Skip to content
Snippets Groups Projects
Unverified Commit 11126849 authored by XIE Xuan's avatar XIE Xuan Committed by GitHub
Browse files

Dev ofrecord auto truncating (#5412)


* auto_zero_padding -> auto_truncating in cpp

* auto_truncating for python interface

* fix element count

* fix

* fix bug auto_truncating<->dim1_varying_length

* auto_truncating -> truncate

* rm not in if

* of format

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 30a37272
No related branches found
No related tags found
No related merge requests found
...@@ -87,11 +87,16 @@ class OfrecordRawDecoder(Module): ...@@ -87,11 +87,16 @@ class OfrecordRawDecoder(Module):
shape: Sequence[int], shape: Sequence[int],
dtype: flow.dtype, dtype: flow.dtype,
dim1_varying_length: bool = False, dim1_varying_length: bool = False,
truncate: bool = False,
auto_zero_padding: bool = False, auto_zero_padding: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
): ):
super().__init__() super().__init__()
if auto_zero_padding:
print(
"""WARNING: auto_zero_padding has been deprecated, Please use truncate instead.
"""
)
self._op = ( self._op = (
flow.builtin_op("ofrecord_raw_decoder", name) flow.builtin_op("ofrecord_raw_decoder", name)
.Input("in") .Input("in")
...@@ -100,7 +105,7 @@ class OfrecordRawDecoder(Module): ...@@ -100,7 +105,7 @@ class OfrecordRawDecoder(Module):
.Attr("shape", shape) .Attr("shape", shape)
.Attr("data_type", dtype) .Attr("data_type", dtype)
.Attr("dim1_varying_length", dim1_varying_length) .Attr("dim1_varying_length", dim1_varying_length)
.Attr("auto_zero_padding", auto_zero_padding) .Attr("truncate", truncate or auto_zero_padding)
.Build() .Build()
) )
...@@ -424,11 +429,22 @@ def raw_decoder( ...@@ -424,11 +429,22 @@ def raw_decoder(
shape: Sequence[int], shape: Sequence[int],
dtype: flow.dtype, dtype: flow.dtype,
dim1_varying_length: bool = False, dim1_varying_length: bool = False,
truncate: bool = False,
auto_zero_padding: bool = False, auto_zero_padding: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
): ):
if auto_zero_padding:
print(
"""WARNING: auto_zero_padding has been deprecated, Please use truncate instead.
"""
)
return OfrecordRawDecoder( return OfrecordRawDecoder(
blob_name, shape, dtype, dim1_varying_length, auto_zero_padding, name blob_name,
shape,
dtype,
dim1_varying_length,
truncate or auto_zero_padding,
name,
).forward(input_record) ).forward(input_record)
......
...@@ -94,8 +94,13 @@ class ImageCodec(object): ...@@ -94,8 +94,13 @@ class ImageCodec(object):
@oneflow_export("data.RawCodec") @oneflow_export("data.RawCodec")
class RawCodec(object): class RawCodec(object):
def __init__(self, auto_zero_padding: bool = False) -> None: def __init__(self, truncate: bool = False, auto_zero_padding: bool = False) -> None:
self.auto_zero_padding = auto_zero_padding if auto_zero_padding:
print(
"""WARNING: auto_zero_padding has been deprecated, Please use truncate instead.
"""
)
self.truncate = truncate or auto_zero_padding
@oneflow_export("data.NormByChannelPreprocessor") @oneflow_export("data.NormByChannelPreprocessor")
...@@ -183,7 +188,7 @@ class BlobConf(object): ...@@ -183,7 +188,7 @@ class BlobConf(object):
blob_name=self.name, blob_name=self.name,
shape=self.shape, shape=self.shape,
dtype=self.dtype, dtype=self.dtype,
auto_zero_padding=self.codec.auto_zero_padding, truncate=self.codec.truncate,
) )
return raw return raw
else: else:
......
...@@ -34,9 +34,15 @@ def OFRecordRawDecoder( ...@@ -34,9 +34,15 @@ def OFRecordRawDecoder(
shape: Sequence[int], shape: Sequence[int],
dtype: flow.dtype, dtype: flow.dtype,
dim1_varying_length: bool = False, dim1_varying_length: bool = False,
truncate: bool = False,
auto_zero_padding: bool = False, auto_zero_padding: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
) -> oneflow._oneflow_internal.BlobDesc: ) -> oneflow._oneflow_internal.BlobDesc:
if auto_zero_padding:
print(
"""WARNING: auto_zero_padding has been deprecated, Please use truncate instead.
"""
)
if name is None: if name is None:
name = id_util.UniqueStr("OFRecordRawDecoder_") name = id_util.UniqueStr("OFRecordRawDecoder_")
return ( return (
...@@ -48,7 +54,7 @@ def OFRecordRawDecoder( ...@@ -48,7 +54,7 @@ def OFRecordRawDecoder(
.Attr("shape", shape) .Attr("shape", shape)
.Attr("data_type", dtype) .Attr("data_type", dtype)
.Attr("dim1_varying_length", dim1_varying_length) .Attr("dim1_varying_length", dim1_varying_length)
.Attr("auto_zero_padding", auto_zero_padding) .Attr("truncate", truncate or auto_zero_padding)
.Build() .Build()
.InferAndTryRun() .InferAndTryRun()
.RemoteBlobList()[0] .RemoteBlobList()[0]
......
...@@ -33,8 +33,8 @@ namespace oneflow { ...@@ -33,8 +33,8 @@ namespace oneflow {
namespace { namespace {
template<typename T> template<typename T>
void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_cnt, void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_cnt, bool truncate,
bool dim1_varying_length, bool auto_zero_padding) { bool dim1_varying_length) {
if (feature.has_bytes_list()) { if (feature.has_bytes_list()) {
CHECK_EQ(feature.bytes_list().value_size(), 1); CHECK_EQ(feature.bytes_list().value_size(), 1);
const auto& value0 = feature.bytes_list().value(0); const auto& value0 = feature.bytes_list().value(0);
...@@ -42,21 +42,24 @@ void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_c ...@@ -42,21 +42,24 @@ void DecodeOneRawOFRecord(const Feature& feature, T* dptr, int64_t sample_elem_c
sample_elem_cnt = std::min<int64_t>(sample_elem_cnt, value0.size()); sample_elem_cnt = std::min<int64_t>(sample_elem_cnt, value0.size());
CopyElem<int8_t, T>(in_dptr, dptr, sample_elem_cnt); CopyElem<int8_t, T>(in_dptr, dptr, sample_elem_cnt);
} }
#define DEFINE_ONE_ELIF(PbT, CppT) \ #define DEFINE_ONE_ELIF(PbT, CppT) \
else if (feature.has_##PbT##_list()) { \ else if (feature.has_##PbT##_list()) { \
const auto& list = feature.PbT##_list(); \ const auto& list = feature.PbT##_list(); \
const CppT* in_dptr = list.value().data(); \ const CppT* in_dptr = list.value().data(); \
const int64_t padding_elem_num = auto_zero_padding ? sample_elem_cnt - list.value_size() : 0; \ const int64_t padding_elem_num = truncate ? sample_elem_cnt - list.value_size() : 0; \
if (dim1_varying_length || auto_zero_padding) { \ if (truncate) { \
CHECK_LE(list.value_size(), sample_elem_cnt); \ sample_elem_cnt = std::min<int64_t>(sample_elem_cnt, list.value_size()); \
sample_elem_cnt = list.value_size(); \ } else { \
} else { \ if (dim1_varying_length) { \
CHECK_EQ(sample_elem_cnt, list.value_size()); \ sample_elem_cnt = list.value_size(); \
} \ } else { \
CopyElem<CppT, T>(in_dptr, dptr, sample_elem_cnt); \ CHECK_EQ(sample_elem_cnt, list.value_size()); \
if (padding_elem_num > 0) { \ } \
std::memset(dptr + sample_elem_cnt, 0, padding_elem_num * sizeof(T)); \ } \
} \ CopyElem<CppT, T>(in_dptr, dptr, sample_elem_cnt); \
if (padding_elem_num > 0) { \
std::memset(dptr + sample_elem_cnt, 0, padding_elem_num * sizeof(T)); \
} \
} }
DEFINE_ONE_ELIF(float, float) DEFINE_ONE_ELIF(float, float)
DEFINE_ONE_ELIF(double, double) DEFINE_ONE_ELIF(double, double)
...@@ -88,7 +91,7 @@ class OFRecordRawDecoderKernel final : public user_op::OpKernel { ...@@ -88,7 +91,7 @@ class OFRecordRawDecoderKernel final : public user_op::OpKernel {
T* out_dptr = out_blob->mut_dptr<T>(); T* out_dptr = out_blob->mut_dptr<T>();
const std::string& name = ctx->Attr<std::string>("name"); const std::string& name = ctx->Attr<std::string>("name");
bool auto_zero_padding = ctx->Attr<bool>("auto_zero_padding"); bool truncate = ctx->Attr<bool>("truncate");
bool dim1_varying_length = ctx->Attr<bool>("dim1_varying_length"); bool dim1_varying_length = ctx->Attr<bool>("dim1_varying_length");
MultiThreadLoop(record_num, [&](size_t i) { MultiThreadLoop(record_num, [&](size_t i) {
...@@ -97,7 +100,7 @@ class OFRecordRawDecoderKernel final : public user_op::OpKernel { ...@@ -97,7 +100,7 @@ class OFRecordRawDecoderKernel final : public user_op::OpKernel {
CHECK(record.feature().find(name) != record.feature().end()) CHECK(record.feature().find(name) != record.feature().end())
<< "Field " << name << " not found"; << "Field " << name << " not found";
const Feature& feature = record.feature().at(name); const Feature& feature = record.feature().at(name);
DecodeOneRawOFRecord(feature, dptr, sample_elem_cnt, auto_zero_padding, dim1_varying_length); DecodeOneRawOFRecord(feature, dptr, sample_elem_cnt, truncate, dim1_varying_length);
}); });
} }
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
......
...@@ -25,7 +25,7 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_raw_decoder") ...@@ -25,7 +25,7 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_raw_decoder")
.Attr<Shape>("shape") .Attr<Shape>("shape")
.Attr<DataType>("data_type") .Attr<DataType>("data_type")
.Attr<bool>("dim1_varying_length", false) .Attr<bool>("dim1_varying_length", false)
.Attr<bool>("auto_zero_padding", false) .Attr<bool>("truncate", false)
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0);
user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment