Skip to content
Snippets Groups Projects
Unverified Commit bdfb39d5 authored by guo ran's avatar guo ran Committed by GitHub
Browse files

fuse cast scalar_mul scalar_mul_by_tensor (#4730)


* fused_cast_scale_pass fuse scalar_mul

* test case

* refine

* fix

Co-authored-by: default avatarJuncheng <liujuncheng1022@gmail.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent d5f17b13
No related branches found
No related tags found
No related merge requests found
......@@ -65,8 +65,16 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
if (!IsUserOpWithTypeName(op_node->op().op_conf(), "cast")) { return; }
if (!IsSafeToDelete(op_node)) { return; }
if (op_node->out_edges().size() != 1) { return; }
const OpNode* sole_dst_node = op_node->SoleOutEdge()->dst_node();
if (!IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul_by_tensor")) { return; }
OpNode* sole_dst_node = op_node->SoleOutEdge()->dst_node();
if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul")) {
if (!IsSafeToDelete(sole_dst_node)) { return; }
if (!IsUserOpWithTypeName(sole_dst_node->SoleOutEdge()->dst_node()->op().op_conf(),
"scalar_mul_by_tensor")) {
return;
}
} else {
if (!IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul_by_tensor")) { return; }
}
const user_op::UserOpConfWrapper cast_user_conf(op_node->op().op_conf());
if (op_node->LogicalBlobDesc4Lbi(GenLogicalBlobId(cast_user_conf.input("in", 0))).data_type()
!= DataType::kFloat16) {
......@@ -77,14 +85,33 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
return;
}
if (op_node->parallel_desc().device_type() != DeviceType::kGPU) { return; }
double scale = 1.0;
std::vector<OperatorConf> delete_ops;
if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul")) {
const user_op::UserOpConfWrapper scalar_mul_op_conf(sole_dst_node->op().op_conf());
if (scalar_mul_op_conf.attr<bool>("has_int_operand")) {
scale = static_cast<double>(scalar_mul_op_conf.attr<int64_t>("int_operand"));
} else if (scalar_mul_op_conf.attr<bool>("has_float_operand")) {
scale = scalar_mul_op_conf.attr<double>("float_operand");
} else {
UNIMPLEMENTED();
}
delete_ops.push_back(sole_dst_node->op().op_conf());
sole_dst_node = sole_dst_node->SoleOutEdge()->dst_node();
}
delete_ops.push_back(op_node->op().op_conf());
const user_op::UserOpConfWrapper scale_user_conf(sole_dst_node->op().op_conf());
user_op::UserOpConfWrapperBuilder fused_op_builder(sole_dst_node->op().op_name());
fused_op_builder.OpTypeName("fused_cast_scale")
.Input("x", cast_user_conf.input("in", 0))
.Input("scale_by_tensor", scale_user_conf.input("scalar", 0))
.Attr<double>("scale", scale)
.Output("y");
OperatorConf new_op_conf = sole_dst_node->op().op_conf();
new_op_conf.mutable_user_conf()->set_op_type_name("fused_cast_scale");
const auto new_val = cast_user_conf.input("in", 0);
const auto& old_val =
ReplaceInputLbnInOpCustomizedConf(&new_op_conf, GenRepeatedBn("x", 0), new_val);
CHECK_EQ(scale_user_conf.input("x", 0), old_val);
job_builder->DelOps({op_node->op().op_conf()});
*new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf();
job_builder->DelOps(delete_ops);
job_builder->MutOpsOnlyOnce({new_op_conf});
});
return Maybe<void>::Ok();
......
......@@ -29,13 +29,14 @@ for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
def fused_cast_scale(x, scalar, name):
def fused_cast_scale(x, scale_by_tensor, scale, name):
return (
flow.user_op_builder(name)
.Op("fused_cast_scale")
.Input("x", [x])
.Input("scalar", [scalar])
.Input("scale_by_tensor", [scale_by_tensor])
.Output("y")
.Attr("scale", float(scale))
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
......@@ -43,7 +44,12 @@ def fused_cast_scale(x, scalar, name):
def compare_with_tensorflow(
device_type, input_shape, in_dtype, out_dtype, test_fuse_cast_scale_pass
device_type,
input_shape,
in_dtype,
out_dtype,
test_fuse_cast_scale_pass,
has_scalar_mul,
):
assert device_type in ["gpu", "cpu"]
flow.clear_default_session()
......@@ -71,13 +77,19 @@ def compare_with_tensorflow(
)
loss = flow.cast(x, dtype=type_name_to_flow_type[in_dtype])
if test_fuse_cast_scale_pass:
loss = flow.cast(
loss, dtype=type_name_to_flow_type[out_dtype]
) * flow.cast(scale, dtype=type_name_to_flow_type[out_dtype])
loss = flow.cast(loss, dtype=type_name_to_flow_type[out_dtype])
if has_scalar_mul:
loss = loss * 0.125
loss = loss * flow.cast(scale, dtype=type_name_to_flow_type[out_dtype])
else:
if has_scalar_mul:
scale_val = 0.125
else:
scale_val = 1.0
loss = fused_cast_scale(
loss,
flow.cast(scale, dtype=type_name_to_flow_type[out_dtype]),
scale=scale_val,
name="fused_cast_scale",
)
loss = flow.cast(loss, dtype=flow.float)
......@@ -96,6 +108,8 @@ def compare_with_tensorflow(
tf_out = tf.cast(tf_out, dtype=type_name_to_np_type[out_dtype]) * tf.cast(
scale, dtype=type_name_to_np_type[out_dtype]
)
if has_scalar_mul:
tf_out = tf_out * 0.125
tf_out = tf.cast(tf_out, dtype=tf.float32)
assert np.allclose(of_out.numpy(), tf_out.numpy(), rtol=1e-5, atol=1e-5)
......@@ -110,6 +124,7 @@ class TestFusedCastScale(flow.unittest.TestCase):
arg_dict["in_dtype"] = ["float16", "float32", "double"]
arg_dict["out_dtype"] = ["float16", "float32", "double"]
arg_dict["test_fuse_cast_scale_pass"] = [True, False]
arg_dict["has_scalar_mul"] = [True, False]
for arg in GenArgList(arg_dict):
if arg[2] == arg[3]:
continue
......
......@@ -26,13 +26,14 @@ class FusedCastScaleCpuKernel final : public user_op::OpKernel {
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* scalar = ctx->Tensor4ArgNameAndIndex("scalar", 0);
const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const double scale_val = ctx->Attr<double>("scale");
const int64_t n = x->shape().elem_cnt();
const auto scalar_val = *(scalar->dptr<T>());
const T scale = *(scale_by_tensor->dptr<T>()) * scale_val;
const U* x_ptr = x->dptr<U>();
T* y_ptr = y->mut_dptr<T>();
FOR_RANGE(int64_t, i, 0, n) { y_ptr[i] = static_cast<T>(x_ptr[i]) * scalar_val; }
FOR_RANGE(int64_t, i, 0, n) { y_ptr[i] = static_cast<T>(x_ptr[i]) * scale; }
};
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
......
......@@ -21,43 +21,46 @@ namespace oneflow {
namespace {
template<typename T, typename U>
__global__ void FusedCastScaleGpu(const int64_t n, const U* in, const T* scalar, T* out) {
const T scalar_val = *scalar;
CUDA_1D_KERNEL_LOOP(i, n) { out[i] = static_cast<T>(in[i]) * scalar_val; }
__global__ void FusedCastScaleGpu(const int64_t n, const T scale_val, const U* in,
const T* scale_by_ptr, T* out) {
const T scale = *scale_by_ptr * scale_val;
CUDA_1D_KERNEL_LOOP(i, n) { out[i] = static_cast<T>(in[i]) * scale; }
}
template<>
__global__ void FusedCastScaleGpu<float, half>(const int64_t n, const half* in, const float* scalar,
__global__ void FusedCastScaleGpu<float, half>(const int64_t n, const float scale_val,
const half* in, const float* scale_by_ptr,
float* out) {
const float scalar_val = *scalar;
const float scale = *scale_by_ptr * scale_val;
const int64_t n_2 = n / 2;
const auto* in_2 = reinterpret_cast<const half2*>(in);
auto* out_2 = reinterpret_cast<float2*>(out);
CUDA_1D_KERNEL_LOOP(i, n_2) {
float2 f2 = __half22float2(in_2[i]);
f2.x *= scalar_val;
f2.y *= scalar_val;
f2.x *= scale;
f2.y *= scale;
out_2[i] = f2;
}
if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) {
out[n - 1] = __half2float(in[n - 1]) * scalar_val;
out[n - 1] = __half2float(in[n - 1]) * scale;
}
}
template<>
__global__ void FusedCastScaleGpu<half, float>(const int64_t n, const float* in, const half* scalar,
__global__ void FusedCastScaleGpu<half, float>(const int64_t n, const half scale_val,
const float* in, const half* scale_by_ptr,
half* out) {
const half scalar_val = *scalar;
const half2 scalar_h2 = __half2half2(scalar_val);
const half scale = *scale_by_ptr * scale_val;
const half2 scale_h2 = __half2half2(scale);
const int64_t n_2 = n / 2;
const auto* in_2 = reinterpret_cast<const float2*>(in);
auto* out_h2 = reinterpret_cast<half2*>(out);
CUDA_1D_KERNEL_LOOP(i, n_2) {
half2 in_h2 = __float22half2_rn(in_2[i]);
out_h2[i] = __hmul2(in_h2, scalar_h2);
out_h2[i] = __hmul2(in_h2, scale_h2);
}
if (n % 2 == 1 && blockIdx.x == 0 && threadIdx.x == 0) {
out[n - 1] = __float2half(in[n - 1]) * scalar_val;
out[n - 1] = __float2half(in[n - 1]) * scale;
}
}
......@@ -70,16 +73,17 @@ class FusedCastScaleGpuKernel final : public user_op::OpKernel {
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* scalar = ctx->Tensor4ArgNameAndIndex("scalar", 0);
const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const int64_t n = x->shape().elem_cnt();
const double scale = ctx->Attr<double>("scale");
const int64_t launch_n = ((std::is_same<T, half>::value && std::is_same<U, float>::value)
|| (std::is_same<T, float>::value && std::is_same<U, half>::value))
? RoundUp(n, 2) / 2
: n;
FusedCastScaleGpu<T, U><<<BlocksNum4ThreadsNum(launch_n), kCudaThreadsNumPerBlock, 0,
ctx->device_ctx()->cuda_stream()>>>(
n, x->dptr<U>(), scalar->dptr<T>(), y->mut_dptr<T>());
n, static_cast<T>(scale), x->dptr<U>(), scale_by_tensor->dptr<T>(), y->mut_dptr<T>());
};
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
......
......@@ -20,9 +20,10 @@ namespace {
Maybe<void> TensorDescInfer(user_op::InferContext* ctx) {
const user_op::TensorDesc* x = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const user_op::TensorDesc* scalar = ctx->TensorDesc4ArgNameAndIndex("scalar", 0);
CHECK_EQ_OR_RETURN(scalar->shape().NumAxes(), 1);
CHECK_EQ_OR_RETURN(scalar->shape().At(0), 1);
const user_op::TensorDesc* scale_by_tensor =
ctx->TensorDesc4ArgNameAndIndex("scale_by_tensor", 0);
CHECK_EQ_OR_RETURN(scale_by_tensor->shape().NumAxes(), 1);
CHECK_EQ_OR_RETURN(scale_by_tensor->shape().At(0), 1);
user_op::TensorDesc* y = ctx->TensorDesc4ArgNameAndIndex("y", 0);
*y->mut_is_dynamic() = x->is_dynamic();
*y->mut_shape() = x->shape();
......@@ -30,9 +31,10 @@ Maybe<void> TensorDescInfer(user_op::InferContext* ctx) {
}
Maybe<void> DataTypeInfer(user_op::InferContext* ctx) {
const user_op::TensorDesc* scalar = ctx->TensorDesc4ArgNameAndIndex("scalar", 0);
const user_op::TensorDesc* scale_by_tensor =
ctx->TensorDesc4ArgNameAndIndex("scale_by_tensor", 0);
user_op::TensorDesc* y = ctx->TensorDesc4ArgNameAndIndex("y", 0);
*y->mut_data_type() = scalar->data_type();
*y->mut_data_type() = scale_by_tensor->data_type();
return Maybe<void>::Ok();
}
......@@ -40,18 +42,18 @@ Maybe<void> GetSbpSignatures(user_op::SbpContext* ctx) {
const auto& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
for (int i = 0; i < x.shape().NumAxes(); ++i) {
ctx->NewBuilder()
.Broadcast(user_op::OpArg("scalar", 0))
.Broadcast(user_op::OpArg("scale_by_tensor", 0))
.Split(user_op::OpArg("x", 0), i)
.Split(user_op::OpArg("y", 0), i)
.Build();
}
ctx->NewBuilder()
.PartialSum(user_op::OpArg("scalar", 0))
.PartialSum(user_op::OpArg("scale_by_tensor", 0))
.Broadcast(user_op::OpArg("x", 0))
.PartialSum(user_op::OpArg("y", 0))
.Build();
ctx->NewBuilder()
.Broadcast(user_op::OpArg("scalar", 0))
.Broadcast(user_op::OpArg("scale_by_tensor", 0))
.PartialSum(user_op::OpArg("x", 0))
.PartialSum(user_op::OpArg("y", 0))
.Build();
......@@ -60,8 +62,9 @@ Maybe<void> GetSbpSignatures(user_op::SbpContext* ctx) {
REGISTER_USER_OP("fused_cast_scale")
.Input("x")
.Input("scalar")
.Input("scale_by_tensor")
.Output("y")
.Attr<double>("scale", 1.0)
.SetTensorDescInferFn(TensorDescInfer)
.SetGetSbpFn(GetSbpSignatures)
.SetInferDataTypeFn(DataTypeInfer);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment