diff --git a/research/cv/resnet3d/README_CN.md b/research/cv/resnet3d/README_CN.md
index 076278b9514bb790805d7a3e9c191f1ab37bd208..f59ae412eeb5ad9037ffb1ee0932c20e61e58b53 100644
--- a/research/cv/resnet3d/README_CN.md
+++ b/research/cv/resnet3d/README_CN.md
@@ -14,6 +14,7 @@
   - [鑴氭湰鍙傛暟](#鑴氭湰鍙傛暟)
   - [璁粌杩囩▼](#璁粌杩囩▼)
   - [璇勪及杩囩▼](#璇勪及杩囩▼)
+  - [ONNX璇勪及](#ONNX璇勪及)
   - [瀵煎嚭杩囩▼](#瀵煎嚭杩囩▼)
   - [瀵煎嚭](#瀵煎嚭)
   - [鎺ㄧ悊杩囩▼](#鎺ㄧ悊杩囩▼)
@@ -54,9 +55,11 @@ resnet3d鐨勬€讳綋缃戠粶鏋舵瀯濡備笅锛�
 - [MIT](http://moments.csail.mit.edu/)
   - MIT-IBM Watson AI Lab 鎺ㄥ嚭鐨勪竴涓叏鏂扮殑鐧句竾瑙勬ā瑙嗛鐞嗚В鏁版嵁闆哅oments in Time,鍏辨湁100,0000 涓棰�, 鐢ㄤ簬棰勮缁�
 - [hmdb51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads)
-  - 涓€涓皬鍨嬬殑瑙嗛琛屼负璇嗗埆鏁版嵁闆嗭紝鍖呭惈51绫诲姩浣滐紝鍏辨湁6849涓棰戯紝姣忎釜鍔ㄤ綔鑷冲皯鍖呭惈51涓棰�, 鐢ㄤ簬Fine-tune
+  - 涓€涓皬鍨嬬殑瑙嗛琛屼负璇嗗埆鏁版嵁闆嗭紝鍖呭惈51绫诲姩浣滐紝鍏辨湁6849涓棰戯紝姣忎釜鍔ㄤ綔鑷冲皯鍖呭惈51涓棰�, 鐢ㄤ簬Fine-tune锛屾澶勪娇鐢⊿tabilized HMDB51
+  - labels鍦板潃(http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar)
 - [UCF101](https://www.crcv.ucf.edu/data/UCF101/UCF101.rar)
   - 浠嶻ouTube鏀堕泦鐨勫叿鏈�101涓姩浣滅被鍒殑鐪熷疄鍔ㄤ綔瑙嗛鐨勫姩浣滆瘑鍒暟鎹泦, 鍏辫13320涓棰�, 鐢ㄤ簬Fine-tune
+  - labels鍦板潃(https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip)
 
 棰勮缁冩ā鍨嬭幏鍙栧湴鍧€锛�
 [閾炬帴](https://github.com/kenshohara/3D-ResNets-PyTorch)
@@ -68,7 +71,7 @@ python pth_to_ckpt.py --pth_path=./pretrained.pth --ckpt_path=./pretrained.ckpt
 
 鐗瑰埆璇存槑锛�
 
-鍘熷鏁版嵁闆嗕笅杞藉悗鏍煎紡涓猴細
+鎸夌収涓嬮潰鏍煎紡鍒涘缓鐩綍锛屽皢涓嬭浇濂界殑hmdb51_sta.rar瑙e帇锛屾妸瑙e帇鍑烘潵鐨勬枃浠跺す鏀惧埌videos鐩綍涓€傚皢鏁版嵁闆嗗搴旂殑labels瑙e帇锛屾妸瑙e帇鍑虹殑txt鏂囦欢绉诲姩鍒發abels鐩綍涓€�
 
 ```text
 .
@@ -87,18 +90,18 @@ python pth_to_ckpt.py --pth_path=./pretrained.pth --ckpt_path=./pretrained.ckpt
   鈹斺攢鈹€json
 ```
 
-浣跨敤src/generate_hmdb51_json.py鐢熸垚json鏍煎紡鐨勬爣娉ㄦ枃浠�
+浣跨敤src/generate_video_jpgs.py灏哸vi鏍煎紡鐨勮棰戞枃浠惰浆鎹负jpg鏍煎紡鐨勫浘鐗囨枃浠�
 
 ```text
 cd ~/src
-python3 generate_hmdb51_json.py --dir_path ~/dataset/hmdb51/labels/ --video_path ~/dataset/hmdb51/videos/ --dst_dir_path ~/dataset/hmdb51/json
+python3 generate_video_jpgs.py --video_path ~/dataset/hmdb51/videos/ --target_path ~/dataset/hmdb51/jpg/
 ```
 
-浣跨敤src/generate_video_jpgs.py灏哸vi鏍煎紡鐨勮棰戞枃浠惰浆鎹负jpg鏍煎紡鐨勫浘鐗囨枃浠�
+浣跨敤src/generate_hmdb51_json.py鐢熸垚json鏍煎紡鐨勬爣娉ㄦ枃浠�
 
 ```text
 cd ~/src
-python3 generate_video_jpgs.py --video_path ~/dataset/hmdb51/videos/ --target_path ~/dataset/hmdb51/jpg/
+python3 generate_hmdb51_json.py --dir_path ~/dataset/hmdb51/labels/ --video_path ~/dataset/hmdb51/jpg/ --dst_dir_path ~/dataset/hmdb51/json
 ```
 
 # 鐗规€�
@@ -115,7 +118,7 @@ python3 generate_video_jpgs.py --video_path ~/dataset/hmdb51/videos/ --target_pa
     - [MindSpore](https://www.mindspore.cn/install/en)
 - 濡傞渶鏌ョ湅璇︽儏锛岃鍙傝濡備笅璧勬簮锛�
   - [MindSpore鏁欑▼](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
-  - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html)
+  - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
 
 # 蹇€熷叆闂�
 
@@ -146,6 +149,7 @@ python3 generate_video_jpgs.py --video_path ~/dataset/hmdb51/videos/ --target_pa
     鈹溾攢鈹€ run_distribute_train.sh            # 鍚姩Ascend鍒嗗竷寮忚缁冿紙8鍗★級
     鈹溾攢鈹€ run_eval.sh                        # 鍚姩Ascend璇勪及
     鈹溾攢鈹€ run_standalone_train.sh            # 鍚姩Ascend鍗曟満璁粌锛堝崟鍗★級
+    鈹溾攢鈹€run_eval_onnx.sh                    # ONNX璇勪及鐨剆hell鑴氭湰
   鈹溾攢鈹€ src
     鈹溾攢鈹€ __init__.py
     鈹溾攢鈹€ config.py                          # yaml鏂囦欢瑙f瀽
@@ -165,6 +169,7 @@ python3 generate_video_jpgs.py --video_path ~/dataset/hmdb51/videos/ --target_pa
     鈹斺攢鈹€  videodataset_multiclips.py        # 鑷畾涔夋暟鎹泦鍔犺浇鏂瑰紡
   鈹溾攢鈹€ pth_to_ckpt.py                       # 灏嗛璁粌妯″瀷浠巔th鏍煎紡杞崲涓篶kpt鏍煎紡
   鈹溾攢鈹€ eval.py                              # 璇勪及缃戠粶
+  鈹溾攢鈹€ eval_onnx.py                         # ONNX璇勪及鑴氭湰
   鈹溾攢鈹€ train.py                             # 璁粌缃戠粶
   鈹溾攢鈹€ hmdb51_config.yaml                   # 鍙傛暟閰嶇疆
   鈹斺攢鈹€ ucf101_config.yaml                   # 鍙傛暟閰嶇疆  
@@ -180,6 +185,7 @@ python3 generate_video_jpgs.py --video_path ~/dataset/hmdb51/videos/ --target_pa
     'result_path': './results/ucf101',                                             # 璁粌銆佹帹鐞嗙粨鏋滆矾寰�
     'pretrain_path': '~/your_path/pretrained.ckpt',                                # 棰勮缁冩ā鍨嬫枃浠惰矾寰�
     'inference_ckpt_path': "~/your_path/results/ucf101/result.ckpt",               # 鐢ㄤ簬鎺ㄧ悊鐨勬ā鍨嬫枃浠惰矾寰�
+    'onnx_path': "~/your_path/results/result-3d.onnx",               # 鐢ㄤ簬鎺ㄧ悊鐨勬ā鍨嬫枃浠惰矾寰�
     'n_classes': 101,                                                              # 鏁版嵁闆嗙被鍒暟
     'sample_size': 112,                                                            # 鍥剧墖鍒嗚鲸鐜�
     'sample_duration': 16,                                                         # 瑙嗛鐗囨闀垮害锛屽崟浣嶏細甯�
@@ -303,6 +309,48 @@ clip: 66.5% top-1: 69.7%  top-5: 93.8%
 clip: 88.8% top-1: 92.7%  top-5: 99.3%
 ```
 
+## ONNX璇勪及
+
+### 瀵煎嚭onnx妯″瀷
+
+```bash
+python export.py --ckpt_file=/path/best.ckpt --file_format=ONNX --n_classes=51 --batch_size=1 --device_target=GPU
+ ```
+
+- `ckpt_file` ckpt鏂囦欢璺緞
+- `file_format` 瀵煎嚭妯″瀷鏍煎紡锛屾澶勪负ONNX
+- `n_classes` 浣跨敤鏁版嵁闆嗙被鍒暟锛宧mdb51鏁版嵁闆嗘鍙傛暟涓�51锛寀cf101鏁版嵁闆嗘鍙傛暟涓�101
+- `batch_size` 鎵规鏁帮紝鍥哄畾涓�1
+- `device_target` 鐩墠浠呮敮鎸丟PU鎴朇PU
+
+### 杩愯ONNX妯″瀷璇勪及
+
+```bash
+鐢ㄦ硶锛歜ash run_eval_onnx.sh [ucf101|hmdb51] [VIDEO_PATH] [ANNOTATION_PATH] [ONNX_PATH]
+瀹炰緥锛歜ash run_eval_onnx.sh ucf101 /path/ucf101/jpg /path/ucf101/json/ucf101_01.json /path/resnet-3d.onnx
+ ```
+
+- `[ucf101|hmdb51]` 閫夋嫨鎵€浣跨敤鐨勬暟鎹泦
+- `[VIDEO_PATH]` 瑙嗛璺緞
+- `[ANNOTATION_PATH]` 鏍囩璺緞
+- `[ONNX_PATH]` onnx妯″瀷鐨勮矾寰�
+
+### 缁撴灉
+
+璇勪及缁撴灉淇濆瓨鍦ㄧず渚嬭矾寰勪腑锛屾枃浠跺悕涓衡€渵/eval_onnx.log鈥濄€傛偍鍙湪姝よ矾寰勪笅鐨勬棩蹇楁壘鍒板涓嬬粨鏋滐細
+
+- 浣跨敤hmdb51鏁版嵁闆嗚瘎浼皉esnet3d
+
+```text
+clip: 66.5% top-1: 69.7%  top-5: 93.8%
+```
+
+- 浣跨敤ucf101鏁版嵁闆嗚瘎浼皉esnet3d
+
+```text
+clip: 88.8% top-1: 92.7%  top-5: 99.3%
+```
+
 ## 瀵煎嚭杩囩▼
 
 ### 瀵煎嚭
@@ -310,7 +358,7 @@ clip: 88.8% top-1: 92.7%  top-5: 99.3%
 鍦ㄥ鍑烘椂锛宧mdb51鏁版嵁闆�,鍙傛暟n_classes璁剧疆涓�51,ucf101鏁版嵁闆�,鍙傛暟n_classes璁剧疆涓�101, 鍙傛暟batch_size鍙兘璁剧疆涓�1.
 
 ```shell
-python export.py --ckpt_file=./saved_model/best.ckpt --file_format=MINDIR --n_classes=51, --batch_size=1
+python export.py --ckpt_file=./saved_model/best.ckpt --file_format=MINDIR --n_classes=51 --batch_size=1
 ```
 
 ## 鎺ㄧ悊杩囩▼
diff --git a/research/cv/resnet3d/eval_onnx.py b/research/cv/resnet3d/eval_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..e879324762838633bfa4208397b7abb9b1401cb1
--- /dev/null
+++ b/research/cv/resnet3d/eval_onnx.py
@@ -0,0 +1,146 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Eval.
+"""
+import time
+import random
+import json
+from collections import defaultdict
+import numpy as np
+import onnxruntime
+from mindspore import dataset as de
+from mindspore.common import set_seed
+from src.config import config as args_opt
+from src.dataset import create_eval_dataset
+from src.inference import (topk_, get_video_results, load_ground_truth, load_result,
+                           remove_nonexistent_ground_truth, calculate_clip_acc)
+from src.videodataset_multiclips import get_target_path
+
+
+random.seed(1)
+np.random.seed(1)
+de.config.set_seed(1)
+set_seed(1)
+
+
+if __name__ == '__main__':
+    t1_ = time.time()
+    cfg = args_opt
+    print(cfg)
+    target = args_opt.device_target
+    if target == 'GPU':
+        providers = ['CUDAExecutionProvider']
+    elif target == 'CPU':
+        providers = ['CPUExecutionProvider']
+    else:
+        raise ValueError(
+            f'Unsupported target device {target}, '
+            f'Expected one of: "CPU", "GPU"'
+        )
+
+    session = onnxruntime.InferenceSession(args_opt.onnx_path, providers=providers)
+    predict_data = create_eval_dataset(
+        cfg.video_path, cfg.annotation_path, cfg)
+    size = predict_data.get_dataset_size()
+    total_target_path = get_target_path(cfg.annotation_path)
+    with total_target_path.open('r') as f:
+        total_target_data = json.load(f)
+    results = {'results': defaultdict(list)}
+    count = 0
+    for data in predict_data.create_dict_iterator(output_numpy=True):
+        t1 = time.time()
+        x, label = data['data'][0], data['label'].tolist()
+        video_ids, segments = zip(
+            *total_target_data['targets'][str(label[0])])
+        x_list = np.split(x, x.shape[0], axis=0)
+        outputs = []
+        for x in x_list:
+            inputs = {session.get_inputs()[0].name: x}
+            output = session.run(None, inputs)[0]
+            outputs.append(output)
+        outputs = np.concatenate(outputs, axis=0)
+
+        _, locs = topk_(outputs, K=1)
+        locs = locs.reshape(1, -1)
+
+        t2 = time.time()
+        print("[{} / {}] Net time: {} ms".format(count, size, (t2 - t1) * 1000))
+        for j in range(0, outputs.shape[0]):
+            results['results'][video_ids[j]].append({
+                'segment': segments[j],
+                'output': outputs[j]
+            })
+        count += 1
+
+    class_names = total_target_data['class_names']
+    inference_results = {'results': {}}
+    clips_inference_results = {'results': {}}
+    for video_id, video_results in results['results'].items():
+        video_outputs = [
+            segment_result['output'] for segment_result in video_results
+        ]
+        video_outputs = np.stack(video_outputs, axis=0)
+        average_scores = np.mean(video_outputs, axis=0)
+        clips_inference_results['results'][video_id] = get_video_results(
+            average_scores, class_names, 5)
+
+        inference_results['results'][video_id] = []
+        for segment_result in video_results:
+            segment = segment_result['segment']
+            result = get_video_results(segment_result['output'],
+                                       class_names, 5)
+            inference_results['results'][video_id].append({
+                'segment': segment,
+                'result': result
+            })
+    # init context
+    print('load ground truth')
+    ground_truth, class_labels_map = load_ground_truth(
+        cfg.annotation_path, "validation")
+    print('number of ground truth: {}'.format(len(ground_truth)))
+
+    n_ground_truth_top_1 = len(ground_truth)
+    n_ground_truth_top_5 = len(ground_truth)
+
+    result_top1, result_top5 = load_result(
+        clips_inference_results, class_labels_map)
+
+    ground_truth_top1 = remove_nonexistent_ground_truth(
+        ground_truth, result_top1)
+    ground_truth_top5 = remove_nonexistent_ground_truth(
+        ground_truth, result_top5)
+
+    if cfg.ignore:
+        n_ground_truth_top_1 = len(ground_truth_top1)
+        n_ground_truth_top_5 = len(ground_truth_top5)
+
+    correct_top1 = [1 if line[1] in result_top1[line[0]]
+                    else 0 for line in ground_truth_top1]
+    correct_top5 = [1 if line[1] in result_top5[line[0]]
+                    else 0 for line in ground_truth_top5]
+
+    clip_acc = calculate_clip_acc(
+        inference_results, ground_truth, class_labels_map)
+    print(sum(correct_top1))
+    print(n_ground_truth_top_1)
+    print(sum(correct_top5))
+    print(n_ground_truth_top_5)
+    accuracy_top1 = float(sum(correct_top1)) / float(n_ground_truth_top_1)
+    accuracy_top5 = float(sum(correct_top5)) / float(n_ground_truth_top_5)
+    print('==================Accuracy=================\n'
+          ' clip-acc : {} \ttop-1 : {} \ttop-5: {}'.format(clip_acc, accuracy_top1, accuracy_top5))
+    t2_ = time.time()
+    print("Total time : {} s".format(t2_ - t1_))
diff --git a/research/cv/resnet3d/export.py b/research/cv/resnet3d/export.py
index c1722b6adddb21e85636c626c532d91dd195b90f..8c969897e1cfe14e939a766af81b2d6dd45ef4a4 100644
--- a/research/cv/resnet3d/export.py
+++ b/research/cv/resnet3d/export.py
@@ -31,7 +31,7 @@ parser.add_argument('--ckpt_file', type=str, required=True,
 parser.add_argument('--file_name', type=str,
                     default='resnet-3d', help='Output file name.')
 parser.add_argument('--file_format', type=str,
-                    choices=['AIR', 'MINDIR'], default='MINDIR', help='File format.')
+                    choices=['AIR', 'MINDIR', 'ONNX'], default='MINDIR', help='File format.')
 parser.add_argument('--device_target', type=str, choices=['Ascend', 'CPU', 'GPU'], default='Ascend',
                     help='Device target')
 parser.add_argument('--sample_duration', type=int, default=16)
diff --git a/research/cv/resnet3d/hmdb51_config.yaml b/research/cv/resnet3d/hmdb51_config.yaml
index 3f06cfbf6e00e5b8d1db9e5542396182dc67014c..05a0fe0c33e518397050d70e31e05ecdd8aece36 100644
--- a/research/cv/resnet3d/hmdb51_config.yaml
+++ b/research/cv/resnet3d/hmdb51_config.yaml
@@ -14,6 +14,7 @@ annotation_path: ""
 result_path: ""
 pretrain_path: ""
 inference_ckpt_path: ""
+onnx_path: ""
 n_classes: 51
 sample_size: 112
 sample_duration: 16
diff --git a/research/cv/resnet3d/requirements.txt b/research/cv/resnet3d/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..980596ca780f721d0444c84d45589249000bc739
--- /dev/null
+++ b/research/cv/resnet3d/requirements.txt
@@ -0,0 +1,4 @@
+numpy 1.21.6
+onnxruntime-gpu 1.11.1
+pyyaml 6.0
+Pillow 9.2.0
\ No newline at end of file
diff --git a/research/cv/resnet3d/scripts/run_eval_onnx.sh b/research/cv/resnet3d/scripts/run_eval_onnx.sh
new file mode 100644
index 0000000000000000000000000000000000000000..44dc54ff6790b8113bdde3313764b1959b03cd9a
--- /dev/null
+++ b/research/cv/resnet3d/scripts/run_eval_onnx.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_eval_onnx.sh [ucf101|hmdb51] [VIDEO_PATH] [ANNOTATION_PATH] [ONNX_PATH]"
+echo "For example:
+bash run_eval_onnx.sh 0 ucf101 \\
+/path/ucf101/jpg/ \\
+/path/ucf101/json/ucf101_01.json \\
+/path/resnet-3d.onnx"
+echo "It is better to use the ABSOLUTE path."
+echo "=============================================================================================================="
+set -e
+
+if [ $# != 4 ]
+then
+  echo "Usage: bash run_eval_onnx.sh [ucf101|hmdb51] [VIDEO_PATH] [ANNOTATION_PATH] [ONNX_PATH]"
+exit 1
+fi
+
+DATASET=$1
+VIDEO_PATH=$2
+ANNOTATION_PATH=$3
+ONNX_PATH=$4
+
+EXEC_PATH=$(pwd)
+echo "$EXEC_PATH"
+cd ..
+env > env.log
+echo "Eval begin"
+python eval_onnx.py --is_modelarts False  --config_path ./${DATASET}_config.yaml --video_path $VIDEO_PATH \
+--annotation_path $ANNOTATION_PATH --onnx_path $ONNX_PATH --device_target GPU > eval_$DATASET.log 2>&1 &
+
+echo "Evaling. Check it at eval_$DATASET.log"
+
diff --git a/research/cv/resnet3d/src/pil_transforms.py b/research/cv/resnet3d/src/pil_transforms.py
index cfd0f0fa7166370eef969ea9fc8bc816d4f4c0dd..e194f35cab8be5abb3e97003f13541c1251cbfc2 100644
--- a/research/cv/resnet3d/src/pil_transforms.py
+++ b/research/cv/resnet3d/src/pil_transforms.py
@@ -33,7 +33,7 @@ class PILTrans:
                                      ratio=(opt.train_crop_min_ratio, 1.0 / opt.train_crop_min_ratio))
         self.random_horizontal_flip = vision.RandomHorizontalFlip(prob=0.5)
         self.color = vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1)
-        self.normalize = vision.Normalize(mean=mean, std=std, is_hwc=False)
+        self.normalize = vision.Normalize(mean=mean, std=std)
         self.to_tensor = vision.ToTensor()
         self.resize = vision.Resize(opt.sample_size)
         self.center_crop = vision.CenterCrop(opt.sample_size)
@@ -75,7 +75,7 @@ class EvalPILTrans:
         self.to_pil = vision.ToPIL()
         self.resize = vision.Resize(opt.sample_size)
         self.center_crop = vision.CenterCrop(opt.sample_size)
-        self.normalize = vision.Normalize(mean=mean, std=std, is_hwc=False)
+        self.normalize = vision.Normalize(mean=mean, std=std)
         self.to_tensor = vision.ToTensor()
 
     def __call__(self, data, labels, batchInfo):
diff --git a/research/cv/resnet3d/ucf101_config.yaml b/research/cv/resnet3d/ucf101_config.yaml
index 25729f91f095f903d1d0ab0d5da2a135852970d8..cd77d34e896e18da0fa23ad6365fb01d34d01082 100644
--- a/research/cv/resnet3d/ucf101_config.yaml
+++ b/research/cv/resnet3d/ucf101_config.yaml
@@ -14,6 +14,7 @@ annotation_path: ""
 result_path: ""
 pretrain_path: ""
 inference_ckpt_path: ""
+onnx_path: ""
 n_classes: 101
 sample_size: 112
 sample_duration: 16
@@ -97,4 +98,4 @@ lr_decay_mode:
   -
     linear
   -
-    cosine
\ No newline at end of file
+    cosine