diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..42098e9f2e757b5fbb05bb4cbf171ed24d56835d
--- /dev/null
+++ b/oneflow/python/nn/modules/conv.py
@@ -0,0 +1,89 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import math
+import oneflow as flow
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.nn.module import Module
+from oneflow.python.nn.modules.utils import _pair
+from oneflow.python.nn.common_types import _size_2_t
+from oneflow.python.nn import init
+
+
+@oneflow_export("nn.Conv2d")
+class Conv2d(Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: _size_2_t,
+ stride: _size_2_t = 1,
+ padding: _size_2_t = 0,
+ dilation: _size_2_t = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros", # TODO: refine this type
+ ):
+ super().__init__()
+
+ assert padding_mode == "zeros"
+ kernel_size = _pair(kernel_size)
+ stride = _pair(stride)
+ padding = _pair(padding)
+ dilation = _pair(dilation)
+ self.weight = flow.nn.Parameter(
+ flow.Tensor(out_channels, in_channels // groups, *kernel_size)
+ )
+ self.bias = None
+ self._bias_add_op = None
+ if bias:
+ self.bias = flow.nn.Parameter(flow.Tensor(out_channels))
+ self._bias_add_op = (
+ flow.builtin_op("bias_add")
+ .Input("a")
+ .Input("b")
+ .Output("out")
+ .Attr("axis", 1)
+ .Build()
+ )
+
+ self._op = (
+ flow.builtin_op("conv2d")
+ .Input("in")
+ .Input("weight")
+ .Attr("filters", out_channels)
+ .Attr("padding_before", padding)
+ .Attr("strides", stride)
+ .Attr("kernel_size", kernel_size)
+ .Attr("dilation_rate", dilation)
+ .Attr("groups", groups)
+ .Attr("data_format", "channels_first")
+ .Output("out")
+ .Build()
+ )
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, x):
+ res = self._op(x, self.weight)[0]
+ if self._bias_add_op is not None:
+ res = self._bias_add_op(res, self.bias)[0]
+ return res
diff --git a/oneflow/python/test/modules/test_conv.py b/oneflow/python/test/modules/test_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..36253ff01af82a5b7dab50cab26ce393eb7647cb
--- /dev/null
+++ b/oneflow/python/test/modules/test_conv.py
@@ -0,0 +1,787 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+import numpy as np
+import oneflow as flow
+
+test_conv2d_weight = np.array(
+ [
+ [
+ [
+ [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],
+ [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],
+ [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],
+ ]
+ ],
+ [
+ [
+ [0.29670074582099915, 1.3111951351165771, 0.5035904049873352],
+ [-1.1894450187683105, -0.5502137541770935, -1.591875672340393],
+ [-1.1081947088241577, 0.07872020453214645, -0.9185634255409241],
+ ]
+ ],
+ [
+ [
+ [-0.7457143664360046, -1.2080862522125244, 1.8140212297439575],
+ [-1.5227429866790771, -2.515244960784912, -1.3549325466156006],
+ [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952],
+ ]
+ ],
+ ]
+)
+test_conv2d_data = np.array(
+ [
+ [
+ [
+ [
+ 1.1630785465240479,
+ 0.4838046133518219,
+ 0.299563467502594,
+ 0.15302546322345734,
+ -1.168814778327942,
+ ],
+ [
+ 1.5580710172653198,
+ -0.5459445714950562,
+ -2.3556296825408936,
+ 0.5414402484893799,
+ 2.678506374359131,
+ ],
+ [
+ 1.2546343803405762,
+ -0.5487740635871887,
+ -0.6810643672943115,
+ -0.13531559705734253,
+ 0.37723132967948914,
+ ],
+ [
+ 0.41016456484794617,
+ 0.5712682008743286,
+ -2.757962703704834,
+ 1.0762799978256226,
+ -0.6141325235366821,
+ ],
+ [
+ 1.830764889717102,
+ -1.1468064785003662,
+ 0.053837940096855164,
+ -2.5074806213378906,
+ -0.5916498899459839,
+ ],
+ ]
+ ]
+ ]
+)
+
+test_conv2d_output = np.array(
+ [
+ [
+ [
+ [0.9699610471725464, -0.20758534967899323, 2.3857712745666504],
+ [0.3666309118270874, 4.690882682800293, -8.203354835510254],
+ [2.6072847843170166, -1.9033538103103638, 2.331153154373169],
+ ],
+ [
+ [2.519343852996826, 2.3757898807525635, -1.6613528728485107],
+ [0.5777544379234314, -3.5739502906799316, 5.349126815795898],
+ [0.729295015335083, 1.5791023969650269, 3.7627718448638916],
+ ],
+ [
+ [-0.27685487270355225, 6.446267127990723, -2.762883424758911],
+ [-8.25644588470459, 9.616064071655273, 8.005367279052734],
+ [-0.6944921016693115, 3.866114854812622, 4.788446426391602],
+ ],
+ ]
+ ]
+)
+
+test_conv2d_with_bias_weight = np.array(
+ [
+ [
+ [
+ [1.8271433115005493, -1.0446699857711792, 1.0062190294265747],
+ [0.5174201130867004, -0.806931734085083, 1.3769007921218872],
+ [0.205885112285614, 0.9943519234657288, -0.23580588400363922],
+ ]
+ ],
+ [
+ [
+ [0.29881811141967773, -1.9982075691223145, 0.3511354625225067],
+ [-0.7644741535186768, 1.2594351768493652, -0.9629734754562378],
+ [0.5080506205558777, 0.7561734318733215, 1.6839302778244019],
+ ]
+ ],
+ [
+ [
+ [1.2573646306991577, 0.13123232126235962, 1.6403018236160278],
+ [-1.2138012647628784, 2.399970531463623, -0.38509097695350647],
+ [-0.9878040552139282, 0.9585888385772705, -1.4976465702056885],
+ ]
+ ],
+ ]
+)
+test_conv2d_with_bias_bias = np.array(
+ [0.6605162620544434, -0.18903568387031555, -0.27302607893943787]
+)
+test_conv2d_with_bias_data = np.array(
+ [
+ [
+ [
+ [
+ -0.47827261686325073,
+ -1.1739492416381836,
+ -0.7921845316886902,
+ 0.9321041703224182,
+ -3.1557741165161133,
+ ],
+ [
+ 2.1935296058654785,
+ -0.5385921001434326,
+ -0.8611332774162292,
+ -1.881519079208374,
+ -0.7205708026885986,
+ ],
+ [
+ -0.35601571202278137,
+ -0.15963983535766602,
+ 1.797447681427002,
+ 0.19594945013523102,
+ -1.7376397848129272,
+ ],
+ [
+ 0.047347065061330795,
+ 0.14580930769443512,
+ 0.32604914903640747,
+ 0.4578782916069031,
+ -0.8942581415176392,
+ ],
+ [
+ 0.49383941292762756,
+ -0.9043426513671875,
+ -1.2140793800354004,
+ 2.1564064025878906,
+ 1.0938222408294678,
+ ],
+ ]
+ ]
+ ]
+)
+
+test_conv2d_with_bias_output = np.array(
+ [
+ [
+ [
+ [-0.05607491731643677, -0.185230553150177, -3.8808679580688477],
+ [6.861937046051025, -2.3341472148895264, -0.5597308874130249],
+ [1.8299254179000854, -2.770848274230957, 2.1958212852478027],
+ ],
+ [
+ [2.9348952770233154, 4.117504119873047, -6.278541088104248],
+ [0.2638452351093292, 3.998856782913208, 2.612290620803833],
+ [-1.9891828298568726, -1.6476304531097412, 3.39066219329834],
+ ],
+ [
+ [-8.44466781616211, 0.5747121572494507, -8.501373291015625],
+ [-0.036642804741859436, -0.23458999395370483, -2.370849370956421],
+ [2.8372013568878174, -2.987276077270508, 1.8382092714309692],
+ ],
+ ]
+ ]
+)
+test_conv2d_group_weight = np.array(
+ [
+ [
+ [
+ [-0.7248556613922119, 1.1119636297225952, -0.47827261686325073],
+ [-1.1739492416381836, -0.7921845316886902, 0.9321041703224182],
+ [-3.1557741165161133, 2.1935296058654785, -0.5385921001434326],
+ ]
+ ],
+ [
+ [
+ [-0.8611332774162292, -1.881519079208374, -0.7205708026885986],
+ [-0.35601571202278137, -0.15963983535766602, 1.797447681427002],
+ [0.19594945013523102, -1.7376397848129272, 0.047347065061330795],
+ ]
+ ],
+ ]
+)
+
+test_conv2d_group_data = np.array(
+ [
+ [
+ [
+ [
+ 1.1630785465240479,
+ 0.4838046133518219,
+ 0.299563467502594,
+ 0.15302546322345734,
+ -1.168814778327942,
+ ],
+ [
+ 1.5580710172653198,
+ -0.5459445714950562,
+ -2.3556296825408936,
+ 0.5414402484893799,
+ 2.678506374359131,
+ ],
+ [
+ 1.2546343803405762,
+ -0.5487740635871887,
+ -0.6810643672943115,
+ -0.13531559705734253,
+ 0.37723132967948914,
+ ],
+ [
+ 0.41016456484794617,
+ 0.5712682008743286,
+ -2.757962703704834,
+ 1.0762799978256226,
+ -0.6141325235366821,
+ ],
+ [
+ 1.830764889717102,
+ -1.1468064785003662,
+ 0.053837940096855164,
+ -2.5074806213378906,
+ -0.5916498899459839,
+ ],
+ ],
+ [
+ [
+ 0.8586049675941467,
+ -0.2279418259859085,
+ 0.2013147622346878,
+ 0.35005471110343933,
+ 0.5360521078109741,
+ ],
+ [
+ 1.5194443464279175,
+ 1.9040879011154175,
+ -1.5734431743621826,
+ -0.14007866382598877,
+ 0.29670074582099915,
+ ],
+ [
+ 1.3111951351165771,
+ 0.5035904049873352,
+ -1.1894450187683105,
+ -0.5502137541770935,
+ -1.591875672340393,
+ ],
+ [
+ -1.1081947088241577,
+ 0.07872020453214645,
+ -0.9185634255409241,
+ -0.7457143664360046,
+ -1.2080862522125244,
+ ],
+ [
+ 1.8140212297439575,
+ -1.5227429866790771,
+ -2.515244960784912,
+ -1.3549325466156006,
+ -0.9574840068817139,
+ ],
+ ],
+ ]
+ ]
+)
+test_conv2d_group_output = np.array(
+ [
+ [
+ [
+ [-8.836943626403809, 3.2316627502441406, 6.994439601898193],
+ [-0.8386597037315369, -9.857108116149902, 13.68197250366211],
+ [-13.020713806152344, 7.310227870941162, -3.3760271072387695],
+ ],
+ [
+ [-4.803101539611816, 1.026240587234497, 0.5452112555503845],
+ [-6.839838027954102, 2.0195930004119873, 0.11328654736280441],
+ [0.393694669008255, 4.987061023712158, 3.297354221343994],
+ ],
+ ]
+ ]
+)
+test_conv2d_padding_weight = np.array(
+ [
+ [
+ [
+ [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],
+ [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],
+ [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],
+ ]
+ ]
+ ]
+)
+test_conv2d_padding_data = np.array(
+ [
+ [
+ [
+ [
+ 1.1630785465240479,
+ 0.4838046133518219,
+ 0.299563467502594,
+ 0.15302546322345734,
+ -1.168814778327942,
+ ],
+ [
+ 1.5580710172653198,
+ -0.5459445714950562,
+ -2.3556296825408936,
+ 0.5414402484893799,
+ 2.678506374359131,
+ ],
+ [
+ 1.2546343803405762,
+ -0.5487740635871887,
+ -0.6810643672943115,
+ -0.13531559705734253,
+ 0.37723132967948914,
+ ],
+ [
+ 0.41016456484794617,
+ 0.5712682008743286,
+ -2.757962703704834,
+ 1.0762799978256226,
+ -0.6141325235366821,
+ ],
+ [
+ 1.830764889717102,
+ -1.1468064785003662,
+ 0.053837940096855164,
+ -2.5074806213378906,
+ -0.5916498899459839,
+ ],
+ ]
+ ]
+ ]
+)
+test_conv2d_padding_output = np.array(
+ [
+ [
+ [
+ [
+ 1.5489805936813354,
+ -1.0164761543273926,
+ 5.277345657348633,
+ 3.153532028198242,
+ -7.301508903503418,
+ -3.7565059661865234,
+ 4.690962314605713,
+ ],
+ [
+ 2.425799608230591,
+ -2.0592665672302246,
+ 0.9699610471725464,
+ -0.20758534967899323,
+ 2.3857712745666504,
+ 1.1719579696655273,
+ 0.6523551940917969,
+ ],
+ [
+ 2.1625545024871826,
+ -1.3517316579818726,
+ 0.3666309118270874,
+ 4.690882682800293,
+ -8.203354835510254,
+ 3.0248217582702637,
+ 1.2624683380126953,
+ ],
+ [
+ 0.6193475723266602,
+ -2.0285415649414062,
+ 2.6072847843170166,
+ -1.9033538103103638,
+ 2.331153154373169,
+ -3.998155355453491,
+ -1.0176407098770142,
+ ],
+ [
+ 2.8643176555633545,
+ -0.7396122217178345,
+ -0.2253415733575821,
+ -2.846742630004883,
+ -4.961236476898193,
+ -0.1308247298002243,
+ -0.7344070672988892,
+ ],
+ ]
+ ]
+ ]
+)
+
+test_conv2d_stride_weight = np.array(
+ [
+ [
+ [
+ [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],
+ [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],
+ [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],
+ ]
+ ]
+ ]
+)
+test_conv2d_stride_data = np.array(
+ [
+ [
+ [
+ [
+ 1.1630785465240479,
+ 0.4838046133518219,
+ 0.299563467502594,
+ 0.15302546322345734,
+ -1.168814778327942,
+ ],
+ [
+ 1.5580710172653198,
+ -0.5459445714950562,
+ -2.3556296825408936,
+ 0.5414402484893799,
+ 2.678506374359131,
+ ],
+ [
+ 1.2546343803405762,
+ -0.5487740635871887,
+ -0.6810643672943115,
+ -0.13531559705734253,
+ 0.37723132967948914,
+ ],
+ [
+ 0.41016456484794617,
+ 0.5712682008743286,
+ -2.757962703704834,
+ 1.0762799978256226,
+ -0.6141325235366821,
+ ],
+ [
+ 1.830764889717102,
+ -1.1468064785003662,
+ 0.053837940096855164,
+ -2.5074806213378906,
+ -0.5916498899459839,
+ ],
+ ]
+ ]
+ ]
+)
+test_conv2d_stride_output = np.array(
+ [
+ [
+ [
+ [-1.0164761543273926, -7.301508903503418],
+ [-1.3517316579818726, -8.203354835510254],
+ [-0.7396122217178345, -4.961236476898193],
+ ]
+ ]
+ ]
+)
+test_conv2d_kernel_weight = np.array(
+ [
+ [
+ [
+ [
+ -0.9574840068817139,
+ -0.7248556613922119,
+ 1.1119636297225952,
+ -0.47827261686325073,
+ -1.1739492416381836,
+ ],
+ [
+ -0.7921845316886902,
+ 0.9321041703224182,
+ -3.1557741165161133,
+ 2.1935296058654785,
+ -0.5385921001434326,
+ ],
+ [
+ -0.8611332774162292,
+ -1.881519079208374,
+ -0.7205708026885986,
+ -0.35601571202278137,
+ -0.15963983535766602,
+ ],
+ ]
+ ]
+ ]
+)
+test_conv2d_kernel_data = np.array(
+ [
+ [
+ [
+ [
+ 1.1630785465240479,
+ 0.4838046133518219,
+ 0.299563467502594,
+ 0.15302546322345734,
+ -1.168814778327942,
+ 1.5580710172653198,
+ -0.5459445714950562,
+ ],
+ [
+ -2.3556296825408936,
+ 0.5414402484893799,
+ 2.678506374359131,
+ 1.2546343803405762,
+ -0.5487740635871887,
+ -0.6810643672943115,
+ -0.13531559705734253,
+ ],
+ [
+ 0.37723132967948914,
+ 0.41016456484794617,
+ 0.5712682008743286,
+ -2.757962703704834,
+ 1.0762799978256226,
+ -0.6141325235366821,
+ 1.830764889717102,
+ ],
+ [
+ -1.1468064785003662,
+ 0.053837940096855164,
+ -2.5074806213378906,
+ -0.5916498899459839,
+ 0.8586049675941467,
+ -0.2279418259859085,
+ 0.2013147622346878,
+ ],
+ [
+ 0.35005471110343933,
+ 0.5360521078109741,
+ 1.5194443464279175,
+ 1.9040879011154175,
+ -1.5734431743621826,
+ -0.14007866382598877,
+ 0.29670074582099915,
+ ],
+ [
+ 1.3111951351165771,
+ 0.5035904049873352,
+ -1.1894450187683105,
+ -0.5502137541770935,
+ -1.591875672340393,
+ -1.1081947088241577,
+ 0.07872020453214645,
+ ],
+ [
+ -0.9185634255409241,
+ -0.7457143664360046,
+ -1.2080862522125244,
+ 1.8140212297439575,
+ -1.5227429866790771,
+ -2.515244960784912,
+ -1.3549325466156006,
+ ],
+ ]
+ ]
+ ]
+)
+
+test_conv2d_kernel_output = np.array(
+ [
+ [
+ [
+ [-3.5647754669189453, -4.234736919403076, 1.4046944379806519],
+ [-0.6964312791824341, 16.42838478088379, -9.649789810180664],
+ [4.312150478363037, -6.283960819244385, -4.8443922996521],
+ [-2.772286891937256, -4.483709812164307, 12.315184593200684],
+ [7.39893913269043, 1.305102825164795, -2.049992561340332],
+ ]
+ ]
+ ]
+)
+test_conv2d_dilation_weight = np.array(
+ [
+ [
+ [
+ [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952],
+ [-0.47827261686325073, -1.1739492416381836, -0.7921845316886902],
+ [0.9321041703224182, -3.1557741165161133, 2.1935296058654785],
+ ]
+ ]
+ ]
+)
+test_conv2d_dilation_data = np.array(
+ [
+ [
+ [
+ [
+ 1.1630785465240479,
+ 0.4838046133518219,
+ 0.299563467502594,
+ 0.15302546322345734,
+ -1.168814778327942,
+ 1.5580710172653198,
+ -0.5459445714950562,
+ ],
+ [
+ -2.3556296825408936,
+ 0.5414402484893799,
+ 2.678506374359131,
+ 1.2546343803405762,
+ -0.5487740635871887,
+ -0.6810643672943115,
+ -0.13531559705734253,
+ ],
+ [
+ 0.37723132967948914,
+ 0.41016456484794617,
+ 0.5712682008743286,
+ -2.757962703704834,
+ 1.0762799978256226,
+ -0.6141325235366821,
+ 1.830764889717102,
+ ],
+ [
+ -1.1468064785003662,
+ 0.053837940096855164,
+ -2.5074806213378906,
+ -0.5916498899459839,
+ 0.8586049675941467,
+ -0.2279418259859085,
+ 0.2013147622346878,
+ ],
+ [
+ 0.35005471110343933,
+ 0.5360521078109741,
+ 1.5194443464279175,
+ 1.9040879011154175,
+ -1.5734431743621826,
+ -0.14007866382598877,
+ 0.29670074582099915,
+ ],
+ [
+ 1.3111951351165771,
+ 0.5035904049873352,
+ -1.1894450187683105,
+ -0.5502137541770935,
+ -1.591875672340393,
+ -1.1081947088241577,
+ 0.07872020453214645,
+ ],
+ [
+ -0.9185634255409241,
+ -0.7457143664360046,
+ -1.2080862522125244,
+ 1.8140212297439575,
+ -1.5227429866790771,
+ -2.515244960784912,
+ -1.3549325466156006,
+ ],
+ ]
+ ]
+ ]
+)
+test_conv2d_dilation_output = np.array(
+ [[[[-5.2563982009887695], [5.410353183746338], [-8.517012596130371]]]]
+)
+
+
+def _test_conv2d(test_case, conv, data, output, weight, bias=None):
+ x = flow.Tensor(data)
+ conv.weight = flow.nn.Parameter(flow.Tensor(weight))
+ if bias is not None:
+ conv.bias = flow.nn.Parameter(flow.Tensor(bias))
+ of_out = conv(x)
+ test_case.assertTrue(np.allclose(of_out.numpy(), output, rtol=1e-4, atol=1e-8))
+
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestConv2d(flow.unittest.TestCase):
+ def test_conv2d_default_init(test_case):
+ conv = flow.nn.Conv2d(1, 1, (3, 3), bias=True)
+ test_case.assertTrue(
+ not np.allclose(
+ conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10
+ )
+ )
+ test_case.assertTrue(
+ not np.allclose(conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10)
+ )
+
+ def test_conv2d(test_case):
+ conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False)
+ _test_conv2d(
+ test_case, conv, test_conv2d_data, test_conv2d_output, test_conv2d_weight
+ )
+
+ def test_conv2d_with_bias(test_case):
+ conv = flow.nn.Conv2d(1, 3, (3, 3), bias=True)
+ _test_conv2d(
+ test_case,
+ conv,
+ test_conv2d_with_bias_data,
+ test_conv2d_with_bias_output,
+ test_conv2d_with_bias_weight,
+ test_conv2d_with_bias_bias,
+ )
+
+ def test_conv2d_group(test_case):
+ conv = flow.nn.Conv2d(2, 2, (3, 3), groups=2, bias=False)
+ _test_conv2d(
+ test_case,
+ conv,
+ test_conv2d_group_data,
+ test_conv2d_group_output,
+ test_conv2d_group_weight,
+ )
+
+ def test_conv2d_padding(test_case):
+ conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 2), bias=False)
+ _test_conv2d(
+ test_case,
+ conv,
+ test_conv2d_padding_data,
+ test_conv2d_padding_output,
+ test_conv2d_padding_weight,
+ )
+
+ def test_conv2d_stride(test_case):
+ conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 1), stride=(2, 3), bias=False)
+ _test_conv2d(
+ test_case,
+ conv,
+ test_conv2d_stride_data,
+ test_conv2d_stride_output,
+ test_conv2d_stride_weight,
+ )
+
+ def test_conv2d_kernel(test_case):
+ conv = flow.nn.Conv2d(1, 1, (3, 5), bias=False)
+ _test_conv2d(
+ test_case,
+ conv,
+ test_conv2d_kernel_data,
+ test_conv2d_kernel_output,
+ test_conv2d_kernel_weight,
+ )
+
+ def test_conv2d_dilation(test_case):
+ conv = flow.nn.Conv2d(1, 1, (3, 3), dilation=(2, 3), bias=False)
+ _test_conv2d(
+ test_case,
+ conv,
+ test_conv2d_dilation_data,
+ test_conv2d_dilation_output,
+ test_conv2d_dilation_weight,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp
index 8adc58a6f6b57bb05528721e3067d4992c4b8574..a59d6040885a519f779cb90b4746d01d9ee5ef00 100644
--- a/oneflow/user/kernels/stateful_local_opkernel.cpp
+++ b/oneflow/user/kernels/stateful_local_opkernel.cpp
@@ -101,7 +101,6 @@ user_op::TensorDesc* ZeroCopyBaseContext::TensorDesc4ArgNameAndIndex(const std::
if (i >= 0) { return input_tensor_desc_views_.at(i).get(); }
i = TryGetTensorTupleIndex(arg_name2bn_index2output_tensor_tuple_index_, arg_name, index);
if (i >= 0) { return output_tensor_desc_views_.at(i).get(); }
- LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found";
return nullptr;
}
@@ -112,7 +111,6 @@ user_op::Tensor* ZeroCopyBaseContext::Tensor4ArgNameAndIndex(const std::string&
i = TryGetTensorTupleIndex(arg_name2bn_index2output_tensor_tuple_index_, arg_name, index);
if (i >= 0) { return output_tensor_views_.at(i).get(); }
if (arg_name == "tmp_buffer" && index == 0) { return CHECK_NOTNULL(tmp_buffer_view_.get()); }
- LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found";
return nullptr;
}
diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_local_opkernel.h
index b21ac222a114e75939f5f97b616bb92a569b9dc0..ed75265c123c3fe236ae7097a81cf4983fef23a6 100644
--- a/oneflow/user/kernels/stateful_local_opkernel.h
+++ b/oneflow/user/kernels/stateful_local_opkernel.h
@@ -156,13 +156,13 @@ class LocalUserOpInferContext : public user_op::InferContext {
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) override;
Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
- return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape();
+ return NonNullTensorDesc4ArgNameAndIndex(arg_name, index)->mut_shape();
}
DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
- return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type();
+ return NonNullTensorDesc4ArgNameAndIndex(arg_name, index)->mut_data_type();
}
bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
- return TensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic();
+ return NonNullTensorDesc4ArgNameAndIndex(arg_name, index)->mut_is_dynamic();
}
const ArgVec& inputs() const override { return zero_copy_base_ctx_.inputs(); }
@@ -187,6 +187,12 @@ class LocalUserOpInferContext : public user_op::InferContext {
void Update(EagerBlobObjectList inputs, EagerBlobObjectList outputs);
private:
+ user_op::TensorDesc* NonNullTensorDesc4ArgNameAndIndex(const std::string& arg_name,
+ int32_t index) {
+ user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(arg_name, index);
+ if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; }
+ return tensor_desc;
+ }
const user_op::UserOpConfWrapper& user_op_conf() const override { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {