Skip to content
Snippets Groups Projects
Unverified Commit 1e351cf8 authored by liufengwei0103's avatar liufengwei0103 Committed by GitHub
Browse files

registry_callback_fn return maybe (#5456)


* modified SetInputArgModifyFn

* Delete the CHECK changes in the assign_op.cpp file

* Format

* Modified the OutputArgModifyFn interface

* add return

* maybe error stack from CheckAndConstructOp to OutputArgModifier callback function

* maybe error stack from CheckAndConstructOp to OutputArgModifier callback function

* OutputArgModifier return maybe part_1

* maybe error stack from CheckAndConstructOp to OutputArgModifier callback function

* input_arg_modifier return maybe

* gen_bw_fn return maybe

* bw_gen_fn return maybe

* registry_callback_fn return maybe

* fix bug after merge master

* fix bug

Co-authored-by: default avataraishangjj <702572275@qq.com>
parent e6da5750
No related branches found
No related tags found
No related merge requests found
Showing
with 60 additions and 55 deletions
......@@ -42,8 +42,9 @@ class SamePaddingKernel final : public user_op::OpKernel {
for (int32_t i = 0; i < num_spatial_dims; ++i) {
int32_t padding_small = 0;
int32_t padding_large = 0;
CalcSamePadding(x->shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i),
strides.at(i), &padding_small, &padding_large);
CHECK_JUST(CalcSamePadding(x->shape().At(idx_offset + i), kernel_size.at(i),
dilation_rate.at(i), strides.at(i), &padding_small,
&padding_large));
if (padding == "same_lower") {
padding_before[idx_offset + i] = padding_large;
} else if (padding == "same_upper") {
......@@ -123,8 +124,9 @@ class SamePaddingGradKernel final : public user_op::OpKernel {
for (int32_t i = 0; i < num_spatial_dims; ++i) {
int32_t padding_small = 0;
int32_t padding_large = 0;
CalcSamePadding(dx->shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i),
strides.at(i), &padding_small, &padding_large);
CHECK_JUST(CalcSamePadding(dx->shape().At(idx_offset + i), kernel_size.at(i),
dilation_rate.at(i), strides.at(i), &padding_small,
&padding_large));
if (padding == "same_lower") {
padding_before[idx_offset + i] = padding_large;
} else if (padding == "same_upper") {
......
......@@ -37,7 +37,7 @@ REGISTER_NO_GRAD_USER_OP("cast_to_tick")
const cfg::ParallelDistribution& in_dis_hint =
ctx->ParallelDistributionHint4InputArgNameAndIndex("in", 0);
const Shape& parallel_hierarchy = ctx->parallel_hierarchy();
CHECK_EQ(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes());
CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes());
cfg::ParallelDistribution* in_distribution =
ctx->ParallelDistribution4ArgNameAndIndex("in", 0);
......
......@@ -23,7 +23,7 @@ namespace {
template<size_t NDims>
Maybe<void> InferTensorDesc4Conv(user_op::InferContext* ctx) {
const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0);
CHECK_EQ(NDims + 2, in.shape().NumAxes());
CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes());
auto data_format = ctx->Attr<std::string>("data_format");
auto kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
......@@ -44,8 +44,8 @@ Maybe<void> InferTensorDesc4Conv(user_op::InferContext* ctx) {
const size_t c_dim = data_format == "channels_first" ? 1 : NDims + 1;
out_shape.at(c_dim) = filters;
for (int32_t i = 0; i < NDims; ++i) {
CalcConvOut(in.shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i),
strides.at(i), padding_before.at(i), &out_shape.at(idx_offset + i));
JUST(CalcConvOut(in.shape().At(idx_offset + i), kernel_size.at(i), dilation_rate.at(i),
strides.at(i), padding_before.at(i), &out_shape.at(idx_offset + i)));
}
*out->mut_is_dynamic() = in.is_dynamic();
*out->mut_shape() = Shape(out_shape);
......
......@@ -23,7 +23,7 @@ namespace {
template<size_t NDims>
Maybe<void> InferTensorDesc4DeConv(user_op::InferContext* ctx) {
const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0);
CHECK_EQ(NDims + 2, in.shape().NumAxes());
CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes());
const std::string& data_format = ctx->Attr<std::string>("data_format");
const auto& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
......@@ -69,7 +69,7 @@ Maybe<void> InferTensorDesc4DeConv(user_op::InferContext* ctx) {
for (size_t i = 0; i < NDims; ++i) { weight_shape.at(idx_offset + i) = kernel_size.at(i); }
const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0);
CHECK_EQ(weight.shape(), Shape(weight_shape));
CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape));
}
return Maybe<void>::Ok();
......
......@@ -51,7 +51,7 @@ REGISTER_USER_OP("elu_grad")
const Shape& x_shape = ctx->InputShape("x", 0);
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
CHECK(dy_shape == x_shape);
CHECK_OR_RETURN(dy_shape == x_shape);
*dx_shape = dy_shape;
return Maybe<void>::Ok();
})
......
......@@ -51,7 +51,7 @@ REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale")
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
CHECK_GE(x_tensor.shape().NumAxes(), 2);
CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2);
FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) {
ctx->NewBuilder()
.Split(user_op::OpArg("x", 0), axis)
......@@ -75,7 +75,7 @@ REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale_grad")
const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0);
const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0);
user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0);
CHECK(dy_desc.shape() == softmax_y_desc.shape());
CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape());
*dx_desc->mut_shape() = dy_desc.shape();
*dx_desc->mut_is_dynamic() = dy_desc.is_dynamic();
return Maybe<void>::Ok();
......@@ -84,13 +84,13 @@ REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale_grad")
const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0);
const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0);
user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0);
CHECK(dy_desc.data_type() == softmax_y_desc.data_type());
CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type());
*dx_desc->mut_data_type() = dy_desc.data_type();
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0);
CHECK_GE(dy_tensor.shape().NumAxes(), 2);
CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2);
FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) {
ctx->NewBuilder()
.Split(user_op::OpArg("softmax_y", 0), axis)
......
......@@ -49,7 +49,7 @@ REGISTER_USER_OP("gelu_grad")
const Shape& x_shape = ctx->InputShape("x", 0);
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
CHECK(dy_shape == x_shape);
CHECK_OR_RETURN(dy_shape == x_shape);
*dx_shape = dy_shape;
return Maybe<void>::Ok();
})
......
......@@ -51,7 +51,7 @@ REGISTER_USER_OP("hardsigmoid_grad")
const Shape& x_shape = ctx->InputShape("x", 0);
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
CHECK(dy_shape == x_shape);
CHECK_OR_RETURN(dy_shape == x_shape);
*dx_shape = dy_shape;
return Maybe<void>::Ok();
})
......
......@@ -49,7 +49,7 @@ REGISTER_USER_OP("hardswish_grad")
const Shape& x_shape = ctx->InputShape("x", 0);
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
CHECK(dy_shape == x_shape);
CHECK_OR_RETURN(dy_shape == x_shape);
*dx_shape = dy_shape;
return Maybe<void>::Ok();
})
......
......@@ -30,7 +30,7 @@ REGISTER_USER_OP("hardtanh")
*out_shape = in_shape;
double min_val = ctx->Attr<double>("min_val");
double max_val = ctx->Attr<double>("max_val");
CHECK_LE(min_val, max_val);
CHECK_LE_OR_RETURN(min_val, max_val);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
......@@ -58,11 +58,11 @@ REGISTER_USER_OP("hardtanh_grad")
const Shape& y_shape = ctx->InputShape("y", 0);
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
CHECK(dy_shape == y_shape);
CHECK_OR_RETURN(dy_shape == y_shape);
*dx_shape = dy_shape;
double min_val = ctx->Attr<double>("min_val");
double max_val = ctx->Attr<double>("max_val");
CHECK_LE(min_val, max_val);
CHECK_LE_OR_RETURN(min_val, max_val);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
......
......@@ -140,7 +140,7 @@ REGISTER_USER_OP("layer_norm_grad")
user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0);
CHECK_EQ_OR_RETURN(dy.shape(), x.shape());
const int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
CHECK_GT(begin_norm_axis, 0);
CHECK_GT_OR_RETURN(begin_norm_axis, 0);
const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);
CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape);
CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape);
......
......@@ -48,7 +48,7 @@ REGISTER_USER_OP("leaky_relu_grad")
const Shape& x_shape = ctx->InputShape("x", 0);
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
CHECK(dy_shape == x_shape);
CHECK_OR_RETURN(dy_shape == x_shape);
*dx_shape = dy_shape;
return Maybe<void>::Ok();
})
......
......@@ -17,11 +17,11 @@ limitations under the License.
namespace oneflow {
void CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, const std::string& padding_type, int64_t* output_size,
int32_t* padding_before, int32_t* padding_after) {
CHECK_GT(stride, 0);
CHECK_GE(dilation_rate, 1);
Maybe<void> CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, const std::string& padding_type, int64_t* output_size,
int32_t* padding_before, int32_t* padding_after) {
CHECK_GT_OR_RETURN(stride, 0);
CHECK_GE_OR_RETURN(dilation_rate, 1);
int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
if (padding_type == "valid") {
......@@ -41,13 +41,14 @@ void CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation
} else {
UNIMPLEMENTED();
}
if (output_size) { CHECK_GE((*output_size), 0); }
if (output_size) { CHECK_GE_OR_RETURN((*output_size), 0); }
return Maybe<void>::Ok();
}
void CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride,
int32_t* padding_small, int32_t* padding_large) {
CHECK_GT(stride, 0);
CHECK_GE(dilation_rate, 1);
Maybe<void> CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, int32_t* padding_small, int32_t* padding_large) {
CHECK_GT_OR_RETURN(stride, 0);
CHECK_GE_OR_RETURN(dilation_rate, 1);
int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
int64_t tmp_output_size = (input_size + stride - 1) / stride;
......@@ -55,18 +56,20 @@ void CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_r
0, static_cast<int32_t>((tmp_output_size - 1) * stride + effective_filter_size - input_size));
if (padding_small) { *padding_small = padding_needed / 2; }
if (padding_large) { *padding_large = padding_needed - padding_needed / 2; }
return Maybe<void>::Ok();
}
void CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride,
int32_t padding_before, int64_t* output_size) {
CHECK_GT(stride, 0);
CHECK_GE(dilation_rate, 1);
Maybe<void> CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, int32_t padding_before, int64_t* output_size) {
CHECK_GT_OR_RETURN(stride, 0);
CHECK_GE_OR_RETURN(dilation_rate, 1);
int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
if (output_size) {
*output_size = (input_size + 2 * padding_before - effective_filter_size + stride) / stride;
CHECK_GE((*output_size), 0);
CHECK_GE_OR_RETURN((*output_size), 0);
}
return Maybe<void>::Ok();
}
const size_t IdxOffset(const std::string& data_format) {
......
......@@ -20,15 +20,15 @@ limitations under the License.
namespace oneflow {
void CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, const std::string& padding_type, int64_t* output_size,
int32_t* padding_before, int32_t* padding_after);
Maybe<void> CalcOutAndPadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, const std::string& padding_type, int64_t* output_size,
int32_t* padding_before, int32_t* padding_after);
void CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride,
int32_t* padding_small, int32_t* padding_large);
Maybe<void> CalcSamePadding(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, int32_t* padding_small, int32_t* padding_large);
void CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate, int32_t stride,
int32_t padding_before, int64_t* output_size);
Maybe<void> CalcConvOut(int64_t input_size, int32_t filter_size, int32_t dilation_rate,
int32_t stride, int32_t padding_before, int64_t* output_size);
const size_t IdxOffset(const std::string& data_format);
const int32_t ChannelIdx(const std::string& data_format, int32_t num_axes);
......
......@@ -146,7 +146,7 @@ user_op::TensorDescInferFn MakeFwTensorDescInferFn(
JUST(SetParamTensorDesc("mean"));
JUST(SetParamTensorDesc("inv_variance"));
if (ctx->has_output("reserve_space", 0)) {
CHECK(reserve_space_infer_fn);
CHECK_OR_RETURN(reserve_space_infer_fn);
reserve_space_infer_fn(ctx, &x, ctx->OutputTensorDesc("reserve_space", 0));
}
return Maybe<void>::Ok();
......@@ -178,7 +178,7 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn(
JUST(SetParamDataType("mean"));
JUST(SetParamDataType("inv_variance"));
if (ctx->has_output("reserve_space", 0)) {
CHECK(reserve_space_infer_fn);
CHECK_OR_RETURN(reserve_space_infer_fn);
reserve_space_infer_fn(ctx, &x, ctx->OutputTensorDesc("reserve_space", 0));
}
return Maybe<void>::Ok();
......
......@@ -26,7 +26,7 @@ REGISTER_USER_OP("pack")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0);
const Shape& in_shape = in_desc.shape();
CHECK_GT(in_shape.NumAxes(), 0);
CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0);
user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0);
*out_desc->mut_is_dynamic() = in_desc.is_dynamic();
*out_desc->mut_shape() = in_desc.shape();
......
......@@ -29,8 +29,8 @@ REGISTER_USER_OP("pad")
const Shape& x_shape = ctx->InputShape("x", 0);
const auto& padding_before = ctx->Attr<std::vector<int64_t>>("padding_before");
const auto& padding_after = ctx->Attr<std::vector<int64_t>>("padding_after");
CHECK_EQ(padding_before.size(), x_shape.NumAxes());
CHECK_EQ(padding_after.size(), x_shape.NumAxes());
CHECK_EQ_OR_RETURN(padding_before.size(), x_shape.NumAxes());
CHECK_EQ_OR_RETURN(padding_after.size(), x_shape.NumAxes());
DimVector y_dim_vec(x_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) {
y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i];
......@@ -68,8 +68,8 @@ REGISTER_USER_OP("pad_grad")
const Shape& dy_shape = ctx->InputShape("dy", 0);
const auto& padding_before = ctx->Attr<std::vector<int64_t>>("padding_before");
const auto& padding_after = ctx->Attr<std::vector<int64_t>>("padding_after");
CHECK_EQ(padding_before.size(), dy_shape.NumAxes());
CHECK_EQ(padding_after.size(), dy_shape.NumAxes());
CHECK_EQ_OR_RETURN(padding_before.size(), dy_shape.NumAxes());
CHECK_EQ_OR_RETURN(padding_after.size(), dy_shape.NumAxes());
DimVector dx_dim_vec(dy_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, dy_shape.NumAxes()) {
dx_dim_vec[i] = dy_shape.At(i) - padding_before[i] - padding_after[i];
......
......@@ -67,7 +67,7 @@ REGISTER_USER_OP_GRAD("parallel_cast")
ctx->FwOp().BindGradTensorWithOpInput(ctx->FwOp().GetGradTensorWithOpOutput("out", 0),
"in", 0);
} else {
CHECK(IsValidSbpParallelString(grad_sbp_parallel_str));
CHECK_OR_RETURN(IsValidSbpParallelString(grad_sbp_parallel_str));
const std::string grad_op_name = "System-AutoGrad-" + ctx->FwOp().op_name();
ctx->DefineOp(grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("parallel_cast")
......
......@@ -51,7 +51,7 @@ REGISTER_USER_OP("relu_grad")
const Shape& y_shape = ctx->InputShape("y", 0);
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
CHECK(dy_shape == y_shape);
CHECK_OR_RETURN(dy_shape == y_shape);
*dx_shape = dy_shape;
return Maybe<void>::Ok();
})
......
......@@ -70,7 +70,7 @@ Maybe<void> ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(
HashMap<int, int>* group_start_in_axis2out_axis) {
CHECK_NE_OR_RETURN(in_shape.NumAxes(), 0);
CHECK_NE_OR_RETURN(out_shape.NumAxes(), 0);
CHECK_EQ(in_shape.elem_cnt(), out_shape.elem_cnt());
CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), out_shape.elem_cnt());
int in_axis = in_shape.NumAxes() - 1;
int out_axis = out_shape.NumAxes() - 1;
while (in_axis >= 0 && out_axis >= 0) {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment