diff --git a/oneflow/python/ops/user_data_ops.py b/oneflow/python/ops/user_data_ops.py
index 050f310c6a9321e3d6c6119ac08344919560f0e4..c0a78db89117cac336f3151334599dc5eec39d76 100644
--- a/oneflow/python/ops/user_data_ops.py
+++ b/oneflow/python/ops/user_data_ops.py
@@ -229,6 +229,74 @@ def CropMirrorNormalize(
     )
 
 
+@oneflow_export("image.random_crop", "image_random_crop")
+def api_image_random_crop(
+    input_blob: BlobDef,
+    num_attempts: int = 10,
+    seed: Optional[int] = None,
+    random_area: Sequence[float] = None,
+    random_aspect_ratio: Sequence[float] = None,
+    name: str = "ImageRandomCrop",
+) -> BlobDef:
+    assert isinstance(name, str)
+    if seed is not None:
+        assert name is not None
+    if random_area is None:
+        random_area = [0.08, 1.0]
+    if random_aspect_ratio is None:
+        random_aspect_ratio = [0.75, 1.333333]
+    module = flow.find_or_create_module(
+        name,
+        lambda: ImageRandomCropModule(
+            num_attempts=num_attempts,
+            random_seed=seed,
+            random_area=random_area,
+            random_aspect_ratio=random_aspect_ratio,
+            name=name,
+        ),
+    )
+    return module(input_blob)
+
+
+class ImageRandomCropModule(module_util.Module):
+    def __init__(
+        self,
+        num_attempts: int,
+        random_seed: Optional[int],
+        random_area: Sequence[float],
+        random_aspect_ratio: Sequence[float],
+        name: str,
+    ):
+        module_util.Module.__init__(self, name)
+        seed, has_seed = flow.random.gen_seed(random_seed)
+        self.op_module_builder = (
+            flow.user_op_module_builder("image_random_crop")
+            .InputSize("in", 1)
+            .Output("out")
+            .Attr("num_attempts", num_attempts)
+            .Attr("random_area", random_area)
+            .Attr("random_aspect_ratio", random_aspect_ratio)
+            .Attr("has_seed", has_seed)
+            .Attr("seed", seed)
+            .CheckAndComplete()
+        )
+        self.op_module_builder.user_op_module.InitOpKernel()
+
+    def forward(self, input: BlobDef):
+        if self.call_seq_no == 0:
+            name = self.module_name
+        else:
+            name = id_util.UniqueStr("ImageRandomCrop_")
+
+        return (
+            self.op_module_builder.OpName(name)
+            .Input("in", [input])
+            .Build()
+            .InferAndTryRun()
+            .SoleOutputBlob()
+        )
+
+
 @oneflow_export("random.CoinFlip", "random.coin_flip")
 def api_coin_flip(
     batch_size: int = 1,
diff --git a/oneflow/user/image/random_crop_generator.cpp b/oneflow/user/image/random_crop_generator.cpp
index 9c5eb517dc9ef7ffda02c2d9b7788f17f9a99b4d..adafe72a797313ac9a63635b28d4e53965dbfa6b 100644
--- a/oneflow/user/image/random_crop_generator.cpp
+++ b/oneflow/user/image/random_crop_generator.cpp
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 #include "oneflow/user/image/random_crop_generator.h"
-#include <iostream>
 
 namespace oneflow {
 
diff --git a/oneflow/user/kernels/image_preprocess_kernels.cpp b/oneflow/user/kernels/image_preprocess_kernels.cpp
index a0ba43ef58ac265efa0910aa8db6e4b73b102ca9..27fc4969758543c718e29990d03d4e504c130af4 100644
--- a/oneflow/user/kernels/image_preprocess_kernels.cpp
+++ b/oneflow/user/kernels/image_preprocess_kernels.cpp
@@ -20,6 +20,7 @@ limitations under the License.
 #include "oneflow/core/kernel/new_kernel_util.h"
 #include "oneflow/core/thread/thread_manager.h"
 #include "oneflow/user/image/image_util.h"
+#include "oneflow/user/kernels/random_crop_kernel_state.h"
 #include "oneflow/user/kernels/random_seed_util.h"
 
 namespace oneflow {
@@ -453,4 +454,70 @@ REGISTER_USER_KERNEL("coin_flip")
     .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)
                      & (user_op::HobDataType("out", 0) == DataType::kInt8));
 
+namespace {
+
+void ImageRandomCropImpl(const TensorBuffer* in_buffer, TensorBuffer* out_buffer,
+                         RandomCropGenerator* random_crop_gen) {
+  cv::Mat image = GenCvMat4ImageBuffer(*in_buffer);
+  int W = image.cols;
+  int H = image.rows;
+  cv::Mat image_roi;
+  CropWindow crop;
+  random_crop_gen->GenerateCropWindow({H, W}, &crop);
+  const int y = crop.anchor.At(0);
+  const int x = crop.anchor.At(1);
+  const int new_h = crop.shape.At(0);
+  const int new_w = crop.shape.At(1);
+  CHECK(new_w > 0 && new_w <= W);
+  CHECK(new_h > 0 && new_h <= H);
+  cv::Rect roi(x, y, new_w, new_h);
+  image(roi).copyTo(image_roi);
+  image = image_roi;
+  W = image.cols;
+  H = image.rows;
+
+  CHECK(image.isContinuous());
+  const int c = in_buffer->shape().At(2);
+  CHECK_EQ(c, image.channels());
+  Shape image_shape({H, W, c});
+  out_buffer->Resize(image_shape, in_buffer->data_type());
+  memcpy(out_buffer->mut_data<>(), image.ptr(), out_buffer->nbytes());
+}
+
+}  // namespace
+
+class ImageRandomCropKernel final : public user_op::OpKernel {
+ public:
+  ImageRandomCropKernel() = default;
+  ~ImageRandomCropKernel() override = default;
+
+  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
+      user_op::KernelInitContext* ctx) const override {
+    return CreateRandomCropKernelState(ctx);
+  }
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
+    auto* crop_window_generators = dynamic_cast<RandomCropKernelState*>(state);
+    CHECK_NOTNULL(crop_window_generators);
+    user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0);
+    int64_t record_num = out_blob->shape().elem_cnt();
+    CHECK(record_num > 0);
+    user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0);
+    CHECK_EQ(out_blob->shape(), in_blob->shape());
+    const TensorBuffer* in_buffers = in_blob->dptr<TensorBuffer>();
+    TensorBuffer* out_buffers = out_blob->mut_dptr<TensorBuffer>();
+    MultiThreadLoop(record_num, [&](size_t i) {
+      ImageRandomCropImpl(in_buffers + i, out_buffers + i, crop_window_generators->GetGenerator(i));
+    });
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+REGISTER_USER_KERNEL("image_random_crop")
+    .SetCreateFn<ImageRandomCropKernel>()
+    .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU)
+                     & (user_op::HobDataType("in", 0) == DataType::kTensorBuffer)
+                     & (user_op::HobDataType("out", 0) == DataType::kTensorBuffer));
+
 }  // namespace oneflow
diff --git a/oneflow/user/kernels/ofrecord_decoder_kernels.cpp b/oneflow/user/kernels/ofrecord_decoder_kernels.cpp
index 56256093d3713c770600a1e817d2b909abd348a5..b52235b49f2b11edb54f0472aac542716f11361b 100644
--- a/oneflow/user/kernels/ofrecord_decoder_kernels.cpp
+++ b/oneflow/user/kernels/ofrecord_decoder_kernels.cpp
@@ -22,6 +22,7 @@ limitations under the License.
 #include "oneflow/core/thread/thread_manager.h"
 #include "oneflow/user/image/random_crop_generator.h"
 #include "oneflow/user/image/image_util.h"
+#include "oneflow/user/kernels/random_crop_kernel_state.h"
 #include "oneflow/user/kernels/op_kernel_state_wrapper.h"
 #include "oneflow/user/kernels/random_seed_util.h"
 
@@ -172,24 +173,6 @@ void DecodeRandomCropImageFromOneRecord(const OFRecord& record, TensorBuffer* bu
   memcpy(buffer->mut_data<uint8_t>(), image.ptr(), image_shape.elem_cnt());
 }
 
-class RandCropGens final : public user_op::OpKernelState {
- public:
-  explicit RandCropGens(int32_t size) : gens_(size) {}
-  ~RandCropGens() = default;
-
-  RandomCropGenerator* Get(int32_t idx) { return gens_.at(idx).get(); }
-
-  void New(int32_t idx, AspectRatioRange aspect_ratio_range, AreaRange area_range, int64_t seed,
-           int32_t num_attempts) {
-    CHECK_LT(idx, gens_.size());
-    gens_.at(idx).reset(
-        new RandomCropGenerator(aspect_ratio_range, area_range, seed, num_attempts));
-  }
-
- private:
-  std::vector<std::shared_ptr<RandomCropGenerator>> gens_;
-};
-
 }  // namespace
 
 class OFRecordImageDecoderRandomCropKernel final : public user_op::OpKernel {
@@ -199,36 +182,13 @@ class OFRecordImageDecoderRandomCropKernel final : public user_op::OpKernel {
 
   std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
       user_op::KernelInitContext* ctx) const override {
-    int32_t num_attempts = ctx->Attr<int32_t>("num_attempts");
-    CHECK(num_attempts >= 1);
-    const std::vector<float>& random_aspect_ratio =
-        ctx->Attr<std::vector<float>>("random_aspect_ratio");
-    CHECK(random_aspect_ratio.size() == 2 && 0 < random_aspect_ratio.at(0)
-          && random_aspect_ratio.at(0) <= random_aspect_ratio.at(1));
-    const std::vector<float>& random_area = ctx->Attr<std::vector<float>>("random_area");
-    CHECK(random_area.size() == 2 && 0 < random_area.at(0)
-          && random_area.at(0) <= random_area.at(1));
-    const user_op::TensorDesc* out_tensor_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0);
-    CHECK(out_tensor_desc->shape().NumAxes() == 1);
-    int64_t batch_size = out_tensor_desc->shape().At(0);
-    CHECK(batch_size > 0);
-    int64_t seed = GetOpKernelRandomSeed(ctx);
-    std::seed_seq seq{seed};
-    std::vector<int> seeds(batch_size);
-    seq.generate(seeds.begin(), seeds.end());
-
-    std::shared_ptr<RandCropGens> crop_window_generators(new RandCropGens(batch_size));
-    for (int32_t i = 0; i < batch_size; ++i) {
-      crop_window_generators->New(i, {random_aspect_ratio.at(0), random_aspect_ratio.at(1)},
-                                  {random_area.at(0), random_area.at(1)}, seeds.at(i),
-                                  num_attempts);
-    }
-    return crop_window_generators;
+    return CreateRandomCropKernelState(ctx);
   }
 
  private:
   void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
-    auto* crop_window_generators = dynamic_cast<RandCropGens*>(state);
+    auto* crop_window_generators = dynamic_cast<RandomCropKernelState*>(state);
+    CHECK_NOTNULL(crop_window_generators);
     user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0);
     int64_t record_num = out_blob->shape().At(0);
     CHECK(record_num > 0);
@@ -242,7 +202,7 @@ class OFRecordImageDecoderRandomCropKernel final : public user_op::OpKernel {
     MultiThreadLoop(record_num, [&](size_t i) {
       const OFRecord& record = *(records + i);
       TensorBuffer* buffer = buffers + i;
-      RandomCropGenerator* gen = crop_window_generators->Get(i);
+      RandomCropGenerator* gen = crop_window_generators->GetGenerator(i);
       DecodeRandomCropImageFromOneRecord(record, buffer, name, color_space, gen);
     });
   }
diff --git a/oneflow/user/kernels/random_crop_kernel_state.cpp b/oneflow/user/kernels/random_crop_kernel_state.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9668f11c5a1886ca377d0cb71e2a0d590c57188b
--- /dev/null
+++ b/oneflow/user/kernels/random_crop_kernel_state.cpp
@@ -0,0 +1,38 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/user/kernels/random_seed_util.h"
+#include "oneflow/user/kernels/random_crop_kernel_state.h"
+
+namespace oneflow {
+
+std::shared_ptr<RandomCropKernelState> CreateRandomCropKernelState(
+    user_op::KernelInitContext* ctx) {
+  int32_t num_attempts = ctx->Attr<int32_t>("num_attempts");
+  CHECK(num_attempts >= 1);
+  const std::vector<float>& random_aspect_ratio =
+      ctx->Attr<std::vector<float>>("random_aspect_ratio");
+  CHECK(random_aspect_ratio.size() == 2 && 0 < random_aspect_ratio.at(0)
+        && random_aspect_ratio.at(0) <= random_aspect_ratio.at(1));
+  const std::vector<float>& random_area = ctx->Attr<std::vector<float>>("random_area");
+  CHECK(random_area.size() == 2 && 0 < random_area.at(0) && random_area.at(0) <= random_area.at(1));
+  const user_op::TensorDesc* out_tensor_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0);
+  return std::shared_ptr<RandomCropKernelState>(
+      new RandomCropKernelState(out_tensor_desc->shape().elem_cnt(), GetOpKernelRandomSeed(ctx),
+                                {random_aspect_ratio.at(0), random_aspect_ratio.at(1)},
+                                {random_area.at(0), random_area.at(1)}, num_attempts));
+}
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/random_crop_kernel_state.h b/oneflow/user/kernels/random_crop_kernel_state.h
new file mode 100644
index 0000000000000000000000000000000000000000..bfdecaae7bb1d9262e92126bc0890f7ff3dfdf17
--- /dev/null
+++ b/oneflow/user/kernels/random_crop_kernel_state.h
@@ -0,0 +1,49 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#ifndef ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_
+#define ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_
+
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/user/image/random_crop_generator.h"
+
+namespace oneflow {
+
+class RandomCropKernelState final : public user_op::OpKernelState {
+ public:
+  explicit RandomCropKernelState(int32_t size, int64_t seed, AspectRatioRange aspect_ratio_range,
+                                 AreaRange area_range, int32_t num_attempts)
+      : gens_(size) {
+    std::seed_seq seq{seed};
+    std::vector<int> seeds(size);
+    seq.generate(seeds.begin(), seeds.end());
+    for (int32_t i = 0; i < size; ++i) {
+      gens_.at(i).reset(
+          new RandomCropGenerator(aspect_ratio_range, area_range, seeds.at(i), num_attempts));
+    }
+  }
+  ~RandomCropKernelState() = default;
+
+  RandomCropGenerator* GetGenerator(int32_t idx) { return gens_.at(idx).get(); }
+
+ private:
+  std::vector<std::shared_ptr<RandomCropGenerator>> gens_;
+};
+
+std::shared_ptr<RandomCropKernelState> CreateRandomCropKernelState(user_op::KernelInitContext* ctx);
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_USER_KERNELS_RANDOM_CROP_KERNEL_STATE_H_
diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp
index 711f57dd913f4dd52d30a2a658abdc4fe5809379..626a44f0df22c4a89fa9cbfb39164f944b8072de 100644
--- a/oneflow/user/ops/image_preprocess_ops.cpp
+++ b/oneflow/user/ops/image_preprocess_ops.cpp
@@ -212,4 +212,28 @@ REGISTER_CPU_ONLY_USER_OP("coin_flip")
       return Maybe<void>::Ok();
     });
 
+REGISTER_CPU_ONLY_USER_OP("image_random_crop")
+    .Input("in")
+    .Output("out")
+    .Attr<int32_t>("num_attempts", UserOpAttrType::kAtInt32, 10)
+    .Attr<int64_t>("seed", UserOpAttrType::kAtInt64, -1)
+    .Attr<bool>("has_seed", UserOpAttrType::kAtBool, false)
+    .Attr<std::vector<float>>("random_area", UserOpAttrType::kAtListFloat, {0.08, 1.0})
+    .Attr<std::vector<float>>("random_aspect_ratio", UserOpAttrType::kAtListFloat, {0.75, 1.333333})
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      user_op::TensorDesc* in_tensor = ctx->TensorDesc4ArgNameAndIndex("in", 0);
+      user_op::TensorDesc* out_tensor = ctx->TensorDesc4ArgNameAndIndex("out", 0);
+      CHECK_OR_RETURN(in_tensor->data_type() == DataType::kTensorBuffer);
+      *out_tensor = *in_tensor;
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis)
+    .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,
+                            const user_op::UserOpConfWrapper&) {
+      user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0);
+      CHECK_NOTNULL(in_modifier);
+      in_modifier->set_requires_grad(false);
+    })
+    .SetBatchAxisInferFn(user_op::BatchAxisInferFnUtil::NaiveInferBatchAxis);
+
 }  // namespace oneflow