diff --git a/research/cv/ResNeSt50/README.md b/research/cv/ResNeSt50/README.md
index e039c5fff782b7b41d345a59b48e4cb7b29086e9..fad0120f3a5f6012f8ff886d92aa127bf985bf3e 100644
--- a/research/cv/ResNeSt50/README.md
+++ b/research/cv/ResNeSt50/README.md
@@ -1,6 +1,5 @@
 # 目录
 
-- [目录](#目录)
 - [ResNeSt说明](#resnest说明)
 - [模型架构](#模型架构)
 - [数据集](#数据集)
@@ -55,14 +54,14 @@ ResNeSt整体网络架构如下:
 
 ## 混合精度
 
-采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
+采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
 
 以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
 
 # 环境要求
 
-- 硬件(Ascend)
-    - 使用Ascend处理器来搭建硬件环境。
+- 硬件(Ascend/GPU)
+    - 使用Ascend或GPU处理器来搭建硬件环境。
 - 框架
     - [MindSpore](https://www.mindspore.cn/install)
 - 如需查看详情,请参见如下资源:
@@ -76,12 +75,14 @@ ResNeSt整体网络架构如下:
 ```path
 .
 └─ResNeSt50
-  ├─README.md
   ├─scripts
     ├─run_train.sh
     ├─run_eval.sh
+    ├─run_train_gpu.sh
+    ├─run_eval_gpu.sh
     ├─run_distribute_train.sh              # 启动Ascend分布式训练(8卡)
     ├─run_distribute_eval.sh               # 启动Ascend分布式评估(8卡)
+    ├─run_distribute_train_gpu.sh              # 启动GPU分布式训练(8卡)
     └─run_infer_310.sh                     # 启动310推理
   ├─src
     ├─datasets
@@ -101,7 +102,7 @@ ResNeSt整体网络架构如下:
   ├──export.py                        # 导出Mindir接口
   ├──create_imagenet2012_label.py     # 创建数据集标签用于310推理精度验证
   ├──postprocess.py                   # 后处理
-  ├──README.md                        # README文件
+  └──README.md                        # README文件
 ```
 
 ## 脚本参数
@@ -110,7 +111,7 @@ ResNeSt整体网络架构如下:
 
 ```python
 "net_name": 'resnest50'                   # 网络选择
-"root": '/mass_data/imagenet/imagenet/'   # 数据集路径
+"root": "/home/mindspore/dataset/imagenet_original"   # 数据集路径
 "num_classes": 1000,                      # 数据集类数
 "base_size": 224,                         # 图像大小
 "crop_size": 224,                         # crop大小
@@ -146,7 +147,10 @@ ResNeSt整体网络架构如下:
 您可以通过python脚本开始训练:
 
 ```shell
-python train.py --outdir ./output --device_target Ascend
+Ascend:
+   python train.py --outdir ./output --device_target [device]
+GPU:
+   python train.py --outdir ./output --device_target [device]
 ```
 
 或通过shell脚本开始训练:
@@ -154,9 +158,14 @@ python train.py --outdir ./output --device_target Ascend
 ```shell
 Ascend:
     # 分布式训练示例(8卡)
-    bash run_distribute_train.sh RANK_TABLE_FILE OUTPUT_DIR
+    bash scripts/run_distribute_train.sh RANK_TABLE_FILE OUTPUT_DIR
+    # 单机训练
+    bash scripts/run_train.sh OUTPUT_DIR
+GPU:
+    # 分布式训练示例(8卡)
+    bash scripts/run_distribute_train_gpu.sh [DEVICE_NUM]
     # 单机训练
-    bash run_train.sh OUTPUT_DIR
+    bash scripts/run_train_gpu.sh
 ```
 
 ### 训练样例
@@ -166,6 +175,10 @@ Ascend:
 bash scripts/run_distribute_train.sh RANK_TABLE_FILE OUTPUT_DIR
 # Ascend单机训练示例
 bash scripts/run_train.sh OUTPUT_DIR
+# GPU分布式训练示例(8卡)
+bash scripts/run_distribute_train_gpu.sh 8
+# GPU单机训练示例
+bash scripts/run_train_gpu.sh
 ```
 
 您可以在日志中找到检查点文件和结果。
@@ -177,37 +190,54 @@ bash scripts/run_train.sh OUTPUT_DIR
 您可以通过python脚本开始评估:
 
 ```shell
-python eval.py --outdir ./output --resume_path ~/resnest50-270_2502.ckpt
+Ascend:
+python eval.py --outdir ./output --pretrained_ckpt_path ~/resnest50-270_2502.ckpt
+GPU:
+python eval.py --outdir ./output --pretrained_ckpt_path ~/resnest50-270_2502.ckpt --device_target “GPU”
 ```
 
-或通过shell脚本开始训练:
+或通过shell脚本开始评估:
 
 ```shell
 # 评估
-bash run_eval.sh OUT_DIR PRETRAINED_CKPT_PATH
+Ascend:
+bash scripts/run_eval.sh [OUT_DIR] [PRETRAINED_CKPT_PATH]
+GPU:
+bash scripts/run_eval_gpu.sh [OUT_DIR] [PRETRAINED_CKPT_PATH]
 ```
 
-PLATFORM is Ascend, default is Ascend.
-
 ### 评估样例
 
 ```shell
 # 检查点评估
+Ascend:
 bash scripts/run_eval.sh OUT_DIR PRETRAINED_CKPT_PATH
+GPU:
+bash scripts/run_eval_gpu.sh OUT_DIR PRETRAINED_CKPT_PATH
 
 #或者直接使用脚本运行
-python eval.py --outdir ./output --resume_path ~/resnest50-270_2502.ckpt
+python eval.py --outdir ./output --pretrained_ckpt_path ~/resnest50-270_2502.ckpt
 ```
 
 ### 评估结果
 
-评估结果保存在脚本路径`/scripts/EVAL_LOG/`下。您可以在日志中找到类似以下的结果。
+Ascend评估结果保存在脚本路径`/scripts/EVAL_LOG/`下。您可以在日志中找到类似以下的结果。
 
 ```log
 acc=80.90%(TOP1)
 acc=95.51%(TOP5)
 ```
 
+GPU评估结果保存在脚本路径`/output1/valid下。您可以在日志中找到类似以下的结果。
+
+```log
+2022-01-30 12:10:33,478:INFO:Inference Performance: 379.23 img/sec
+2022-01-30 12:10:33,478:INFO:before results=[[40525], [47716], [49984]]
+2022-01-30 12:10:33,479:INFO:after results=[[40525],[47716],[49984]]
+2022-01-30 12:10:33,479:INFO:after allreduce eval: top1_correct=40525, tot=49984,acc=81.08%(TOP1)
+2022-01-30 12:10:33,479:INFO:after allreduce eval: top5_correct=47716, tot=49984,acc=95.46%(TOP5)
+```
+
 ## 推理过程
 
 在Ascend310执行推理,执行推理之前,需要通过`export.py`文件导出MINDIR模型
@@ -253,31 +283,31 @@ acc=0.9548(TOP5)
 
 ### 训练性能
 
-| 参数                       | ResNeSt50                                                  |
-| -------------------------- | ---------------------------------------------------------- |
-| 资源                       | Ascend 910;CPU:2.60GHz,192核;内存:755GB               |
-| 上传日期                   | 2021-11-09                                                 |
-| MindSpore版本              | 1.3                                                        |
-| 数据集                     | ImageNet                                                   |
-| 训练参数                   | src/config.py                                              |
-| 优化器                     | Momentum                                                   |
-| 损失函数                   | Softmax交叉熵                                              |
-| 损失                       | 1.466                                                      |
-| 准确率                     | 80.9%(TOP1)                                                |
-| 总时长                     | 84h21m39s (8卡)                                          |
-| 调优检查点                 | 223 M(.ckpt文件)                                         |
+| 参数                       | Ascend 910                                                  |  GPU                                                  |
+| -------------------------- | ---------------------------------------------------------- | ---------------------------------------------------------- |
+| 资源                       | Ascend 910;CPU:2.60GHz,192核;内存:755GB | GeForce RTX 3090 ;CPU 2.90GHz,16cores;内存,252G        |
+| 上传日期                   | 2021-11-09                                                  | 2022-2-15                                                  |
+| MindSpore版本              | 1.3                                                        | 1.5                                                        |
+| 数据集                     | ImageNet                                                   | ImageNet                                                   |
+| 训练参数                   | src/config.py                                              | src/config.py                                              |
+| 优化器                     | Momentum                                                   | Momentum                                                   |
+| 损失函数                   | Softmax交叉熵                                              | Softmax交叉熵                                              |
+| 损失                       | 1.466                                                     | 1.5859                                                     |
+| 准确率                     | 80.9%(TOP1)                                               | 81.08%(TOP1)                                               |
+| 总时长                     | 84h21m39s (8卡)                                        | 66h42m42s185(8卡)                                        |
+| 调优检查点                 | 223 M(.ckpt文件)                                         | 212 M(.ckpt文件)                                         |
 
 ### 推理性能
 
-| 参数                       |                      |
-| -------------------------- | -------------------- |
-| 资源                       | Ascend 910           |
-| 上传日期                   | 2021-11-09           |
-| MindSpore版本              | 1.3                  |
-| 数据集                     | ImageNet, 5万       |
-| batch_size                 | 1                    |
-| 输出                       | 分类准确率           |
-| 准确率                     | acc=80.9%(TOP1)      |
+| 参数                       |  Ascend 910   |  GPU             |
+| -------------------------- | -------------------- | -------------------- |
+| 资源                       | Ascend 910           | GeForce RTX 3090     |
+| 上传日期                   | 2021-11-09           |2022-2-15            |
+| MindSpore版本              | 1.3                  | 1.5                  |
+| 数据集                     | ImageNet, 5万       | ImageNet, 5万       |
+| batch_size                 | 1                    | 1                    |
+| 输出                       | 分类准确率           | 分类准确率           |
+| 准确率                     | acc=80.9%(TOP1)      | acc=81.08%(TOP1)      |
 
 # 随机情况说明
 
diff --git a/research/cv/ResNeSt50/eval.py b/research/cv/ResNeSt50/eval.py
index 6f155d18d070084138c7d1e4e5fcc98252f271f4..69d543f58711b7c27e5b8f4ae338b86c81424b9b 100644
--- a/research/cv/ResNeSt50/eval.py
+++ b/research/cv/ResNeSt50/eval.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -91,7 +91,7 @@ def Parse(args=None):
     parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
                         help='device where the code will be implemented (default: Ascend)')
 
-    parser.add_argument('--resume_path', type=str, default="./output/ckpt_0/resnest50-270_2502.ckpt",
+    parser.add_argument('--pretrained_ckpt_path', type=str, default="./output/ckpt_0/resnest50-270_2502.ckpt",
                         help='put the path to resuming file if needed')
     parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
     args = parser.parse_args()
@@ -157,15 +157,13 @@ def test():
 
     # initialize model
     args.logger.important_info('start create network')
-    net = get_network(config.net_name, True, args.resume_path)
+    net = get_network(config.net_name, True, args.pretrained_ckpt_path)
 
     img_tot = 0
     top1_correct = 0
     top5_correct = 0
     if target == "Ascend":
         net.to_float(mstype.float16)
-    else:
-        auto_mixed_precision(net)
     net.set_train(False)
     t_end = time.time()
     it_name = 0
@@ -190,7 +188,7 @@ def test():
         fps = (img_tot - config.batch_size) * config.group_size / time_used
         args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
 
-    results = get_result(args, args.resume_path, top1_correct, top5_correct, img_tot)
+    results = get_result(args, args.pretrained_ckpt_path, top1_correct, top5_correct, img_tot)
     top1_correct = results[0, 0]
     top5_correct = results[1, 0]
     img_tot = results[2, 0]
diff --git a/research/cv/ResNeSt50/scripts/run_distribute_train_gpu.sh b/research/cv/ResNeSt50/scripts/run_distribute_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..11b12f9369c0c31601487686273beb95af43c48c
--- /dev/null
+++ b/research/cv/ResNeSt50/scripts/run_distribute_train_gpu.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_distribute_train_gpu.sh RANK_SIZE"
+echo "For example: bash run_distribute_train_gpu.sh 8"
+echo "=============================================================================================================="
+set -e
+
+export DEVICE_NUM=$1
+export RANK_SIZE=$1
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+echo "start training"
+mpirun -n $1 --allow-run-as-root --output-filename log_output1 --merge-stderr-to-stdout \
+    python ./train.py --device_num $1  --device_target "GPU" --run_distribute True \
+    > train_dis.log 2>&1 &
+
diff --git a/research/cv/ResNeSt50/scripts/run_eval_gpu.sh b/research/cv/ResNeSt50/scripts/run_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1ab03e465e96f8c628071d600d528f0c6ec75ee3
--- /dev/null
+++ b/research/cv/ResNeSt50/scripts/run_eval_gpu.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_eval.sh OUT_DIR PRETRAINED_CKPT_PATH"
+echo "For example: bash run_eval.sh ./output /path/dataset pretrained_ckpt_path"
+echo "=============================================================================================================="
+set -e
+
+export DEVICE_NUM=1
+export RANK_SIZE=1
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+
+echo "start evaling"
+python ./eval.py --outdir $1 --pretrained_ckpt_path $2 --device_target "GPU"\
+    > eval.log 2>&1 &
+
diff --git a/research/cv/ResNeSt50/scripts/run_train_gpu.sh b/research/cv/ResNeSt50/scripts/run_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cf6faee8f3a4050ef3a52d9b8f75753fef9b2df3
--- /dev/null
+++ b/research/cv/ResNeSt50/scripts/run_train_gpu.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_train_gpu"
+echo "=============================================================================================================="
+set -e
+
+export DEVICE_NUM=1
+export RANK_SIZE=1
+export DATASET_NAME=$1
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+echo "start training"
+python ./train.py --device_num 1 --device_target "GPU"\
+    > train.log 2>&1 &
+
diff --git a/research/cv/ResNeSt50/src/config.py b/research/cv/ResNeSt50/src/config.py
index 7ea8b828e51579a3f19474946e8dc58ee7b9ee2f..727294cc9c6d0f2c1414471947c6c4e1f3138e8a 100644
--- a/research/cv/ResNeSt50/src/config.py
+++ b/research/cv/ResNeSt50/src/config.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@ from easydict import EasyDict as ed
 
 config_train = ed({
     "net_name": "resnest50",
-    "root": "/mass_data/imagenet/imagenet/",
+    "root": "/data1/datasets/imagenet/",
     "base_size": 224,
     "crop_size": 224,
     "num_classes": 1000,
@@ -28,7 +28,7 @@ config_train = ed({
     "final_drop": 1.0,
     "epochs": 270,
     "start_epoch": 0,
-    "num_workers": 64,
+    "num_workers": 50,
 
     "lr": 0.025,
     "steps_per_epoch": 1,
diff --git a/research/cv/ResNeSt50/src/models/resnet.py b/research/cv/ResNeSt50/src/models/resnet.py
index f8e839e46c297cbdd7689ebc48e3c6a4a498950b..e9e4554b9fc9457e3bddd27e43c4c4ae454fa9e6 100644
--- a/research/cv/ResNeSt50/src/models/resnet.py
+++ b/research/cv/ResNeSt50/src/models/resnet.py
@@ -18,7 +18,6 @@ from mindspore.common import initializer as init
 from mindspore import nn
 from mindspore.ops import operations as P
 from mindspore.train.serialization import load_param_into_net
-
 from src.models.splat import SplAtConv2d
 from src.models.utils import Resume
 
diff --git a/research/cv/ResNeSt50/src/models/splat.py b/research/cv/ResNeSt50/src/models/splat.py
index 37de96adfaef1d150ca01019aae611da1fd1d8a2..8e8e129402969d37a6dc9d2f5445af9197019d82 100644
--- a/research/cv/ResNeSt50/src/models/splat.py
+++ b/research/cv/ResNeSt50/src/models/splat.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,13 +25,11 @@ conv_weight_init = 'HeUniform'
 class GroupConv(nn.Cell):
     """
     group convolution operation.
-
     Args:
         in_channels (int): Input channels of feature map.
         out_channels (int): Output channels of feature map.
         kernel_size (int): Size of convolution kernel.
         stride (int): Stride size for the group convolution layer.
-
     Returns:
         tensor, output tensor.
     """
@@ -41,18 +39,28 @@ class GroupConv(nn.Cell):
         assert in_channels % group == 0 and out_channels % group == 0
         self.group = group
         self.convs = nn.CellList()
+
         self.op_split = P.Split(axis=1, output_num=self.group)
         self.op_concat = P.Concat(axis=1)
         self.cast = P.Cast()
+
         for _ in range(group):
             self.convs.append(nn.Conv2d(in_channels//group, out_channels//group,
                                         kernel_size=kernel_size, stride=stride, has_bias=has_bias,
                                         padding=padding, pad_mode=pad_mode, group=1, weight_init=conv_weight_init))
 
+
     def construct(self, x):
-        features = self.op_split(x)
+
+        if self.group > 1:
+            features = (x[:, 0:(x.shape[1] // 2), :, :], x[:, (x.shape[1] // 2):, :, :])
+        else:
+            features = (x,)
         outputs = ()
+
         for i in range(self.group):
+            if len(features[i].shape) < 4:
+                print("error")
             outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),)
         out = self.op_concat(outputs)
         return out
@@ -83,7 +91,9 @@ class SplAtConv2d(nn.Cell):
         self.rsoftmax = rSoftMax(radix, groups)
         self.reshape = P.Reshape()
         self.split = P.Split(axis=1, output_num=self.radix)
+        self.split1 = P.Split(axis=1, output_num=2)
         self.sum = P.AddN()
+        self.cast = P.Cast()
 
     def construct(self, x):
         """Split attention construct"""
@@ -92,11 +102,14 @@ class SplAtConv2d(nn.Cell):
             x = self.bn0(x)
         x = self.relu(x)
         batch = x.shape[0]
-        splited = self.split(x)
+        splited = ()
         if self.radix > 1:
+            if self.radix != 2:
+                print("error")
             outputs = ()
+            splited = (x[:, 0: (x.shape[1]// 2), :, :], x[:, (x.shape[1]// 2):, :, :])
             for i in range(self.radix):
-                outputs = outputs + (splited[i],)
+                outputs = outputs + (self.cast(splited[i], mstype.float32),)
             gap = self.sum(outputs)
         else:
             gap = x
@@ -111,7 +124,10 @@ class SplAtConv2d(nn.Cell):
         atten = self.fc2(gap)
         atten = self.rsoftmax(atten)
         atten = self.reshape(atten, (batch, -1, 1, 1))
-        attens = self.split(atten)
+        if self.radix > 1:
+            attens = (atten[:, 0:(atten.shape[1]//2), :, :], atten[:, (atten.shape[1]//2):, :, :])
+        else:
+            attens = (atten,)
 
         if self.radix > 1:
             outputs = ()
diff --git a/research/cv/ResNeSt50/train.py b/research/cv/ResNeSt50/train.py
index 3a47055c03eb1a8e83c4f29fc7781104df39d37e..67fc92d9d1b2271740f1644a0e224e1ee8935649 100644
--- a/research/cv/ResNeSt50/train.py
+++ b/research/cv/ResNeSt50/train.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,11 +19,11 @@ import time
 import datetime
 import ast
 import argparse
-
+import numpy as np
 from mindspore import Tensor, context, nn
 from mindspore.context import ParallelMode
 from mindspore.communication.management import init, get_rank, get_group_size
-from mindspore.train.callback import ModelCheckpoint
+from mindspore.train.callback import ModelCheckpoint, LossMonitor
 from mindspore.train.callback import CheckpointConfig, Callback, TimeMonitor
 from mindspore.train.model import Model
 from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
@@ -103,7 +103,7 @@ def Parse(arguments=None):
                         help='device where the code will be implemented (default: Ascend)')
     parser.add_argument('--resume', type=bool, default=False,
                         help='whether to resume the pretrained model')
-    parser.add_argument('--resume_path', type=str, default="/home/lidongsheng/ckpt/resnest-30_2502.ckpt",
+    parser.add_argument('--pretrained_ckpt_path', type=str, default="/home/lidongsheng/ckpt/resnest-30_2502.ckpt",
                         help='put the path to resuming file if needed')
     # training parameters
     parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
@@ -115,23 +115,55 @@ def Parse(arguments=None):
                         help="Evaluation interval when run_eval is True, default is 10.")
     parser.add_argument("--eval_start_epoch", type=int, default=1,
                         help="Evaluation start epoch when run_eval is True, default is 120.")
+    parser.add_argument("--device_num", type=int, default=1)
     # modelarts
     parser.add_argument('--is_model_arts', type=ast.literal_eval, default=False)
     parser.add_argument('--data_url', type=str)
     parser.add_argument('--train_url', type=str)
+
     arguments = parser.parse_args()
 
     return arguments
 
 set_seed(1)
 
+class LossCallBack(LossMonitor):
+    """
+    Monitor the loss in training.
+    If the loss in NAN or INF terminating training.
+    """
+
+    def __init__(self, has_trained_epoch=0):
+        super(LossCallBack, self).__init__()
+        self.has_trained_epoch = has_trained_epoch
+
+    def step_end(self, run_context):
+        cb_params = run_context.original_args()
+        loss1 = cb_params.net_outputs
+
+        if isinstance(loss1, (tuple, list)):
+            if isinstance(loss1[0], Tensor) and isinstance(loss1[0].asnumpy(), np.ndarray):
+                loss1 = loss1[0]
+
+        if isinstance(loss1, Tensor) and isinstance(loss1.asnumpy(), np.ndarray):
+            loss1 = np.mean(loss1.asnumpy())
+
+        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
+
+        if isinstance(loss1, float) and (np.isnan(loss1) or np.isinf(loss1)):
+            raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
+                cb_params.cur_epoch_num, cur_step_in_epoch))
+        if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
+            print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + int(self.has_trained_epoch),
+                                                      cur_step_in_epoch, loss1), flush=True)
 if __name__ == "__main__":
     print("================Start training================")
 
     args = Parse()
     target = args.device_target
-    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
-
+    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=True)
+    context.set_context(enable_graph_kernel=True)
+    config.lr = config.lr*args.device_num
     if args.run_distribute:
         if target == "Ascend":
             device_id = int(os.getenv('DEVICE_ID'))
@@ -149,7 +181,7 @@ if __name__ == "__main__":
             init()
             config.rank = get_rank()
             config.group_size = get_group_size()
-            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+            context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
                                               gradients_mean=True)
     else:
         try:
@@ -162,6 +194,7 @@ if __name__ == "__main__":
             config.group_size = 1
 
     # dataset
+
     if args.is_model_arts:
         import moxing as mox
         train_dataset_path = '/cache/dataset/'
@@ -179,6 +212,7 @@ if __name__ == "__main__":
                            batch_size=config.batch_size, num_parallel_workers=config.num_workers)
     config.steps_per_epoch = dataset.get_dataset_size()
 
+
     # net
     model_kwargs = {}
     if config.final_drop > 0.0:
@@ -190,9 +224,9 @@ if __name__ == "__main__":
     if args.resume:
         if args.is_model_arts:
             pretrained_ckpt_path = "/cache/pretrained/resnest50-270_2502.ckpt"
-            mox.file.copy_parallel(args.resume_path, pretrained_ckpt_path)
+            mox.file.copy_parallel(args.pretrained_ckpt_path, pretrained_ckpt_path)
         else:
-            pretrained_ckpt_path = args.resume_path
+            pretrained_ckpt_path = args.pretrained_ckpt_path
         net = get_network(config.net_name, args.resume, pretrained_ckpt_path, **model_kwargs)
     else:
         net = get_network(config.net_name, **model_kwargs)
@@ -229,23 +263,28 @@ if __name__ == "__main__":
     time_cb = TimeMonitor(data_size=config.steps_per_epoch)
     callbacks.append(time_cb)
 
+    if target == "GPU":
+        loss_cb = LossCallBack(0)
+        callbacks.append(loss_cb)
+
     # eval
     if args.run_eval:
-        if config.root is None or not os.path.isdir(config.root):
-            raise ValueError("{} is not a existing path.".format(config.root))
-        eval_dataset = ImageNet(config.root, mode="val",
-                                img_size=config.base_size, crop_size=config.crop_size,
-                                rank=config.rank, group_size=config.group_size, epoch=1,
-                                batch_size=config.batch_size, num_parallel_workers=config.num_workers)
-        val_step_size = eval_dataset.get_dataset_size()
-        eval_param_dict = {"model": model, "dataset": eval_dataset, "metrics_name": "acc"}
-        eval_callback = EvalCallBack(apply_eval,
-                                     eval_param_dict,
-                                     interval=args.eval_interval,
-                                     eval_start_epoch=args.eval_start_epoch,
-                                     metrics_name="acc"
-                                     )
-        callbacks.append(eval_callback)
+        if args.device_target == "Ascend":
+            if config.root is None or not os.path.isdir(config.root):
+                raise ValueError("{} is not a existing path.".format(config.root))
+            eval_dataset = ImageNet(config.root, mode="val",
+                                    img_size=config.base_size, crop_size=config.crop_size,
+                                    rank=config.rank, group_size=config.group_size, epoch=1,
+                                    batch_size=config.batch_size, num_parallel_workers=config.num_workers)
+            val_step_size = eval_dataset.get_dataset_size()
+            eval_param_dict = {"model": model, "dataset": eval_dataset, "metrics_name": "acc"}
+            eval_callback = EvalCallBack(apply_eval,
+                                         eval_param_dict,
+                                         interval=args.eval_interval,
+                                         eval_start_epoch=args.eval_start_epoch,
+                                         metrics_name="acc"
+                                         )
+            callbacks.append(eval_callback)
 
     if args.save_ckpt and config.rank == 0:
         ckpt_config = CheckpointConfig(save_checkpoint_steps=config.steps_per_epoch,