From dae895912b040e5bf241505afb3476d95b854e98 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=90=B4=E7=94=9F?= <1093287517@qq.com>
Date: Thu, 18 Aug 2022 20:17:56 +0800
Subject: [PATCH] add cpu finetune and infer

---
 official/cv/vgg16/README.md         | 103 ++++++++++++
 official/cv/vgg16/README_CN.md      | 105 ++++++++++++-
 official/cv/vgg16/cpu_config.yaml   |  58 +++++++
 official/cv/vgg16/eval.py           |  29 +++-
 official/cv/vgg16/fine_tune.py      | 232 ++++++++++++++++++++++++++++
 official/cv/vgg16/quick_start.py    | 104 +++++++++++++
 official/cv/vgg16/requirements.txt  |   9 +-
 official/cv/vgg16/src/data_split.py | 158 +++++++++++++++++++
 official/cv/vgg16/src/dataset.py    |  76 +++++++++
 9 files changed, 864 insertions(+), 10 deletions(-)
 create mode 100644 official/cv/vgg16/cpu_config.yaml
 create mode 100644 official/cv/vgg16/fine_tune.py
 create mode 100644 official/cv/vgg16/quick_start.py
 create mode 100644 official/cv/vgg16/src/data_split.py

diff --git a/official/cv/vgg16/README.md b/official/cv/vgg16/README.md
index 06383bb60..303d90c77 100644
--- a/official/cv/vgg16/README.md
+++ b/official/cv/vgg16/README.md
@@ -6,6 +6,7 @@
     - [Dataset](#dataset)
         - [Dataset used: CIFAR-10](#dataset-used-cifar-10)
         - [Dataset used: ImageNet2012](#dataset-used-imagenet2012)
+        - [Dataset used: Custom Dataset](#dataset-used-custom-dataset)
             - [Dataset organize way](#dataset-organize-way)
     - [Features](#features)
         - [Mixed Precision](#mixed-precision)
@@ -24,6 +25,10 @@
         - [Evaluation Process](#evaluation-process)
             - [Evaluation](#evaluation-1)
             - [ONNX Evaluation](#onnx-evaluation)
+        - [Migration process](#Migration process)
+            - [Dataset split](#Dataset split)
+            - [Migration](#Migration)
+            - [Model quick start](#Model quick start)
     - [Inference Process](#inference-process)
         - [Export MindIR](#export-mindir)
         - [Infer on Ascend310](#infer-on-ascend310)
@@ -66,6 +71,11 @@ Note that you can run the scripts based on the dataset mentioned in original pap
     - Data format: RGB images
     - Note: Data will be processed in src/dataset.py
 
+### Dataset used: Custom Dataset
+
+- Data format: RGB images
+    - Note: Data will be processed in src/data_split.py,Used to divide training and validation sets.
+
 #### Dataset organize way
 
   CIFAR-10
@@ -89,6 +99,21 @@ Note that you can run the scripts based on the dataset mentioned in original pap
   >   鈹斺攢validation_preprocess # evaluate dataset
   > ```
 
+  Custom Dataset
+
+  > Unzip the custom dataset to any path, the folder structure should contain the folder with the class name and all the pictures under this folder, as shown below:
+  >
+  > ```bash
+  > .
+  > 鈹斺攢dataset
+  >   鈹溾攢class_name1                # class name
+  >     鈹溾攢xx.jpg                    # All images corresponding to the class name
+  >     鈹溾攢 ...
+  >     鈹溾攢xx.jpg
+  >   鈹溾攢class_name2
+  >   鈹溾攢  ...
+  > ```
+
 ## [Features](#contents)
 
 ### Mixed Precision
@@ -141,6 +166,23 @@ bash scripts/run_distribute_train_gpu.sh [DATA_PATH] --dataset=[DATASET_TYPE]
 python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=[DATASET_TYPE] --data_dir=[DATA_PATH]  --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 &
 ```
 
+- Running on CPU
+
+```python
+
+# run dataset processing example
+python src/data_split.py --split_path [SPLIT_PATH]
+
+# run finetune example
+python tine_tune.py --config_path [YAML_CONFIG_PATH]
+
+# run eval example
+python eval.py --config_path [YAML_CONFIG_PATH]
+
+# quick start
+python quick_start.py --config_path [YAML_CONFIG_PATH]
+```
+
 - Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
 
 ```bash
@@ -300,6 +342,7 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=
         鈹�   鈹�   鈹溾攢鈹€ var_init.py                   // network parameter init method
         鈹�   鈹溾攢鈹€ crossentropy.py                   // loss calculation
         鈹�   鈹溾攢鈹€ dataset.py                        // creating dataset
+        鈹�   鈹溾攢鈹€ data_split.py                     // CPU dataset split script
         鈹�   鈹溾攢鈹€ linear_warmup.py                  // linear leanring rate
         鈹�   鈹溾攢鈹€ warmup_cosine_annealing_lr.py     // consine anealing learning rate
         鈹�   鈹溾攢鈹€ warmup_step_lr.py                 // step or multi step learning rate
@@ -307,11 +350,14 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=
         鈹溾攢鈹€ train.py                              // training script
         鈹溾攢鈹€ eval.py                               // evaluation script
         鈹溾攢鈹€ eval_onnx.py                          // ONNX evaluation script
+        鈹溾攢鈹€ finetune.py                           // CPU transfer script
+        鈹溾攢鈹€ quick_start.py                        // CPU quick start script
         鈹溾攢鈹€ postprocess.py                        // postprocess script
         鈹溾攢鈹€ preprocess.py                         // preprocess script
         鈹溾攢鈹€ mindspore_hub_conf.py                 // mindspore_hub_conf script
         鈹溾攢鈹€ cifar10_config.yaml                   // Configurations for cifar10
         鈹溾攢鈹€ imagenet2012_config.yaml              // Configurations for imagenet2012
+        鈹溾攢鈹€ cpu_config.yaml                       // Configurations for CPU transfer
         鈹溾攢鈹€ export.py                             // model convert script
         鈹斺攢鈹€ requirements.txt                      // requirements
 ```
@@ -414,6 +460,29 @@ initialize_mode: "KaimingNormal"    # conv2d init mode
 has_dropout: True                   # whether using Dropout layer
 ```
 
+- config for vgg16, custom dataset
+
+```bash
+num_classes: 5                    # number of dataset categories
+lr: 0.001                         # learning rate
+batch_size: 64                    # batch size of input tensor
+num_epoch: 10                     # number of training epochs
+momentum: 0.9                     # momentum
+pad_mode: 'pad'                   # pad mode for conv2d
+padding: 0                        # padding value for conv2d
+has_bias: False                   # whether has bias in conv2d
+batch_norm: False                 # whether has batch_norm in conv2d
+initialize_mode: "KaimingNormal"  # conv2d init mode
+has_dropout: True                 # whether using Dropout layer
+ckpt_file: "./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt" # The path to the pretrained weights file used by the migration
+save_file: "./vgg16.ckpt"         # Weight file path saved after migration
+train_path: "./datasets/train/"   # Migration train set path
+eval_path: "./datasets/test/"     # Migration valid set path
+split_path: "./datasets/"         # Migration dataset path
+infer_ckpt_path: "./vgg16.ckpt"   # Weight file path used by CPU inference
+
+```
+
 ### [Training Process](#contents)
 
 #### Training
@@ -538,6 +607,40 @@ top-1 accuracy: 0.7332
 top-5 accuracy: 0.9149
 ```
 
+## Migration process
+
+### Dataset split
+
+- The data set division process is as follows, the /train and /test folders will be generated in the dataset directory, and the training and validation set images will be saved.
+
+```bash
+python src/data_split.py --split_path /dir_to_code/{SPLIT_PATH}
+```
+
+### Migration
+
+- The migration process is as follows. The pre-training weight file needs to be downloaded [(https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt)](https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt) to ./vgg16 folder. After the training is completed, the file is saved as ./vgg16.ckpt by default.
+
+```bash
+python fine_tune.py --config_path /dir_to_code/cpu_config.yaml
+```
+
+### Eval
+
+- The migration process is as follows, you need to specify the weight file to be migrated (default is ./vgg16.ckpt).
+
+```bash
+python eval.py --config_path /dir_to_code/cpu_config.yaml
+```
+
+### Model quick start
+
+- The quick start process is as follows, you need to specify the weight file path and dataset path after training.
+
+```bash
+python quick_start.py --config_path /dir_to_code/cpu_config.yaml
+```
+
 ## Inference Process
 
 ### [Export MindIR](#contents)
diff --git a/official/cv/vgg16/README_CN.md b/official/cv/vgg16/README_CN.md
index deaaa6602..355a3ea16 100644
--- a/official/cv/vgg16/README_CN.md
+++ b/official/cv/vgg16/README_CN.md
@@ -8,6 +8,7 @@
     - [鏁版嵁闆哴(#鏁版嵁闆�)
         - [浣跨敤鐨勬暟鎹泦锛欳IFAR-10](#浣跨敤鐨勬暟鎹泦cifar-10)
         - [浣跨敤鐨勬暟鎹泦锛欼mageNet2012](#浣跨敤鐨勬暟鎹泦imagenet2012)
+        - [浣跨敤鐨勬暟鎹泦锛氳嚜瀹氫箟鏁版嵁闆哴(#浣跨敤鐨勬暟鎹泦锛氳嚜瀹氫箟鏁版嵁闆�)
         - [鏁版嵁闆嗙粍缁囨柟寮廬(#鏁版嵁闆嗙粍缁囨柟寮�)
     - [鐗规€(#鐗规€�)
         - [娣峰悎绮惧害](#娣峰悎绮惧害)
@@ -25,6 +26,10 @@
                 - [GPU澶勭悊鍣ㄧ幆澧冭繍琛孷GG16](#gpu澶勭悊鍣ㄧ幆澧冭繍琛寁gg16)
         - [璇勪及杩囩▼](#璇勪及杩囩▼)
             - [璇勪及](#璇勪及-1)
+        - [杩佺Щ杩囩▼](#杩佺Щ杩囩▼)
+            - [鏁版嵁闆嗗垝鍒哴(#鏁版嵁闆嗗垝鍒�)
+            - [鏁版嵁闆嗚縼绉籡(#鏁版嵁闆嗚縼绉�)
+            - [quick start](#quick start)
     - [鎺ㄧ悊杩囩▼](#鎺ㄧ悊杩囩▼)
         - [瀵煎嚭MindIR](#瀵煎嚭mindir)
         - [鍦ˋscend310鎵ц鎺ㄧ悊](#鍦╝scend310鎵ц鎺ㄧ悊)
@@ -67,6 +72,11 @@ VGG 16缃戠粶涓昏鐢卞嚑涓熀鏈ā鍧楋紙鍖呮嫭鍗风Н灞傚拰姹犲寲灞傦級鍜屼笁
 - 鏁版嵁鏍煎紡锛歊GB鍥惧儚銆�
     - 娉細鏁版嵁鍦╯rc/dataset.py涓鐞嗐€�
 
+### 浣跨敤鐨勬暟鎹泦锛氳嚜瀹氫箟鏁版嵁闆�
+
+- 鏁版嵁鏍煎紡锛歊GB鍥惧儚銆�
+    - 娉細娉細鏁版嵁鍦╯rc/data_split.py涓鐞�,鐢ㄦ潵鍒掑垎璁粌銆侀獙璇侀泦銆�
+
 ### 鏁版嵁闆嗙粍缁囨柟寮�
 
   CIFAR-10
@@ -90,6 +100,21 @@ VGG 16缃戠粶涓昏鐢卞嚑涓熀鏈ā鍧楋紙鍖呮嫭鍗风Н灞傚拰姹犲寲灞傦級鍜屼笁
   >   鈹斺攢validation_preprocess # 璇勪及鏁版嵁闆�
   > ```
 
+  鑷畾涔夋暟鎹泦
+
+  > 灏嗚嚜瀹氫箟鏁版嵁闆嗚В鍘嬪埌浠绘剰璺緞锛屾枃浠跺す缁撴瀯搴斿寘鍚被鍚嶇殑鏂囦欢澶逛互鍙婂湪姝ゆ枃浠跺す涓嬬殑鎵€鏈夊浘鐗囷紝濡備笅鎵€绀猴細
+  >
+  > ```bash
+  > .
+  > 鈹斺攢dataset
+  > 鈹溾攢class_name1                # 绫诲悕
+  >  鈹溾攢xx.jpg                    # 瀵瑰簲绫诲悕鐨勬墍鏈夊浘鐗�
+  >  鈹溾攢 ...
+  >  鈹溾攢xx.jpg
+  > 鈹溾攢class_name2
+  > 鈹溾攢  ...
+  > ```
+
 ## 鐗规€�
 
 ### 娣峰悎绮惧害
@@ -142,6 +167,23 @@ bash scripts/run_distribute_train_gpu.sh [DATA_PATH] --dataset=[DATASET_TYPE]
 python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=[DATASET_TYPE] --data_dir=[DATA_PATH]  --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 &
 ```
 
+- CPU澶勭悊鍣ㄧ幆澧冭繍琛�
+
+```python
+
+# 鏁版嵁闆嗗鐞嗗疄渚�
+python src/data_split.py --split_path [SPLIT_PATH]
+
+# 杩佺Щ绀轰緥
+python fine_tune.py --config_path [YAML_CONFIG_PATH]
+
+# 璇勪及绀轰緥
+python eval.py --config_path [YAML_CONFIG_PATH]
+
+# quick start绀轰緥
+python quick_start.py --config_path [YAML_CONFIG_PATH]
+```
+
 - 鍦� ModelArts 杩涜璁粌 (濡傛灉浣犳兂鍦╩odelarts涓婅繍琛岋紝鍙互鍙傝€冧互涓嬫枃妗� [modelarts](https://support.huaweicloud.com/modelarts/))
 
 ```bash
@@ -303,14 +345,18 @@ python eval.py --config_path=[YAML_CONFIG_PATH] --device_target="GPU" --dataset=
         鈹�   鈹溾攢鈹€ linear_warmup.py                  // 绾挎€у涔犵巼
         鈹�   鈹溾攢鈹€ warmup_cosine_annealing_lr.py     // 浣欏鸡閫€鐏涔犵巼
         鈹�   鈹溾攢鈹€ warmup_step_lr.py                 // 鍗曟鎴栧娆¤凯浠e涔犵巼
-        鈹�   鈹溾攢鈹€vgg.py                             // VGG鏋舵瀯
+        鈹�   鈹溾攢鈹€ vgg.py                            // VGG鏋舵瀯
+        鈹�   鈹溾攢鈹€ data_split.py                     // CPU杩佺Щ鏁版嵁闆嗗垝鍒嗚剼鏈�
         鈹溾攢鈹€ train.py                              // 璁粌鑴氭湰
         鈹溾攢鈹€ eval.py                               // 璇勪及鑴氭湰
+        鈹溾攢鈹€ finetune.py                           // CPU杩佺Щ鑴氭湰
+        鈹溾攢鈹€ quick_start.py                        // CPU quick start鑴氭湰
         鈹溾攢鈹€ postprocess.py                        // 鍚庡鐞嗚剼鏈�
         鈹溾攢鈹€ preprocess.py                         // 棰勫鐞嗚剼鏈�
         鈹溾攢鈹€ mindspore_hub_conf.py                 // mindspore hub 鑴氭湰
         鈹溾攢鈹€ cifar10_config.yaml                   // cifar10 閰嶇疆鏂囦欢
         鈹溾攢鈹€ imagenet2012_config.yaml              // imagenet2012 閰嶇疆鏂囦欢
+        鈹溾攢鈹€ cpu_config.yaml                       // CPU杩佺Щ閰嶇疆鏂囦欢
         鈹溾攢鈹€ export.py                             // 妯″瀷鏍煎紡杞崲鑴氭湰
         鈹斺攢鈹€ requirements.txt                      // requirements
 ```
@@ -413,6 +459,29 @@ initialize_mode: "KaimingNormal"    # conv2d init妯″紡
 has_dropout: True                   # 鏄惁浣跨敤Dropout灞�
 ```
 
+- 閰嶇疆VGG16锛岃嚜瀹氫箟鏁版嵁闆�
+
+```bash
+num_classes: 5                    # 鏁版嵁闆嗙被鍒暟
+lr: 0.001                         # 瀛︿範鐜�
+batch_size: 64                    # 杈撳叆寮犻噺鎵规澶у皬
+num_epoch: 10                     # 璁粌杞暟
+momentum: 0.9                     # 鍔ㄩ噺
+pad_mode: 'pad'                   # conv2d鐨勫~鍏呮柟寮�
+padding: 0                        # conv2d鐨勫~鍏呭€�
+has_bias: False                   # conv2d鏄惁鏈夊亸宸�
+batch_norm: False                 # 鍦╟onv2d涓槸鍚︽湁batch_norm
+initialize_mode: "KaimingNormal"  # conv2d init妯″紡
+has_dropout: True                 # 鏄惁浣跨敤Dropout灞�
+ckpt_file: "./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt" # 杩佺Щ浣跨敤鐨勯璁粌鏉冮噸鏂囦欢璺緞
+save_file: "./vgg16.ckpt"         # 杩佺Щ鍚庝繚瀛樼殑鏉冮噸鏂囦欢璺緞
+train_path: "./datasets/train/"   # 杩佺Щ鏁版嵁闆嗚缁冮泦璺緞
+eval_path: "./datasets/test/"     # 杩佺Щ鏁版嵁闆嗛獙璇侀泦璺緞
+split_path: "./datasets/"         # 杩佺Щ鏁版嵁闆嗚矾寰�
+infer_ckpt_path: "./vgg16.ckpt"   # CPU鎺ㄧ悊浣跨敤鐨勬潈閲嶆枃浠惰矾寰�
+
+```
+
 ### 璁粌杩囩▼
 
 #### 璁粌
@@ -504,6 +573,40 @@ after allreduce eval: top1_correct=36636, tot=50000, acc=73.27%
 after allreduce eval: top5_correct=45582, tot=50000, acc=91.16%
 ```
 
+## 杩佺Щ杩囩▼
+
+### 鏁版嵁闆嗗垝鍒�
+
+- 鏁版嵁闆嗗垝鍒嗚繃绋嬪涓嬶紝浼氬湪鏁版嵁闆嗙洰褰曚笅鐢熸垚/train鍜�/test鏂囦欢澶癸紝淇濆瓨璁粌銆侀獙璇侀泦鍥剧墖銆�
+
+```bash
+python src/data_split.py --split_path /dir_to_code/{SPLIT_PATH}
+```
+
+### 鏁版嵁闆嗚縼绉�
+
+- 杩佺Щ杩囩▼濡備笅锛岄渶瑕佸皢棰勮缁冩潈閲嶆枃浠禰(https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt)](https://download.mindspore.cn/models/r1.7/vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt)涓嬭浇鍒皏gg16鏂囦欢澶逛笅锛岃缁冨畬鎴愬悗榛樿灏嗘枃浠朵繚瀛樻垚./vgg16.ckpt銆�
+
+```bash
+python fine_tune.py --config_path /dir_to_code/cpu_config.yaml
+```
+
+### 鏁版嵁闆嗚瘎浼�
+
+- 杩佺Щ杩囩▼濡備笅锛岄渶瑕佹寚瀹氳縼绉诲畬鎴愮殑鏉冮噸鏂囦欢(榛樿鏄�./vgg16.ckpt)銆�
+
+```bash
+python eval.py --config_path /dir_to_code/cpu_config.yaml
+```
+
+### quick start
+
+- quick start杩囩▼濡備笅锛岄渶瑕佹寚瀹氳缁冨畬鎴愮殑鏉冮噸鏂囦欢璺緞鍜屾暟鎹泦璺緞銆�
+
+```bash
+python quick_start.py --config_path /dir_to_code/cpu_config.yaml
+```
+
 ## 鎺ㄧ悊杩囩▼
 
 ### [瀵煎嚭MindIR](#contents)
diff --git a/official/cv/vgg16/cpu_config.yaml b/official/cv/vgg16/cpu_config.yaml
new file mode 100644
index 000000000..c8d076c30
--- /dev/null
+++ b/official/cv/vgg16/cpu_config.yaml
@@ -0,0 +1,58 @@
+# ==============================================================================
+# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
+enable_modelarts: False
+# device options
+device_target: "CPU"
+
+# dataset options
+train_path: "./datasets/train/"
+eval_path: "./datasets/test/"
+split_path: "./datasets/"
+
+# finetune options
+dataset: 'custom'
+image_size: '224,224'
+log_path: "outputs/"
+num_classes: 5
+lr: 0.001
+batch_size: 64
+num_epochs: 10
+momentum: 0.9
+ckpt_file: "./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt"
+save_file: "./vgg16.ckpt"
+initialize_mode: "KaimingNormal"
+pad_mode: 'pad'
+padding: 1
+has_bias: False
+batch_norm: False
+has_dropout: True
+
+# infer options
+pre_trained: "./vgg16.ckpt"
+
+
+
+---
+
+# Help description for each configuration
+
+# device options
+device_target: "device where the code will be implemented."
+
+# dataset options
+train_path: "the training dataset path"
+eval_path: "the eval dataset path"
+split_path: "the original dataset path to split"
+
+# finetune options
+num_classes: "num of class in dataset"
+lr: "learning rate"
+batch_size: "batch size"
+num_epochs: "num of train epochs"
+momentum: "num of train momentum"
+ckpt_file: "the .ckpt file used for finetune"
+save_file: "the .ckpt for saving"
+
+
+# infer options
+pre_trained: "the .ckpt file path to infer"
diff --git a/official/cv/vgg16/eval.py b/official/cv/vgg16/eval.py
index d62a04e6f..f4cb72126 100644
--- a/official/cv/vgg16/eval.py
+++ b/official/cv/vgg16/eval.py
@@ -29,16 +29,22 @@ from mindspore.ops import functional as F
 from mindspore.common import dtype as mstype
 
 from src.utils.logging import get_logger
-from src.vgg import vgg16
+from src.vgg import vgg16, Vgg
 from src.dataset import vgg_create_dataset
 from src.dataset import classification_dataset
+from src.dataset import create_dataset
 
 from model_utils.moxing_adapter import config
 from model_utils.moxing_adapter import moxing_wrapper
 from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
 
+from model_utils.config import get_config
+from fine_tune import DenseHead, cfg
+
+
 class ParameterReduce(nn.Cell):
     """ParameterReduce"""
+
     def __init__(self):
         super(ParameterReduce, self).__init__()
         self.cast = P.Cast()
@@ -61,6 +67,7 @@ def get_top5_acc(top5_arg, gt_class):
 
 def modelarts_pre_process():
     '''modelarts pre process function.'''
+
     def unzip(zip_file, save_dir):
         import zipfile
         s_time = time.time()
@@ -132,7 +139,6 @@ def run_eval():
     config.rank = get_rank_id()
     config.group_size = get_device_num()
 
-
     _enable_graph_kernel = config.device_target == "GPU"
     context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel,
                         device_target=config.device_target, save_graphs=False)
@@ -168,11 +174,24 @@ def run_eval():
             config.models = sorted(models, key=f)
         else:
             config.models = [config.pre_trained,]
-
         for model in config.models:
-            dataset = classification_dataset(config.data_dir, config.image_size, config.per_batch_size, mode='eval')
+            if config.dataset == "custom":
+                dataset = create_dataset(dataset_path=config.eval_path, do_train=False,
+                                         batch_size=config.batch_size,
+                                         eval_image_size=config.image_size,
+                                         enable_cache=False)
+                model_config = get_config()
+                network = Vgg(cfg['16'], num_classes=1000, args=model_config, batch_norm=True)
+
+                # replace head
+                src_head = network.classifier[6]
+                in_channels = src_head.in_channels
+                head = DenseHead(in_channels, config.num_classes)
+                network.classifier[6] = head
+            else:
+                dataset = classification_dataset(config.data_dir, config.image_size, config.per_batch_size, mode='eval')
+                network = vgg16(config.num_classes, config, phase="test")
             eval_dataloader = dataset.create_tuple_iterator(output_numpy=True, num_epochs=1)
-            network = vgg16(config.num_classes, config, phase="test")
 
             # pre_trained
             load_param_into_net(network, load_checkpoint(model))
diff --git a/official/cv/vgg16/fine_tune.py b/official/cv/vgg16/fine_tune.py
new file mode 100644
index 000000000..eb870088e
--- /dev/null
+++ b/official/cv/vgg16/fine_tune.py
@@ -0,0 +1,232 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import mindspore as ms
+import mindspore.nn as nn
+from mindspore.train import Model
+from mindspore.train.callback import LossMonitor, TimeMonitor
+from model_utils.config import get_config
+from src.vgg import Vgg
+from src.dataset import create_dataset
+
+ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU", save_graphs=False)
+ms.set_seed(21)
+
+
+def import_data(train_dataset_path="./datasets/train/", eval_dataset_path="./datasets/test/", batch_size=32):
+    """
+        Read the dataset
+
+        Args:
+            train_dataset_path(string): the path of train dataset.
+            eval_dataset_path(string): the path of eval dataset.
+            batch_size(int): the batch size of dataset. Default: 32
+
+        Returns:
+            dataset_train: the train dataset
+            dataset_val:   the  val  dataset
+    """
+
+    dataset_train = create_dataset(dataset_path=train_dataset_path, do_train=True,
+                                   batch_size=batch_size, train_image_size=224,
+                                   eval_image_size=224,
+                                   enable_cache=False, cache_session_id=None)
+    dataset_val = create_dataset(dataset_path=eval_dataset_path, do_train=False,
+                                 batch_size=batch_size, train_image_size=224,
+                                 eval_image_size=224,
+                                 enable_cache=False, cache_session_id=None)
+    # print sample data/label
+    data = next(dataset_train.create_dict_iterator())
+    images = data["image"]
+    labels = data["label"]
+    print("Tensor of image", images.shape)  # Tensor of image (18, 3, 224, 224)
+    print("Labels:", labels)  # Labels: [1 0 0 0 1 1 1 1 0 0 1 1 1 0 1 0 0 0]
+
+    return dataset_train, dataset_val
+
+
+# define head layer
+class DenseHead(nn.Cell):
+    def __init__(self, input_channel, num_classes):
+        super(DenseHead, self).__init__()
+        self.dense = nn.Dense(input_channel, num_classes)
+
+    def construct(self, x):
+        return self.dense(x)
+
+
+def init_weight(net, param_dict):
+    """init_weight"""
+
+    # if config.pre_trained:
+    has_trained_epoch = 0
+    has_trained_step = 0
+    if param_dict:
+        if param_dict.get("epoch_num") and param_dict.get("step_num"):
+            has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy())
+            has_trained_step = int(param_dict["step_num"].data.asnumpy())
+
+        ms.load_param_into_net(net, param_dict)
+    print("has_trained_epoch:", has_trained_epoch)
+    print("has_trained_step:", has_trained_step)
+    return has_trained_epoch, has_trained_step
+
+
+def eval_net(model_config, checkpoint_path='./vgg16.ckpt',
+             train_dataset_path="./datasets/train/",
+             eval_dataset_path="./datasets/test/",
+             batch_size=32):
+    """
+      eval the accuracy of vgg16 for flower dataset
+
+      Args:
+
+          model_config(Config in './model_utils/config.py'): vgg16 config
+          checkpoint_path(string): model checkout path(end with '.ckpt'). Default: './vgg16.ckpt'
+          train_dataset_path: the train dataset path. Default: "./datasets/train/"
+          eval_dataset_path:  the eval dataset path.  Default: "./datasets/test/"
+          batch_size: the batch size of dataset. Default: 32
+      Returns:
+          None
+      """
+
+    # define val dataset and model
+    _, data_val = import_data(train_dataset_path=train_dataset_path,
+                              eval_dataset_path=eval_dataset_path, batch_size=batch_size)
+    net = Vgg(cfg['16'], num_classes=1000, args=model_config, batch_norm=True)
+
+    # replace head
+    src_head = net.classifier[6]
+    in_channels = src_head.in_channels
+    head = DenseHead(in_channels, 5)
+    net.classifier[6] = head
+
+    # load checkpoint
+    param_dict = ms.load_checkpoint(checkpoint_path)
+    ms.load_param_into_net(net, param_dict)
+    net.set_train(False)
+
+    # define loss
+    from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
+    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+
+    # define model
+    model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
+
+    # eval step
+    res = model.eval(data_val)
+
+    # show accuracy
+    print("result:", res)
+
+
+cfg = {
+    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+def finetune_train(model_config,
+                   finetune_checkpoint_path=
+                   './vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt',
+                   save_checkpoint_path="./vgg16.ckpt",
+                   train_dataset_path="./datasets/train/",
+                   eval_dataset_path="./datasets/test/",
+                   class_num=5,
+                   num_epochs=10,
+                   learning_rate=0.001,
+                   momentum=0.9,
+                   batch_size=32
+                   ):
+    """
+         finetune the flower dataset for vgg16
+
+         Args:
+             model_config(Config in './model_utils/config.py'): vgg16 config
+             class_num(int): the num of class for dataset. Default: 5
+             num_epochs(int): the training epoch. Default: 10
+             save_checkpoint_path(string): model checkout path for save(end with '.ckpt'). Default: ./vgg16.ckpt
+             train_dataset_path(string): the train dataset path. Default: "./datasets/train/"
+             eval_dataset_path(string):  the eval dataset path.  Default: "./datasets/test/"
+             finetune_checkpoint_path(string): model checkout path for initialize
+                       Default: ./vgg16_bn_ascend_v170_imagenet2012_official_cv_top1acc74.33_top5acc92.1.ckpt
+             learning_rate: the finetune learning rate
+             momentum: the finetune momentum
+             batch_size: the batch size of dataset. Default: 32
+         Returns:
+             None
+    """
+
+    # read train/val dataset
+    dataset_train, _ = import_data(train_dataset_path=train_dataset_path,
+                                   eval_dataset_path=eval_dataset_path,
+                                   batch_size=batch_size)
+
+    ckpt_param_dict = ms.load_checkpoint(finetune_checkpoint_path)
+    net = Vgg(cfg['16'], num_classes=1000, args=model_config, batch_norm=True)
+    init_weight(net=net, param_dict=ckpt_param_dict)
+    print("net parameter:")
+    for param in net.get_parameters():
+        print("param:", param)
+
+    # replace head
+    src_head = net.classifier[6]
+    print("classifier.6.bias:", net.classifier[6])
+    in_channels = src_head.in_channels
+    head = DenseHead(in_channels, class_num)
+    net.classifier[6] = head
+
+    # freeze the param except last layer
+    for param in net.get_parameters():
+        if param.name not in ["classifier.6.dense.weight", "classifier.6.dense.bias"]:
+            param.requires_grad = False
+
+    # define optimizer and loss
+    opt = nn.Momentum(params=net.trainable_params(), learning_rate=learning_rate, momentum=momentum)
+    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+
+    # define model
+    model = Model(net, loss, opt, metrics={"Accuracy": nn.Accuracy()})
+
+    # define callbacks
+    batch_num = dataset_train.get_dataset_size()
+    time_cb = TimeMonitor(data_size=batch_num)
+    loss_cb = LossMonitor()
+    callbacks = [time_cb, loss_cb]
+
+    # do training
+    model.train(num_epochs, dataset_train, callbacks=callbacks)
+    ms.save_checkpoint(net, save_checkpoint_path)
+
+
+if __name__ == '__main__':
+    config = get_config()
+    print("config:", config)
+    # finetune
+    finetune_train(config,
+                   finetune_checkpoint_path=config.ckpt_file,
+                   save_checkpoint_path=config.save_file, train_dataset_path=config.train_path,
+                   eval_dataset_path=config.eval_path, num_epochs=config.num_epochs, class_num=config.num_classes,
+                   learning_rate=config.lr,
+                   momentum=config.momentum,
+                   batch_size=config.batch_size
+                   )
+
+    # eval
+    eval_net(config, checkpoint_path=config.save_file, train_dataset_path=config.train_path,
+             eval_dataset_path=config.eval_path,
+             batch_size=config.batch_size)  # 0.8505434782608695
diff --git a/official/cv/vgg16/quick_start.py b/official/cv/vgg16/quick_start.py
new file mode 100644
index 000000000..672db7650
--- /dev/null
+++ b/official/cv/vgg16/quick_start.py
@@ -0,0 +1,104 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""inference for CPU"""
+import matplotlib.pyplot as plt
+import numpy as np
+from mindspore import Tensor, load_checkpoint, load_param_into_net, nn
+from mindspore.train import Model
+from fine_tune import import_data
+from model_utils.moxing_adapter import config
+from src.vgg import Vgg
+
+# class_name for dataset
+class_name = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
+
+cfg = {
+    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+# define head layer
+class DenseHead(nn.Cell):
+    def __init__(self, input_channel, num_classes):
+        super(DenseHead, self).__init__()
+        self.dense = nn.Dense(input_channel, num_classes)
+
+    def construct(self, x):
+        return self.dense(x)
+
+
+def visualize_model(best_ckpt_path, val_ds, num_classes):
+    """
+             visualize model
+
+             Args:
+                 val_ds: eval dataset
+                 best_ckpt_path(string): the .ckpt file for model to infer
+                 num_classes(int): the class num
+
+             Returns:
+                 None
+        """
+
+    net = Vgg(cfg['16'], num_classes=1000, args=config, batch_norm=True)
+
+    # replace head
+    src_head = net.classifier[6]
+    in_channels = src_head.in_channels
+    head = DenseHead(in_channels, num_classes)
+    net.classifier[6] = head
+
+    # load param
+    param_dict = load_checkpoint(best_ckpt_path)
+    load_param_into_net(net, param_dict)
+
+    net.set_train(False)
+    model = Model(net)
+
+    # load some image in eval dataset for prediction
+    for i in range(5):
+        next(val_ds.create_dict_iterator())
+    data = next(val_ds.create_dict_iterator())
+    images = data["image"].asnumpy()
+    labels = data["label"].asnumpy()
+
+    output = model.predict(Tensor(data['image']))
+    pred = np.argmax(output.asnumpy(), axis=1)
+    print("\nAccuracy:", (pred == labels).sum() / len(labels))
+
+    # show image
+    plt.figure(figsize=(15, 7))
+    for i in range(len(labels)):
+        plt.subplot(4, 8, i + 1)
+        # show blue color if right锛宱therwise show red color
+        color = 'blue' if pred[i] == labels[i] else 'red'
+        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
+        picture_show = np.transpose(images[i], (1, 2, 0))
+        mean = np.array([0.485, 0.456, 0.406])
+        std = np.array([0.229, 0.224, 0.225])
+        picture_show = std * picture_show + mean
+        picture_show = np.clip(picture_show, 0, 1)
+        plt.imshow(picture_show)
+        plt.axis('off')
+    plt.show()
+
+
+if __name__ == '__main__':
+    _, dataset_val = import_data(train_dataset_path=config.train_path, eval_dataset_path=config.eval_path)
+
+    visualize_model(config.pre_trained, dataset_val, config.num_classes)
diff --git a/official/cv/vgg16/requirements.txt b/official/cv/vgg16/requirements.txt
index f77643b0f..186aee6f3 100644
--- a/official/cv/vgg16/requirements.txt
+++ b/official/cv/vgg16/requirements.txt
@@ -1,4 +1,5 @@
-numpy
-onnxruntime-gpu
-pillow
-pyyaml
+numpy
+onnxruntime-gpu
+pillow
+pyyaml
+matplotlib
diff --git a/official/cv/vgg16/src/data_split.py b/official/cv/vgg16/src/data_split.py
new file mode 100644
index 000000000..c851166ac
--- /dev/null
+++ b/official/cv/vgg16/src/data_split.py
@@ -0,0 +1,158 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""split for CPU dataset"""
+import os
+import shutil
+import multiprocessing
+import mindspore as ms
+import mindspore.dataset as ds
+
+
+def get_num_parallel_workers(num_parallel_workers):
+    """
+    Get num_parallel_workers used in dataset operations.
+    If num_parallel_workers > the real CPU cores number, set num_parallel_workers = the real CPU cores number.
+    """
+    cores = multiprocessing.cpu_count()
+    if isinstance(num_parallel_workers, int):
+        if cores < num_parallel_workers:
+            print("The num_parallel_workers {} is set too large, now set it {}".format(num_parallel_workers, cores))
+            num_parallel_workers = cores
+    else:
+        print("The num_parallel_workers {} is invalid, now set it {}".format(num_parallel_workers, min(cores, 8)))
+        num_parallel_workers = min(cores, 8)
+    return num_parallel_workers
+
+
+def create_dataset(dataset_path, do_train, batch_size=32, train_image_size=224, eval_image_size=224,
+                   enable_cache=False, cache_session_id=None):
+    """
+       create a train or eval flower dataset for vgg16
+
+    Args:
+        dataset_path(string): the path of dataset.
+        do_train(bool): whether dataset is used for train or eval.
+        batch_size(int): the batch size of dataset. Default: 32
+        enable_cache(bool): whether tensor caching service is used for eval. Default: False
+        cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None
+
+    Returns:
+        dataset
+    """
+
+    ds.config.set_prefetch_size(64)
+    data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(12), shuffle=True)
+
+    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
+    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
+
+    # define map operations
+    if do_train:
+        trans = [
+            ds.vision.RandomCropDecodeResize(train_image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
+            ds.vision.RandomHorizontalFlip(prob=0.5)
+        ]
+    else:
+        trans = [
+            ds.vision.Decode(),
+            ds.vision.Resize(256),
+            ds.vision.CenterCrop(eval_image_size)
+        ]
+    trans_norm = [ds.vision.Normalize(mean=mean, std=std), ds.vision.HWC2CHW()]
+
+    type_cast_op = ds.transforms.TypeCast(ms.int32)
+    trans_work_num = 24
+    data_set = data_set.map(operations=trans, input_columns="image",
+                            num_parallel_workers=get_num_parallel_workers(trans_work_num))
+    data_set = data_set.map(operations=trans_norm, input_columns="image",
+                            num_parallel_workers=get_num_parallel_workers(12))
+    # only enable cache for eval
+    if do_train:
+        enable_cache = False
+    if enable_cache:
+        if not cache_session_id:
+            raise ValueError("A cache session_id must be provided to use cache.")
+        eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
+        data_set = data_set.map(operations=type_cast_op, input_columns="label",
+                                num_parallel_workers=get_num_parallel_workers(12),
+                                cache=eval_cache)
+    else:
+        data_set = data_set.map(operations=type_cast_op, input_columns="label",
+                                num_parallel_workers=get_num_parallel_workers(12))
+
+    # apply batch operations
+    data_set = data_set.batch(batch_size, drop_remainder=True)
+
+    return data_set
+
+
+def generate_data(path="./"):
+    dirs = []
+    abs_path = None
+    for abs_path, j, _ in os.walk(path):
+        print("abs_path:", abs_path)
+        if j:
+            dirs.append(j)
+    print(dirs)
+
+    train_folder = os.path.exists(path + 'train')
+    if not train_folder:
+        os.makedirs(path + 'train')
+    test_folder = os.path.exists(path + 'test')
+    if not test_folder:
+        os.makedirs(path + 'test')
+
+    for class_dir in dirs[0]:
+        print("path", path)
+        print("dir", class_dir)
+        files = os.listdir(path + class_dir)
+        train_set = files[: int(len(files) * 0.8)]
+        test_set = files[int(len(files) * 0.8):]
+        for file in train_set:
+            file_path = path + "train/" + class_dir + "/"
+            folder = os.path.exists(file_path)
+            if not folder:
+                os.makedirs(file_path)
+            src_file = path + class_dir + "/" + file
+            print("src_file:", src_file)
+            dst_file = file_path + file
+            print("dst_file:", dst_file)
+            shutil.copyfile(src_file, dst_file)
+
+        for file in test_set:
+            file_path = path + "test/" + class_dir + "/"
+            folder = os.path.exists(file_path)
+            if not folder:
+                os.makedirs(file_path)
+            src_file = path + class_dir + "/" + file
+            dst_file = file_path + file
+            shutil.copyfile(src_file, dst_file)
+
+
+def main():
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--split_path", help="the path of dataset to be split")
+    args = parser.parse_args()
+
+    generate_data(path=args.split_path)
+
+    create_dataset(dataset_path=args.split_path + "train/", do_train=True, batch_size=32, train_image_size=224,
+                   eval_image_size=224,
+                   enable_cache=False, cache_session_id=None)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/official/cv/vgg16/src/dataset.py b/official/cv/vgg16/src/dataset.py
index fc632b1be..6f1be68c4 100644
--- a/official/cv/vgg16/src/dataset.py
+++ b/official/cv/vgg16/src/dataset.py
@@ -16,7 +16,9 @@
 dataset processing.
 """
 import os
+import multiprocessing
 from PIL import Image, ImageFile
+import mindspore as ms
 from mindspore.common import dtype as mstype
 import mindspore.dataset as de
 import mindspore.dataset.transforms as C
@@ -25,6 +27,20 @@ from src.utils.sampler import DistributedSampler
 
 ImageFile.LOAD_TRUNCATED_IMAGES = True
 
+def get_num_parallel_workers(num_parallel_workers):
+    """
+    Get num_parallel_workers used in dataset operations.
+    If num_parallel_workers > the real CPU cores number, set num_parallel_workers = the real CPU cores number.
+    """
+    cores = multiprocessing.cpu_count()
+    if isinstance(num_parallel_workers, int):
+        if cores < num_parallel_workers:
+            print("The num_parallel_workers {} is set too large, now set it {}".format(num_parallel_workers, cores))
+            num_parallel_workers = cores
+    else:
+        print("The num_parallel_workers {} is invalid, now set it {}".format(num_parallel_workers, min(cores, 8)))
+        num_parallel_workers = min(cores, 8)
+    return num_parallel_workers
 
 def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, training=True):
     """Data operations."""
@@ -163,6 +179,66 @@ def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_s
 
     return de_dataset
 
+def create_dataset(dataset_path, do_train, batch_size=32, train_image_size=224, eval_image_size=224,
+                   enable_cache=False, cache_session_id=None):
+    """
+    create a train or eval flower dataset for vgg16
+
+    Args:
+        dataset_path(string): the path of dataset.
+        do_train(bool): whether dataset is used for train or eval.
+        batch_size(int): the batch size of dataset. Default: 32
+        enable_cache(bool): whether tensor caching service is used for eval. Default: False
+        cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None
+
+    Returns:
+        dataset
+    """
+    de.config.set_prefetch_size(64)
+    data_set = de.ImageFolderDataset(dataset_path, num_parallel_workers=get_num_parallel_workers(12), shuffle=True)
+
+    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
+    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
+
+    # define map operations
+    if do_train:
+        trans = [
+            de.vision.RandomCropDecodeResize(train_image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
+            de.vision.RandomHorizontalFlip(prob=0.5)
+        ]
+    else:
+        trans = [
+            de.vision.Decode(),
+            de.vision.Resize(256),
+            de.vision.CenterCrop(eval_image_size)
+        ]
+    trans_norm = [de.vision.Normalize(mean=mean, std=std), de.vision.HWC2CHW()]
+
+    type_cast_op = de.transforms.TypeCast(ms.int32)
+    trans_work_num = 24
+    data_set = data_set.map(operations=trans, input_columns="image",
+                            num_parallel_workers=get_num_parallel_workers(trans_work_num))
+    data_set = data_set.map(operations=trans_norm, input_columns="image",
+                            num_parallel_workers=get_num_parallel_workers(12))
+    # only enable cache for eval
+    if do_train:
+        enable_cache = False
+    if enable_cache:
+        if not cache_session_id:
+            raise ValueError("A cache session_id must be provided to use cache.")
+        eval_cache = de.DatasetCache(session_id=int(cache_session_id), size=0)
+        data_set = data_set.map(operations=type_cast_op, input_columns="label",
+                                num_parallel_workers=get_num_parallel_workers(12),
+                                cache=eval_cache)
+    else:
+        data_set = data_set.map(operations=type_cast_op, input_columns="label",
+                                num_parallel_workers=get_num_parallel_workers(12))
+
+    # apply batch operations
+    data_set = data_set.batch(batch_size, drop_remainder=True)
+
+    return data_set
+
 
 class TxtDataset:
     """
-- 
GitLab