Skip to content
Snippets Groups Projects
conv_op.cpp 18.8 KiB
Newer Older
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
Niu Chong's avatar
Niu Chong committed
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/ops/nn_util.h"
Niu Chong's avatar
Niu Chong committed

namespace oneflow {

namespace {

template<size_t NDims>
Maybe<void> InferTensorDesc4Conv(user_op::InferContext* ctx) {
  const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0);
  CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes());
Niu Chong's avatar
Niu Chong committed

OuYang Yu's avatar
OuYang Yu committed
  auto data_format = ctx->Attr<std::string>("data_format");
  auto kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
Niu Chong's avatar
Niu Chong committed
  CHECK_EQ_OR_RETURN(NDims, kernel_size.size());
OuYang Yu's avatar
OuYang Yu committed
  int32_t filters = ctx->Attr<int32_t>("filters");
Niu Chong's avatar
Niu Chong committed
  size_t idx_offset = IdxOffset(data_format);
  {
guo ran's avatar
guo ran committed
    const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
OuYang Yu's avatar
OuYang Yu committed
    auto dilation_rate = ctx->Attr<std::vector<int32_t>>("dilation_rate");
    auto strides = ctx->Attr<std::vector<int32_t>>("strides");
Niu Chong's avatar
Niu Chong committed
    CHECK_EQ_OR_RETURN(NDims, dilation_rate.size());
    CHECK_EQ_OR_RETURN(NDims, strides.size());
guo ran's avatar
guo ran committed
    CHECK_EQ_OR_RETURN(NDims, padding_before.size());
Niu Chong's avatar
Niu Chong committed

    user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0);
Niu Chong's avatar
Niu Chong committed
    DimVector out_shape(NDims + 2);
    out_shape.at(0) = in.shape().At(0);
JamiePlur's avatar
JamiePlur committed
    const size_t c_dim = data_format == "channels_first" ? 1 : NDims + 1;
    out_shape.at(c_dim) = filters;
Niu Chong's avatar
Niu Chong committed
    for (int32_t i = 0; i < NDims; ++i) {
      JUST(CalcConvOut(in.shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i),
                       strides.at(i), padding_before.at(i), &out_shape.at(idx_offset + i)));
Niu Chong's avatar
Niu Chong committed
    }
    *out->mut_is_dynamic() = in.is_dynamic();
Niu Chong's avatar
Niu Chong committed
    *out->mut_shape() = Shape(out_shape);
  }

  {
OuYang Yu's avatar
OuYang Yu committed
    int32_t groups = ctx->Attr<int32_t>("groups");
Niu Chong's avatar
Niu Chong committed
    CHECK_GT_OR_RETURN(groups, 0);
    CHECK_LE_OR_RETURN(groups, filters);
    CHECK_EQ_OR_RETURN(filters % groups, 0);

    DimVector weight_shape(in.shape().dim_vec());
Niu Chong's avatar
Niu Chong committed
    weight_shape.at(0) = filters;
    if (data_format == "channels_first") {
      CHECK_LE_OR_RETURN(groups, weight_shape.at(1));
      CHECK_EQ_OR_RETURN(weight_shape.at(1) % groups, 0);
      weight_shape.at(1) = weight_shape.at(1) / groups;
    } else if (data_format == "channels_last") {
      CHECK_LE_OR_RETURN(groups, weight_shape.at(NDims + 1));
      CHECK_EQ_OR_RETURN(weight_shape.at(NDims + 1) % groups, 0);
      weight_shape.at(NDims + 1) = weight_shape.at(NDims + 1) / groups;
    } else {
      UNIMPLEMENTED_THEN_RETURN();
    }
    for (size_t i = 0; i < NDims; ++i) { weight_shape.at(idx_offset + i) = kernel_size.at(i); }

    const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0);
    CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape));
  }

  bool has_bias = ctx->has_input("bias", 0);
  if (has_bias) {
    const user_op::TensorDesc& bias = ctx->InputTensorDesc("bias", 0);
    CHECK_EQ_OR_RETURN(bias.shape(), Shape({filters}));
Niu Chong's avatar
Niu Chong committed
  }

  return Maybe<void>::Ok();
}

Maybe<void> GetSbpSignatures4Conv(user_op::SbpContext* ctx) {
  // TODO(niuchong) : handle bias_multiplier
  bool has_bias = false;
  for (const auto& pair : ctx->inputs()) {
    if (pair.first == "bias") {
      CHECK_EQ_OR_RETURN(0, pair.second);
      has_bias = true;
      break;
    }
  }

  if (has_bias) {
    ctx->NewBuilder()
        .Split(user_op::OpArg("in", 0), 0)
        .Broadcast(user_op::OpArg("weight", 0))
        .Broadcast(user_op::OpArg("bias", 0))
        .Split(user_op::OpArg("out", 0), 0)
        .Build();
  } else {
    ctx->NewBuilder()
        .Split(user_op::OpArg("in", 0), 0)
        .Broadcast(user_op::OpArg("weight", 0))
        .Split(user_op::OpArg("out", 0), 0)
        .Build();
  }
  return Maybe<void>::Ok();
}

template<size_t NDims>
Maybe<void> CheckAttr(const user_op::UserOpDefWrapper& def,
                      const user_op::UserOpConfWrapper& conf) {
  bool is_checked = true;
  std::stringstream err;
  err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": ";

  const auto& data_format = conf.attr<std::string>("data_format");
  if (!(data_format == "channels_first" || data_format == "channels_last")) {
    err << " data_format:" << data_format;
    is_checked = false;
  }

  if (NDims != 0) {
guo ran's avatar
guo ran committed
    const auto& padding_before = conf.attr<std::vector<int32_t>>("padding_before");
    if (padding_before.size() != NDims) {
      err << " padding_before: number of element is " << padding_before.size();
      is_checked = false;
    }

Niu Chong's avatar
Niu Chong committed
    const auto& kernel_size = conf.attr<std::vector<int32_t>>("kernel_size");
    if (kernel_size.size() != NDims) {
      err << " kernel_size: number of element is " << kernel_size.size();
      is_checked = false;
    }

    const auto& strides = conf.attr<std::vector<int32_t>>("strides");
    if (strides.size() != NDims) {
      err << " strides: number of element is " << strides.size();
      is_checked = false;
    }

    const auto& dilation_rate = conf.attr<std::vector<int32_t>>("dilation_rate");
    if (dilation_rate.size() != NDims) {
      err << " dilation_rate: number of element is " << dilation_rate.size();
      is_checked = false;
    }
  }

  if (is_checked) {
    return Maybe<void>::Ok();
  } else {
    return oneflow::Error::CheckFailedError() << err.str();
Maybe<void> GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
guo ran's avatar
guo ran committed
  const auto& padding_before = op.attr<std::vector<int32_t>>("padding_before");
Niu Chong's avatar
Niu Chong committed
  std::string data_format = op.attr<std::string>("data_format");
  std::vector<int32_t> kernel_size = op.attr<std::vector<int32_t>>("kernel_size");
  std::vector<int32_t> strides = op.attr<std::vector<int32_t>>("strides");
  std::vector<int32_t> dilation_rate = op.attr<std::vector<int32_t>>("dilation_rate");
  int32_t groups = op.attr<int32_t>("groups");

  int32_t ndims = kernel_size.size();
  CHECK_EQ_OR_RETURN(ndims, strides.size());
  CHECK_EQ_OR_RETURN(ndims, dilation_rate.size());
Niu Chong's avatar
Niu Chong committed

  if (op.user_op_conf().has_input("bias", 0)) {
    if (op.NeedGenGradTensor4OpInput("bias", 0)) {
      auto bias_grad_op =
          user_op::UserOpConfWrapperBuilder("System-AutoGrad-" + op.op_name() + "-BiasGrad")
              .Op("conv_bias_grad")
              .Input("dy", op.GetGradTensorWithOpOutput("out", 0))
              .Output("bias_diff")
              .Attr<std::string>("data_format", data_format)
              .Attr<int32_t>("num_spatial_dims", ndims)
              .Build();
      op.BindGradTensorWithOpInput(bias_grad_op.output("bias_diff", 0), "bias", 0);
      AddOp(bias_grad_op);
    }
  }

  if (op.NeedGenGradTensor4OpInput("weight", 0)) {
    auto filter_grad_op =
        user_op::UserOpConfWrapperBuilder("System-AutoGrad-" + op.op_name() + "-FilterGrad")
            .Op("conv_filter_grad")
            .Input("dy", op.GetGradTensorWithOpOutput("out", 0))
            .Input("x", op.input("in", 0))
            .Output("filter_diff")
            .Attr<int32_t>("num_spatial_dims", ndims)
guo ran's avatar
guo ran committed
            .Attr<std::vector<int32_t>>("padding_before", padding_before)
Niu Chong's avatar
Niu Chong committed
            .Attr<std::string>("data_format", data_format)
            .Attr<std::vector<int32_t>>("kernel_size", kernel_size)
            .Attr<std::vector<int32_t>>("strides", strides)
            .Attr<std::vector<int32_t>>("dilation_rate", dilation_rate)
            .Attr<int32_t>("groups", groups)
            .Build();
    op.BindGradTensorWithOpInput(filter_grad_op.output("filter_diff", 0), "weight", 0);
    AddOp(filter_grad_op);
  }

  if (op.NeedGenGradTensor4OpInput("in", 0)) {
    auto data_grad_op =
        user_op::UserOpConfWrapperBuilder("System-AutoGrad-" + op.op_name() + "-DataGrad")
            .Op("conv_data_grad")
            .Input("dy", op.GetGradTensorWithOpOutput("out", 0))
            .Input("filter", op.input("weight", 0))
            .Input("x_like", op.input("in", 0))
            .Output("dx")
            .Attr<int32_t>("num_spatial_dims", ndims)
guo ran's avatar
guo ran committed
            .Attr<std::vector<int32_t>>("padding_before", padding_before)
Niu Chong's avatar
Niu Chong committed
            .Attr<std::string>("data_format", data_format)
            .Attr<std::vector<int32_t>>("kernel_size", kernel_size)
            .Attr<std::vector<int32_t>>("strides", strides)
            .Attr<std::vector<int32_t>>("dilation_rate", dilation_rate)
            .Attr<int32_t>("groups", groups)
            .Build();
    op.BindGradTensorWithOpInput(data_grad_op.output("dx", 0), "in", 0);
    AddOp(data_grad_op);
  }
  return Maybe<void>::Ok();
Niu Chong's avatar
Niu Chong committed
}

}  // namespace

REGISTER_USER_OP("conv1d")
    .Input("in")
    .Input("weight")
    .OptionalInput("bias")
    .OptionalInput("bias_multiplier")  // cudnn conv doesn't need this
    .Output("out")
    .Attr<int32_t>("filters")
    .Attr<std::vector<int32_t>>("padding_before")
    .Attr<std::string>("data_format")
    .Attr<std::vector<int32_t>>("kernel_size")
    .Attr<std::vector<int32_t>>("strides")
    .Attr<std::vector<int32_t>>("dilation_rate")
    .Attr<int32_t>("groups", 1)
Niu Chong's avatar
Niu Chong committed
    .SetCheckAttrFn(CheckAttr<1>)
    .SetTensorDescInferFn(InferTensorDesc4Conv<1>)
    .SetGetSbpFn(GetSbpSignatures4Conv)
    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0);
      return Maybe<void>::Ok();
    });
Niu Chong's avatar
Niu Chong committed

REGISTER_USER_OP("conv2d")
    .Input("in")
    .Input("weight")
    .OptionalInput("bias")
    .OptionalInput("bias_multiplier")  // cudnn conv doesn't need this
    .Output("out")
    .Attr<int32_t>("filters")
    .Attr<std::vector<int32_t>>("padding_before")
    .Attr<std::string>("data_format")
    .Attr<std::vector<int32_t>>("kernel_size")
    .Attr<std::vector<int32_t>>("strides")
    .Attr<std::vector<int32_t>>("dilation_rate")
    .Attr<int32_t>("groups", 1)
Niu Chong's avatar
Niu Chong committed
    .SetCheckAttrFn(CheckAttr<2>)
    .SetTensorDescInferFn(InferTensorDesc4Conv<2>)
    .SetGetSbpFn(GetSbpSignatures4Conv)
    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0);
      return Maybe<void>::Ok();
    });
Niu Chong's avatar
Niu Chong committed

REGISTER_USER_OP("conv3d")
    .Input("in")
    .Input("weight")
    .OptionalInput("bias")
    .OptionalInput("bias_multiplier")  // cudnn conv doesn't need this
    .Output("out")
    .Attr<int32_t>("filters")
    .Attr<std::vector<int32_t>>("padding_before")
    .Attr<std::string>("data_format")
    .Attr<std::vector<int32_t>>("kernel_size")
    .Attr<std::vector<int32_t>>("strides")
    .Attr<std::vector<int32_t>>("dilation_rate")
    .Attr<int32_t>("groups", 1)
Niu Chong's avatar
Niu Chong committed
    .SetCheckAttrFn(CheckAttr<3>)
    .SetTensorDescInferFn(InferTensorDesc4Conv<3>)
    .SetGetSbpFn(GetSbpSignatures4Conv)
    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0);
      return Maybe<void>::Ok();
    });
Niu Chong's avatar
Niu Chong committed

REGISTER_USER_OP_GRAD("conv1d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv);
REGISTER_USER_OP_GRAD("conv2d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv);
REGISTER_USER_OP_GRAD("conv3d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv);

REGISTER_USER_OP("conv_data_grad")
    .Input("dy")
    .Input("filter")
    .Input("x_like")
Juncheng's avatar
Juncheng committed
    .OptionalInput("_add_to_output")
Niu Chong's avatar
Niu Chong committed
    .Output("dx")
    .Attr<int32_t>("num_spatial_dims")
    .Attr<std::vector<int32_t>>("padding_before")
    .Attr<std::string>("data_format")
    .Attr<std::vector<int32_t>>("kernel_size")
    .Attr<std::vector<int32_t>>("strides")
    .Attr<std::vector<int32_t>>("dilation_rate")
    .Attr<int32_t>("groups")
Niu Chong's avatar
Niu Chong committed
    .SetCheckAttrFn(CheckAttr<0>)
    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
      const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0);
OuYang Yu's avatar
OuYang Yu committed
      const int32_t num_spatial_dims = ctx->Attr<int32_t>("num_spatial_dims");
Niu Chong's avatar
Niu Chong committed
      CHECK_GE_OR_RETURN(num_spatial_dims, 1);
      CHECK_LE_OR_RETURN(num_spatial_dims, 3);
      CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2);
      CHECK_EQ_OR_RETURN(x_like.shape().NumAxes(), num_spatial_dims + 2);
      if (ctx->has_input("_add_to_output", 0)) {
        const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
        CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape());
Juncheng's avatar
Juncheng committed
      }
      *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0);
      *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0);
Niu Chong's avatar
Niu Chong committed
      return Maybe<void>::Ok();
    })
    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
Juncheng's avatar
Juncheng committed
      std::vector<user_op::OpArg> split_args;
      split_args.emplace_back("dy", 0);
      split_args.emplace_back("x_like", 0);
      split_args.emplace_back("dx", 0);
      if (ctx->user_op_conf().has_input("_add_to_output", 0)) {
        split_args.emplace_back("_add_to_output", 0);
      }
      ctx->NewBuilder().Split(split_args, 0).Broadcast(user_op::OpArg("filter", 0)).Build();
Niu Chong's avatar
Niu Chong committed
      return Maybe<void>::Ok();
    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
      const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0);
      CHECK_EQ_OR_RETURN(x_like.data_type(), dy.data_type());
      if (ctx->has_input("_add_to_output", 0)) {
        const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
        CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type());
      *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0);
      return Maybe<void>::Ok();
Niu Chong's avatar
Niu Chong committed
    });

REGISTER_USER_OP("conv_filter_grad")
    .Input("dy")
    .Input("x")
    .Output("filter_diff")
    .Attr<int32_t>("num_spatial_dims")
    .Attr<std::vector<int32_t>>("padding_before")
    .Attr<std::string>("data_format")
    .Attr<std::vector<int32_t>>("kernel_size")
    .Attr<std::vector<int32_t>>("strides")
    .Attr<std::vector<int32_t>>("dilation_rate")
    .Attr<int32_t>("groups")
Niu Chong's avatar
Niu Chong committed
    .SetCheckAttrFn(CheckAttr<0>)
    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
      const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
Niu Chong's avatar
Niu Chong committed

OuYang Yu's avatar
OuYang Yu committed
      const int32_t num_spatial_dims = ctx->Attr<int32_t>("num_spatial_dims");
      const int32_t groups = ctx->Attr<int32_t>("groups");
      const std::string& data_format = ctx->Attr<std::string>("data_format");
      const std::vector<int32_t> kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
Niu Chong's avatar
Niu Chong committed

      CHECK_GE_OR_RETURN(num_spatial_dims, 1);
      CHECK_LE_OR_RETURN(num_spatial_dims, 3);
      CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2);
      CHECK_EQ_OR_RETURN(x.shape().NumAxes(), num_spatial_dims + 2);
Niu Chong's avatar
Niu Chong committed
      CHECK_GT_OR_RETURN(groups, 0);

      DimVector filter_diff_dim_vec;
      if (data_format == "channels_first") {
        CHECK_LE_OR_RETURN(groups, x.shape().At(1));
        CHECK_LE_OR_RETURN(groups, dy.shape().At(1));
        CHECK_EQ_OR_RETURN(x.shape().At(1) % groups, 0);
        CHECK_EQ_OR_RETURN(dy.shape().At(1) % groups, 0);
        filter_diff_dim_vec.push_back(dy.shape().At(1));
        filter_diff_dim_vec.push_back(x.shape().At(1) / groups);
Niu Chong's avatar
Niu Chong committed
        filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(),
                                   kernel_size.cend());
      } else {
        CHECK_EQ_OR_RETURN("channels_last", data_format);
        CHECK_EQ_OR_RETURN(groups, 1);
        filter_diff_dim_vec.push_back(dy.shape().dim_vec().back());
Niu Chong's avatar
Niu Chong committed
        filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(),
                                   kernel_size.cend());
        filter_diff_dim_vec.push_back(x.shape().dim_vec().back() / groups);
      user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0);
Niu Chong's avatar
Niu Chong committed
      *filter_diff->mut_shape() = Shape(filter_diff_dim_vec);
binbinHan's avatar
binbinHan committed
      filter_diff->set_is_dynamic(false);
Niu Chong's avatar
Niu Chong committed

      return Maybe<void>::Ok();
    })
    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
      ctx->NewBuilder()
          .Split(user_op::OpArg("dy", 0), 0)
          .Split(user_op::OpArg("x", 0), 0)
          .PartialSum(user_op::OpArg("filter_diff", 0))
          .Build();
      return Maybe<void>::Ok();
    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
      const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
      CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type());
      user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0);
      *filter_diff->mut_data_type() = x.data_type();
      return Maybe<void>::Ok();
Niu Chong's avatar
Niu Chong committed
    });

REGISTER_USER_OP("conv_bias_grad")
    .Input("dy")
    .Output("bias_diff")
    .Attr<std::string>("data_format")
    .Attr<int32_t>("num_spatial_dims")
Niu Chong's avatar
Niu Chong committed
    .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def,
                       const user_op::UserOpConfWrapper& conf) -> Maybe<void> {
      std::string data_format = conf.attr<std::string>("data_format");
      if (data_format == "channels_first" || data_format == "channels_last") {
        return Maybe<void>::Ok();
      }
      return oneflow::Error::CheckFailedError()
             << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name()
             << ": data_format:" << data_format;
Niu Chong's avatar
Niu Chong committed
    })
    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
Niu Chong's avatar
Niu Chong committed
      user_op::TensorDesc* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0);

OuYang Yu's avatar
OuYang Yu committed
      int32_t num_spatial_dims = ctx->Attr<int32_t>("num_spatial_dims");
      std::string data_format = ctx->Attr<std::string>("data_format");
Niu Chong's avatar
Niu Chong committed

      CHECK_GE_OR_RETURN(num_spatial_dims, 1);
      CHECK_LE_OR_RETURN(num_spatial_dims, 3);
      CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2);
Niu Chong's avatar
Niu Chong committed
      if (data_format == "channels_first") {
        *bias_diff->mut_shape() = Shape({dy.shape().At(1)});
Niu Chong's avatar
Niu Chong committed
      } else if (data_format == "channels_last") {
        *bias_diff->mut_shape() = Shape({dy.shape().At(dy.shape().NumAxes() - 1)});
Niu Chong's avatar
Niu Chong committed
      } else {
        OF_UNIMPLEMENTED();
      }
      return Maybe<void>::Ok();
    })
    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
      ctx->NewBuilder()
          .Split(user_op::OpArg("dy", 0), 0)
          .PartialSum(user_op::OpArg("bias_diff", 0))
          .Build();
      return Maybe<void>::Ok();
    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
      const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
      user_op::TensorDesc* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0);
      *bias_diff->mut_data_type() = dy.data_type();
      return Maybe<void>::Ok();
Niu Chong's avatar
Niu Chong committed
    });

}  // namespace oneflow