diff --git a/oneflow/python/ops/user_data_ops.py b/oneflow/python/ops/user_data_ops.py index 2c658bff0608910547cb297d039f3577a11bf9e5..050f310c6a9321e3d6c6119ac08344919560f0e4 100644 --- a/oneflow/python/ops/user_data_ops.py +++ b/oneflow/python/ops/user_data_ops.py @@ -634,3 +634,49 @@ class COCOReader(module_util.Module): .InferAndTryRun() .RemoteBlobList() ) + + +@oneflow_export("data.ofrecord_image_classification_reader") +def ofrecord_image_classification_reader( + ofrecord_dir: str, + image_feature_name: str, + label_feature_name: str, + batch_size: int = 1, + data_part_num: int = 1, + part_name_prefix: str = "part-", + part_name_suffix_length: int = -1, + random_shuffle: bool = False, + shuffle_buffer_size: int = 1024, + shuffle_after_epoch: bool = False, + color_space: str = "BGR", + decode_buffer_size_per_thread: int = 32, + num_decode_threads_per_machine: Optional[int] = None, + name: Optional[str] = None, +) -> BlobDef: + if name is None: + name = id_util.UniqueStr("OFRecordImageClassificationReader_") + (image, label) = ( + flow.user_op_builder(name) + .Op("ofrecord_image_classification_reader") + .Output("image") + .Output("label") + .Attr("data_dir", ofrecord_dir) + .Attr("data_part_num", data_part_num) + .Attr("batch_size", batch_size) + .Attr("part_name_prefix", part_name_prefix) + .Attr("random_shuffle", random_shuffle) + .Attr("shuffle_buffer_size", shuffle_buffer_size) + .Attr("shuffle_after_epoch", shuffle_after_epoch) + .Attr("part_name_suffix_length", part_name_suffix_length) + .Attr("color_space", color_space) + .Attr("image_feature_name", image_feature_name) + .Attr("label_feature_name", label_feature_name) + .Attr("decode_buffer_size_per_thread", decode_buffer_size_per_thread) + .Attr("num_decode_threads_per_machine", num_decode_threads_per_machine or 0) + .Build() + .InferAndTryRun() + .RemoteBlobList() + ) + label = flow.tensor_buffer_to_tensor(label, dtype=flow.int32, instance_shape=[1]) + label = flow.squeeze(label, axis=[-1]) + return image, label diff --git a/oneflow/python/test/models/alexnet.py b/oneflow/python/test/models/alexnet.py index 2e1c70f291a34afb65fa97abdae89e1adea10147..f5a5f701f78180699b20730bf8c0d7a31a98ee0a 100644 --- a/oneflow/python/test/models/alexnet.py +++ b/oneflow/python/test/models/alexnet.py @@ -124,16 +124,15 @@ def _data_load_layer(args, data_dir): node_num = args.num_nodes total_batch_size = args.batch_size * args.gpu_num_per_node * node_num rgb_mean = [123.68, 116.78, 103.94] - ofrecord = flow.data.ofrecord_reader( + (image, label) = flow.data.ofrecord_image_classification_reader( data_dir, batch_size=total_batch_size, data_part_num=args.data_part_num, + image_feature_name="encoded", + label_feature_name="class/label", + color_space="RGB", name="decode", ) - image = flow.data.ofrecord_image_decoder(ofrecord, "encoded", color_space="RGB") - label = flow.data.ofrecord_raw_decoder( - ofrecord, "class/label", shape=(), dtype=flow.int32 - ) rsz = flow.image.resize(image, resize_x=227, resize_y=227, color_space="RGB") normal = flow.image.crop_mirror_normalize( rsz, diff --git a/oneflow/user/data/ofrecord_image_classification_data_reader.h b/oneflow/user/data/ofrecord_image_classification_data_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..9bfdbcd88475953b451b99d8a0484a98fe3df88e --- /dev/null +++ b/oneflow/user/data/ofrecord_image_classification_data_reader.h @@ -0,0 +1,58 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_ +#define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_ + +#include "oneflow/user/data/data_reader.h" +#include "oneflow/user/data/ofrecord_dataset.h" +#include "oneflow/user/data/ofrecord_parser.h" +#include "oneflow/user/data/random_shuffle_dataset.h" +#include "oneflow/user/data/batch_dataset.h" +#include "oneflow/user/data/ofrecord_image_classification_dataset.h" +#include "oneflow/user/data/ofrecord_image_classification_parser.h" + +namespace oneflow { + +namespace data { + +class OFRecordImageClassificationDataReader final + : public DataReader<ImageClassificationDataInstance> { + public: + explicit OFRecordImageClassificationDataReader(user_op::KernelInitContext* ctx) + : DataReader<ImageClassificationDataInstance>(ctx) { + std::unique_ptr<Dataset<TensorBuffer>> base(new OFRecordDataset(ctx)); + if (ctx->Attr<bool>("random_shuffle")) { + base.reset(new RandomShuffleDataset<TensorBuffer>(ctx, std::move(base))); + } + loader_.reset(new OFRecordImageClassificationDataset(ctx, std::move(base))); + const int64_t batch_size = ctx->TensorDesc4ArgNameAndIndex("image", 0)->shape().elem_cnt(); + loader_.reset( + new BatchDataset<ImageClassificationDataInstance>(batch_size, std::move(loader_))); + parser_.reset(new OFRecordImageClassificationParser()); + StartLoadThread(); + } + ~OFRecordImageClassificationDataReader() override = default; + + protected: + using DataReader<ImageClassificationDataInstance>::loader_; + using DataReader<ImageClassificationDataInstance>::parser_; +}; + +} // namespace data + +} // namespace oneflow + +#endif // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATA_READER_H_ diff --git a/oneflow/user/data/ofrecord_image_classification_dataset.h b/oneflow/user/data/ofrecord_image_classification_dataset.h new file mode 100644 index 0000000000000000000000000000000000000000..ceb611246d1c423b75ef821a2b69ca70301bb0fe --- /dev/null +++ b/oneflow/user/data/ofrecord_image_classification_dataset.h @@ -0,0 +1,207 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_ +#define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_ + +#include "oneflow/core/thread/thread_pool.h" +#include "oneflow/core/common/buffer.h" +#include "oneflow/user/data/dataset.h" +#include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/common/str_util.h" +#include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/persistence/persistent_in_stream.h" +#include "oneflow/core/job/job_set.pb.h" +#include "oneflow/user/data/ofrecord_dataset.h" +#include "oneflow/user/image/image_util.h" +#include "oneflow/core/job/resource_desc.h" +#include "oneflow/core/job/global_for.h" +#include <opencv2/opencv.hpp> + +namespace oneflow { + +namespace data { + +struct ImageClassificationDataInstance { + std::shared_ptr<TensorBuffer> label; + std::shared_ptr<TensorBuffer> image; +}; + +using BaseDataset = Dataset<TensorBuffer>; +using BaseLoadTargetPtr = BaseDataset::LoadTargetPtr; +using BaseLoadTargetPtrList = BaseDataset::LoadTargetPtrList; + +namespace { + +void DecodeImageFromOFRecord(const OFRecord& record, const std::string& feature_name, + const std::string& color_space, TensorBuffer* out) { + auto image_feature_it = record.feature().find(feature_name); + CHECK(image_feature_it != record.feature().end()); + const Feature& image_feature = image_feature_it->second; + CHECK(image_feature.has_bytes_list()); + CHECK(image_feature.bytes_list().value_size() == 1); + const std::string& src_data = image_feature.bytes_list().value(0); + cv::Mat image = cv::imdecode(cv::Mat(1, src_data.size(), CV_8UC1, (void*)(src_data.data())), + cv::IMREAD_COLOR); + int W = image.cols; + int H = image.rows; + + // convert color space + if (ImageUtil::IsColor(color_space) && color_space != "BGR") { + ImageUtil::ConvertColor("BGR", image, color_space, image); + } + + CHECK(image.isContinuous()); + const int c = ImageUtil::IsColor(color_space) ? 3 : 1; + CHECK_EQ(c, image.channels()); + Shape image_shape({H, W, c}); + out->Resize(image_shape, DataType::kUInt8); + CHECK_EQ(image_shape.elem_cnt(), out->nbytes()); + CHECK_EQ(image_shape.elem_cnt(), image.total() * image.elemSize()); + memcpy(out->mut_data<uint8_t>(), image.ptr(), image_shape.elem_cnt()); +} + +void DecodeLabelFromFromOFRecord(const OFRecord& record, const std::string& feature_name, + TensorBuffer* out) { + auto label_feature_it = record.feature().find(feature_name); + CHECK(label_feature_it != record.feature().end()); + const Feature& label_feature = label_feature_it->second; + out->Resize(Shape({1}), DataType::kInt32); + if (label_feature.has_int32_list()) { + CHECK_EQ(label_feature.int32_list().value_size(), 1); + *out->mut_data<int32_t>() = label_feature.int32_list().value(0); + } else if (label_feature.has_int64_list()) { + CHECK_EQ(label_feature.int64_list().value_size(), 1); + *out->mut_data<int32_t>() = label_feature.int64_list().value(0); + } else { + UNIMPLEMENTED(); + } +} + +void LoadWorker(BaseDataset* record_dataset, + std::vector<std::unique_ptr<Buffer<BaseLoadTargetPtr>>>* decode_in_buffers) { + int64_t thread_idx = 0; + bool shutdown = false; + while (!shutdown) { + BaseLoadTargetPtrList records = record_dataset->Next(); + for (const auto& record : records) { + auto& current_in_buffer = decode_in_buffers->at(thread_idx); + thread_idx = (thread_idx + 1) % decode_in_buffers->size(); + auto status = current_in_buffer->Send(record); + if (status == kBufferStatusErrorClosed) { + shutdown = true; + break; + } + CHECK(status == kBufferStatusSuccess); + } + } +} + +void DecodeWorker(const std::string image_feature_name, const std::string label_feature_name, + const std::string color_space, Buffer<BaseLoadTargetPtr>* in_buffer, + Buffer<std::shared_ptr<ImageClassificationDataInstance>>* out_buffer) { + while (true) { + BaseLoadTargetPtr serialized_record; + auto receive_status = in_buffer->Receive(&serialized_record); + if (receive_status == kBufferStatusErrorClosed) { break; } + CHECK(receive_status == kBufferStatusSuccess); + OFRecord record; + CHECK(record.ParseFromArray(serialized_record->data<char>(), + serialized_record->shape().elem_cnt())); + std::shared_ptr<ImageClassificationDataInstance> instance( + new ImageClassificationDataInstance()); + instance->image.reset(new TensorBuffer()); + DecodeImageFromOFRecord(record, image_feature_name, color_space, instance->image.get()); + instance->label.reset(new TensorBuffer()); + DecodeLabelFromFromOFRecord(record, label_feature_name, instance->label.get()); + auto send_status = out_buffer->Send(instance); + if (send_status == kBufferStatusErrorClosed) { break; } + CHECK(send_status == kBufferStatusSuccess); + } +} + +int32_t GetNumLocalDecodeThreads(int32_t num_decode_threads_per_machine, + const ParallelDesc& parallel_desc, + const ParallelContext& parallel_ctx) { + if (num_decode_threads_per_machine == 0) { + num_decode_threads_per_machine = + Global<ResourceDesc, ForSession>::Get()->ComputeThreadPoolSize(); + } + const int64_t machine_id = parallel_desc.MachineIdForParallelId(parallel_ctx.parallel_id()); + const int64_t parallel_num_on_this_machine = parallel_desc.sorted_dev_phy_ids(machine_id).size(); + return std::max<int32_t>(num_decode_threads_per_machine / parallel_num_on_this_machine, 1); +} + +} // namespace + +class OFRecordImageClassificationDataset final : public Dataset<ImageClassificationDataInstance> { + public: + using LoadTargetPtr = std::shared_ptr<ImageClassificationDataInstance>; + using LoadTargetPtrList = std::vector<LoadTargetPtr>; + OF_DISALLOW_COPY_AND_MOVE(OFRecordImageClassificationDataset); + OFRecordImageClassificationDataset(user_op::KernelInitContext* ctx, + std::unique_ptr<BaseDataset>&& base) + : base_(std::move(base)), out_thread_idx_(0) { + const std::string& color_space = ctx->Attr<std::string>("color_space"); + const std::string& image_feature_name = ctx->Attr<std::string>("image_feature_name"); + const std::string& label_feature_name = ctx->Attr<std::string>("label_feature_name"); + const auto num_decode_threads_per_machine = + ctx->Attr<int32_t>("num_decode_threads_per_machine"); + const auto decode_buffer_size_per_thread = ctx->Attr<int32_t>("decode_buffer_size_per_thread"); + const int32_t num_local_decode_threads = GetNumLocalDecodeThreads( + num_decode_threads_per_machine, ctx->parallel_desc(), ctx->parallel_ctx()); + decode_in_buffers_.resize(num_local_decode_threads); + decode_out_buffers_.resize(num_local_decode_threads); + for (int64_t i = 0; i < num_local_decode_threads; ++i) { + decode_in_buffers_.at(i).reset(new Buffer<BaseLoadTargetPtr>(decode_buffer_size_per_thread)); + decode_out_buffers_.at(i).reset(new Buffer<LoadTargetPtr>(decode_buffer_size_per_thread)); + decode_threads_.emplace_back( + std::thread(&DecodeWorker, image_feature_name, label_feature_name, color_space, + decode_in_buffers_.at(i).get(), decode_out_buffers_.at(i).get())); + } + load_thread_ = std::thread(&LoadWorker, base_.get(), &decode_in_buffers_); + } + ~OFRecordImageClassificationDataset() override { + for (auto& out_buffer : decode_out_buffers_) { out_buffer->Close(); } + for (auto& in_buffer : decode_in_buffers_) { in_buffer->Close(); } + load_thread_.join(); + for (auto& decode_thread : decode_threads_) { decode_thread.join(); } + } + + LoadTargetPtrList Next() override { + LoadTargetPtrList ret; + LoadTargetPtr sample_ptr; + size_t thread_idx = + out_thread_idx_.fetch_add(1, std::memory_order_relaxed) % decode_out_buffers_.size(); + auto status = decode_out_buffers_.at(thread_idx)->Receive(&sample_ptr); + CHECK_EQ(status, kBufferStatusSuccess); + ret.push_back(std::move(sample_ptr)); + return ret; + } + + private: + std::unique_ptr<BaseDataset> base_; + std::thread load_thread_; + std::vector<std::thread> decode_threads_; + std::vector<std::unique_ptr<Buffer<BaseLoadTargetPtr>>> decode_in_buffers_; + std::vector<std::unique_ptr<Buffer<LoadTargetPtr>>> decode_out_buffers_; + std::atomic<size_t> out_thread_idx_; +}; + +} // namespace data + +} // namespace oneflow + +#endif // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_DATASET_H_ diff --git a/oneflow/user/data/ofrecord_image_classification_parser.h b/oneflow/user/data/ofrecord_image_classification_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..5b4d2880ee8de2131f209d92478e695e98d51ac4 --- /dev/null +++ b/oneflow/user/data/ofrecord_image_classification_parser.h @@ -0,0 +1,59 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_ +#define ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_ + +#include "oneflow/user/data/parser.h" +#include "oneflow/core/common/tensor_buffer.h" +#include "oneflow/core/record/record.pb.h" +#include "oneflow/core/thread/thread_manager.h" +#include "oneflow/user/data/ofrecord_image_classification_dataset.h" + +namespace oneflow { + +namespace data { + +class OFRecordImageClassificationParser final : public Parser<ImageClassificationDataInstance> { + public: + using LoadTargetPtr = std::shared_ptr<ImageClassificationDataInstance>; + using LoadTargetPtrList = std::vector<LoadTargetPtr>; + OFRecordImageClassificationParser() = default; + ~OFRecordImageClassificationParser() override = default; + + void Parse(std::shared_ptr<LoadTargetPtrList> batch_data, + user_op::KernelComputeContext* ctx) override { + const int64_t batch_size = batch_data->size(); + user_op::Tensor* image_tensor = ctx->Tensor4ArgNameAndIndex("image", 0); + CHECK_EQ(image_tensor->shape().NumAxes(), 1); + CHECK_EQ(image_tensor->shape().At(0), batch_size); + auto* image_buffers = image_tensor->mut_dptr<TensorBuffer>(); + user_op::Tensor* label_tensor = ctx->Tensor4ArgNameAndIndex("label", 0); + CHECK_EQ(label_tensor->shape().NumAxes(), 1); + CHECK_EQ(label_tensor->shape().At(0), batch_size); + auto* label_buffers = label_tensor->mut_dptr<TensorBuffer>(); + for (int64_t i = 0; i < batch_data->size(); ++i) { + const auto& instance = batch_data->at(i); + image_buffers[i].Swap(instance->image.get()); + label_buffers[i].Swap(instance->label.get()); + } + } +}; + +} // namespace data + +} // namespace oneflow + +#endif // ONEFLOW_USER_DATA_OFRECORD_IMAGE_CLASSIFICATION_PARSER_H_ diff --git a/oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp b/oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89723519ad29364d7ace277771ae579a1a6c36e9 --- /dev/null +++ b/oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp @@ -0,0 +1,62 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/data/ofrecord_image_classification_data_reader.h" + +namespace oneflow { + +namespace { + +class OFRecordImageClassificationReaderKernelState final : public user_op::OpKernelState { + public: + explicit OFRecordImageClassificationReaderKernelState(user_op::KernelInitContext* ctx) + : reader_(ctx) {} + ~OFRecordImageClassificationReaderKernelState() override = default; + + void Read(user_op::KernelComputeContext* ctx) { reader_.Read(ctx); } + + private: + data::OFRecordImageClassificationDataReader reader_; +}; + +} // namespace + +class OFRecordImageClassificationReaderKernel final : public user_op::OpKernel { + public: + OFRecordImageClassificationReaderKernel() = default; + ~OFRecordImageClassificationReaderKernel() override = default; + + std::shared_ptr<user_op::OpKernelState> CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + return std::make_shared<OFRecordImageClassificationReaderKernelState>(ctx); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + auto* reader = dynamic_cast<OFRecordImageClassificationReaderKernelState*>(state); + CHECK_NOTNULL(reader); + reader->Read(ctx); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +REGISTER_USER_KERNEL("ofrecord_image_classification_reader") + .SetCreateFn<OFRecordImageClassificationReaderKernel>() + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) + & (user_op::HobDataType("image", 0) == DataType::kTensorBuffer) + & (user_op::HobDataType("label", 0) == DataType::kTensorBuffer)); + +} // namespace oneflow diff --git a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa8b8dfc53f6532d55b48632aaee983851450360 --- /dev/null +++ b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp @@ -0,0 +1,72 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +REGISTER_CPU_ONLY_USER_OP("ofrecord_image_classification_reader") + .Output("image") + .Output("label") + .Attr("data_dir", UserOpAttrType::kAtString) + .Attr("data_part_num", UserOpAttrType::kAtInt32) + .Attr("batch_size", UserOpAttrType::kAtInt32) + .Attr<std::string>("part_name_prefix", UserOpAttrType::kAtString, "part-") + .Attr<int32_t>("part_name_suffix_length", UserOpAttrType::kAtInt32, -1) + .Attr<bool>("random_shuffle", UserOpAttrType::kAtBool, false) + .Attr<int64_t>("seed", UserOpAttrType::kAtInt64, -1) + .Attr<int32_t>("shuffle_buffer_size", UserOpAttrType::kAtInt32, 1024) + .Attr<bool>("shuffle_after_epoch", UserOpAttrType::kAtBool, false) + .Attr<std::string>("color_space", UserOpAttrType::kAtString, "BGR") + .Attr<std::string>("image_feature_name", UserOpAttrType::kAtString, "encoded") + .Attr<std::string>("label_feature_name", UserOpAttrType::kAtString, "class/label") + .Attr<int32_t>("decode_buffer_size_per_thread", UserOpAttrType::kAtInt32, 8) + .Attr<int32_t>("num_decode_threads_per_machine", UserOpAttrType::kAtInt32, 0) + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + user_op::TensorDesc* image_tensor = ctx->TensorDesc4ArgNameAndIndex("image", 0); + user_op::TensorDesc* label_tensor = ctx->TensorDesc4ArgNameAndIndex("label", 0); + int32_t local_batch_size = ctx->Attr<int32_t>("batch_size"); + const SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + if (sbp.has_split_parallel() && parallel_num > 1) { + CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); + local_batch_size /= parallel_num; + } + *image_tensor->mut_shape() = Shape({local_batch_size}); + *image_tensor->mut_data_type() = DataType::kTensorBuffer; + *label_tensor->mut_shape() = Shape({local_batch_size}); + *label_tensor->mut_data_type() = DataType::kTensorBuffer; + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe<void>::Ok(); + }) + .SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe<void> { + ctx->BatchAxis4ArgNameAndIndex("image", 0)->set_value(0); + ctx->BatchAxis4ArgNameAndIndex("label", 0)->set_value(0); + return Maybe<void>::Ok(); + }) + .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); + CHECK(image_modifier != nullptr); + image_modifier->set_header_infered_before_compute(false); + user_op::OutputArgModifier* label_modifier = GetOutputArgModifierFn("label", 0); + CHECK(label_modifier != nullptr); + label_modifier->set_header_infered_before_compute(false); + }); + +} // namespace oneflow