From 11126849a4f05b99530f412b27c3fdc445f532b7 Mon Sep 17 00:00:00 2001 From: XIE Xuan <xiexuanx2@gmail.com> Date: Thu, 15 Jul 2021 08:43:31 +0800 Subject: [PATCH] Dev ofrecord auto truncating (#5412) * auto_zero_padding -> auto_truncating in cpp * auto_truncating for python interface * fix element count * fix * fix bug auto_truncating<->dim1_varying_length * auto_truncating -> truncate * rm not in if * of format Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/python/nn/modules/dataset.py | 22 ++++++++-- oneflow/python/ops/data_ops.py | 11 +++-- oneflow/python/ops/user_data_ops.py | 8 +++- .../user/kernels/ofrecord_decoder_kernels.cpp | 41 ++++++++++--------- oneflow/user/ops/ofrecord_decoder_ops.cpp | 2 +- 5 files changed, 57 insertions(+), 27 deletions(-) diff --git a/oneflow/python/nn/modules/dataset.py b/oneflow/python/nn/modules/dataset.py index 2f18db472..9b3fa9a4b 100644 --- a/oneflow/python/nn/modules/dataset.py +++ b/oneflow/python/nn/modules/dataset.py @@ -87,11 +87,16 @@ class OfrecordRawDecoder(Module): shape: Sequence[int], dtype: flow.dtype, dim1_varying_length: bool = False, + truncate: bool = False, auto_zero_padding: bool = False, name: Optional[str] = None, ): super().__init__() - + if auto_zero_padding: + print( + """WARNING: auto_zero_padding has been deprecated, Please use truncate instead. + """ + ) self._op = ( flow.builtin_op("ofrecord_raw_decoder", name) .Input("in") @@ -100,7 +105,7 @@ class OfrecordRawDecoder(Module): .Attr("shape", shape) .Attr("data_type", dtype) .Attr("dim1_varying_length", dim1_varying_length) - .Attr("auto_zero_padding", auto_zero_padding) + .Attr("truncate", truncate or auto_zero_padding) .Build() ) @@ -424,11 +429,22 @@ def raw_decoder( shape: Sequence[int], dtype: flow.dtype, dim1_varying_length: bool = False, + truncate: bool = False, auto_zero_padding: bool = False, name: Optional[str] = None, ): + if auto_zero_padding: + print( + """WARNING: auto_zero_padding has been deprecated, Please use truncate instead. + """ + ) return OfrecordRawDecoder( - blob_name, shape, dtype, dim1_varying_length, auto_zero_padding, name + blob_name, + shape, + dtype, + dim1_varying_length, + truncate or auto_zero_padding, + name, ).forward(input_record) diff --git a/oneflow/python/ops/data_ops.py b/oneflow/python/ops/data_ops.py index 6b0efed51..a640a2582 100644 --- a/oneflow/python/ops/data_ops.py +++ b/oneflow/python/ops/data_ops.py @@ -94,8 +94,13 @@ class ImageCodec(object): @oneflow_export("data.RawCodec") class RawCodec(object): - def __init__(self, auto_zero_padding: bool = False) -> None: - self.auto_zero_padding = auto_zero_padding + def __init__(self, truncate: bool = False, auto_zero_padding: bool = False) -> None: + if auto_zero_padding: + print( + """WARNING: auto_zero_padding has been deprecated, Please use truncate instead. + """ + ) + self.truncate = truncate or auto_zero_padding @oneflow_export("data.NormByChannelPreprocessor") @@ -183,7 +188,7 @@ class BlobConf(object): blob_name=self.name, shape=self.shape, dtype=self.dtype, - auto_zero_padding=self.codec.auto_zero_padding, + truncate=self.codec.truncate, ) return raw else: diff --git a/oneflow/python/ops/user_data_ops.py b/oneflow/python/ops/user_data_ops.py index 41fcfbc1c..37549b646 100644 --- a/oneflow/python/ops/user_data_ops.py +++ b/oneflow/python/ops/user_data_ops.py @@ -34,9 +34,15 @@ def OFRecordRawDecoder( shape: Sequence[int], dtype: flow.dtype, dim1_varying_length: bool = False, + truncate: bool = False, auto_zero_padding: bool = False, name: Optional[str] = None, ) -> oneflow._oneflow_internal.BlobDesc: + if auto_zero_padding: + print( + """WARNING: auto_zero_padding has been deprecated, Please use truncate instead. + """ + ) if name is None: name = id_util.UniqueStr("OFRecordRawDecoder_") return ( @@ -48,7 +54,7 @@ def OFRecordRawDecoder( .Attr("shape", shape) .Attr("data_type", dtype) .Attr("dim1_varying_length", dim1_varying_length) - .Attr("auto_zero_padding", auto_zero_padding) + .Attr("truncate", truncate or auto_zero_padding) .Build() .InferAndTryRun() .RemoteBlobList()[0] diff --git a/oneflow/user/kernels/ofrecord_decoder_kernels.cpp b/oneflow/user/kernels/ofrecord_decoder_kernels.cpp index 0e323d26b..d59fe8d36 100644 --- a/oneflow/user/kernels/ofrecord_decoder_kernels.cpp +++ b/oneflow/user/kernels/ofrecord_decoder_kernels.cpp @@ -33,8 +33,8 @@ namespace oneflow { namespace { template<typename T> -void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_cnt, - bool dim1_varying_length, bool auto_zero_padding) { +void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_cnt, bool truncate, + bool dim1_varying_length) { if (feature.has_bytes_list()) { CHECK_EQ(feature.bytes_list().value_size(), 1); const auto& value0 = feature.bytes_list().value(0); @@ -42,21 +42,24 @@ void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_c sample_elem_cnt = std::min<int64_t>(sample_elem_cnt, value0.size()); CopyElem<int8_t, T>(in_dptr, dptr, sample_elem_cnt); } -#define DEFINE_ONE_ELIF(PbT, CppT) \ - else if (feature.has_##PbT##_list()) { \ - const auto& list = feature.PbT##_list(); \ - const CppT* in_dptr = list.value().data(); \ - const int64_t padding_elem_num = auto_zero_padding ? sample_elem_cnt - list.value_size() : 0; \ - if (dim1_varying_length || auto_zero_padding) { \ - CHECK_LE(list.value_size(), sample_elem_cnt); \ - sample_elem_cnt = list.value_size(); \ - } else { \ - CHECK_EQ(sample_elem_cnt, list.value_size()); \ - } \ - CopyElem<CppT, T>(in_dptr, dptr, sample_elem_cnt); \ - if (padding_elem_num > 0) { \ - std::memset(dptr + sample_elem_cnt, 0, padding_elem_num * sizeof(T)); \ - } \ +#define DEFINE_ONE_ELIF(PbT, CppT) \ + else if (feature.has_##PbT##_list()) { \ + const auto& list = feature.PbT##_list(); \ + const CppT* in_dptr = list.value().data(); \ + const int64_t padding_elem_num = truncate ? sample_elem_cnt - list.value_size() : 0; \ + if (truncate) { \ + sample_elem_cnt = std::min<int64_t>(sample_elem_cnt, list.value_size()); \ + } else { \ + if (dim1_varying_length) { \ + sample_elem_cnt = list.value_size(); \ + } else { \ + CHECK_EQ(sample_elem_cnt, list.value_size()); \ + } \ + } \ + CopyElem<CppT, T>(in_dptr, dptr, sample_elem_cnt); \ + if (padding_elem_num > 0) { \ + std::memset(dptr + sample_elem_cnt, 0, padding_elem_num * sizeof(T)); \ + } \ } DEFINE_ONE_ELIF(float, float) DEFINE_ONE_ELIF(double, double) @@ -88,7 +91,7 @@ class OFRecordRawDecoderKernel final : public user_op::OpKernel { T* out_dptr = out_blob->mut_dptr<T>(); const std::string& name = ctx->Attr<std::string>("name"); - bool auto_zero_padding = ctx->Attr<bool>("auto_zero_padding"); + bool truncate = ctx->Attr<bool>("truncate"); bool dim1_varying_length = ctx->Attr<bool>("dim1_varying_length"); MultiThreadLoop(record_num, [&](size_t i) { @@ -97,7 +100,7 @@ class OFRecordRawDecoderKernel final : public user_op::OpKernel { CHECK(record.feature().find(name) != record.feature().end()) << "Field " << name << " not found"; const Feature& feature = record.feature().at(name); - DecodeOneRawOFRecord(feature, dptr, sample_elem_cnt, auto_zero_padding, dim1_varying_length); + DecodeOneRawOFRecord(feature, dptr, sample_elem_cnt, truncate, dim1_varying_length); }); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/oneflow/user/ops/ofrecord_decoder_ops.cpp b/oneflow/user/ops/ofrecord_decoder_ops.cpp index 1250fe474..2fa02969b 100644 --- a/oneflow/user/ops/ofrecord_decoder_ops.cpp +++ b/oneflow/user/ops/ofrecord_decoder_ops.cpp @@ -25,7 +25,7 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_raw_decoder") .Attr<Shape>("shape") .Attr<DataType>("data_type") .Attr<bool>("dim1_varying_length", false) - .Attr<bool>("auto_zero_padding", false) + .Attr<bool>("truncate", false) .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); -- GitLab