diff --git a/official/cv/brdnet/README_CN.md b/official/cv/brdnet/README_CN.md index 2b24fa9da927505b9a648b0314821a29e492b8f6..2683fb912110d3ff14034a78fe489f0db55cb962 100644 --- a/official/cv/brdnet/README_CN.md +++ b/official/cv/brdnet/README_CN.md @@ -118,6 +118,7 @@ Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/models/tree/m ├── scripts │ ├──run_distribute_train.sh // Ascend 8卡训练脚本 │ ├──run_eval.sh // 推理启动脚本 + │ ├──run_eval_onnx_gpu.sh // ONNX模型推理启动脚本 │ ├──run_train.sh // 训练启动脚本 │ ├──run_infer_310.sh // 启动310推理的脚本 │ ├──docker_start.sh // 使用 MindX 推理时的 docker 启动脚本 @@ -125,9 +126,11 @@ Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/models/tree/m │ ├──dataset.py // 数据集处理 │ ├──logger.py // 日志打印文件 │ ├──models.py // 模型结构 + │ ├──models_onnx.py // 为适配ONNX关闭自动Padding的模型结构 ├── export.py // 将权重文件导出为 MINDIR 等格式的脚本 ├── train.py // 训练脚本 ├── eval.py // 推理脚本 + ├── eval_onnx.py // onnx模型推理脚本 ├── cal_psnr.py // 310推理时计算最终PSNR值的脚本 ├── preprocess.py // 310推理时为测试图片添加噪声的脚本 ``` @@ -170,6 +173,18 @@ eval.py 中的主要参数如下: --outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效) --device_target: 运行设备(默认 "Ascend") +eval_onnx.py 中的主要参数如下: +--test_dir: 测试数据集路径(必须以"/"结尾)。 +--sigma: 高斯噪声强度 +--channel: 推理类型(3:彩色图;1:灰度图) +--onnx_name: 需要使用的brdnet的ONNX模型的路径 +--use_modelarts: 是否使用 modelarts(1 for True, 0 for False; 设置为 1 时将使用 moxing 从 obs 拷贝数据) +--train_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用) +--data_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用) +--output_path: 日志等文件输出目录 +--outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效) +--device_target: 运行设备(默认 "GPU") + export.py 中的主要参数如下: --batch_size: 批次大小 --channel: 训练类型(3:彩色图;1:灰度图) @@ -402,7 +417,7 @@ cal_psnr.py 中的主要参数如下: --output_path=./output/ \ --is_distributed=0 > log.txt 2>&1 & - #通过 bash 命令启动评估 (对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾) + #通过 bash 命令启动评估 bash run_eval.sh [train_code_path] [test_dir] [sigma] [channel] [pretrain_path] [ckpt_name] ``` @@ -489,6 +504,62 @@ cal_psnr.py 中的主要参数如下: | 67 | 68 | 24.61875916 | 32.24455261 | 0.623004258 | 0.97843051 | | 68 | Average | 24.61875916 | 34.05390495 | 0.555872787 | 0.935704286 | +- 导出ONNX模型 + + ```python + 假设基于waterloo数据集,图片分辨率为50*50 + 当前目录brdnet,运行: + python export.py --image_height=50 --image_width=50 --file_format=ONNX --device_target=GPU --ckpt_file=ckpt模型路径 + 即可得到brdnet.onnx文件 + ``` + +- ONNX模型评估 + + ```python + #通过直接运行brdnet文件夹下的eval_onnx.py脚本进行评估,例子如下: + 测试用的数据集存放于brdnet文件夹下的waterloo5050step40colorimage文件夹 + 导出的ONNX模型存放于brdnet文件夹下的brdnet.onnx文件 + 测试在GPU和CPU上运行通过,当前目录brdnet,运行: + python eval_onnx.py --test_dir=./waterloo5050step40colorimage/ --onnx_name=./brdnet.onnx --device_target=GPU + + !!!注意:其它可修改参数如sigma、channel、output_path等的修改办法 + 1. 可以通过直接修改default_config.yaml中对应参数的值进行修改 + 2. 在运行脚本时在后面加上 --参数名=设置值 进行修改 + ``` + + ```python + #通过 bash 命令运行brdnet/scripts/run_eval_onnx_gpu.sh启动评估 (对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾) + bash run_eval_onnx_gpu.sh [ONNX_NAME] [TESTSET_PATH] + 具体例子 + 测试用的数据集存放于brdnet文件夹下的waterloo5050step40colorimage文件夹 + 导出的ONNX模型存放于brdnet文件夹下的brdnet.onnx文件 + 当前目录brdnet/scripts,运行: + bash run_eval_onnx_gpu.sh ../brdnet.onnx ../waterloo5050step40colorimage/ + 运行结果将保存于brdnet/eval/output文件夹下 + + !!!注意:使用run_eval_onnx_gpu.sh启动评估将在后台运行,可使用ps -u找到在后台运行评估的进程 + ``` + + ```python + 评估完成后,您可以在 --output_path 参数指定的目录下找到 加高斯噪声后的图片和经过模型去除高斯噪声后的图片,图片命名方式代表了处理结果。例如 00001_sigma15_psnr24.62.bmp 是加噪声后的图片(加噪声后 psnr=24.62),00001_psnr31.18.bmp 是去噪声后的图片(去噪后 psnr=31.18)。 + 同时,您还可以在 --output_path 参数指定的目录下找到测试结果记录文件,结果记录文件按照日期-时间的方式命名,如2022-07-14_time_16_29_22_rank_0.log,通过 bash 命令运行的结果默认放在brdnet/eval文件夹下。 + 打开log文件即可查看ONNX模型验证的结果,验证结果示例如下: + 2022-07-14 16:31:52,317:INFO:Start to test on /data1/models/official/cv/brdnet/waterloo5050step40colorimage + 2022-07-14 16:31:53,317:INFO:Start to test on load test weights from /data1/models/official/cv/brdnet/brdnet.onnx + 2022-07-14 16:31:54,528:INFO:start testing.... + 2022-07-14 18:42:41,267:INFO:Before denoise: Average PSNR_b = 24.712742, SSIM_b = 0.520469;After denoise: Average PSNR = 36.140112, SSIM = 0.943783 + 2022-07-14 18:42:41,294:INFO:testing finished.... + 2022-07-14 18:42:41,294:INFO:time cost: 7846.765776634216 seconds! + + 以下为官网原始权重brdnet_ascend_v170_waterloo_official_cv_PSNR36.14.ckpt的验证结果,可以看到计算结果相同。 + 2022-07-15 11:16:58,964:INFO:Start to test on ./waterloo5050step40colorimage/ + 2022-07-15 11:16:59,089:INFO:load test weights from brdnet_ascend_v170_waterloo_official_cv_PSNR36.14.ckpt + 2022-07-15 11:17:01,400:INFO:start testing.... + 2022-07-15 14:15:06,262:INFO:Before denoise: Average PSNR_b = 24.712742, SSIM_b = 0.520469;After denoise: Average PSNR = 36.140112, SSIM = 0.943783 + 2022-07-15 14:15:06,262:INFO:testing finished.... + 2022-07-15 14:15:06,262:INFO:time cost: 10684.862426519394 seconds! + ``` + ### 310推理 - 在 Ascend 310 处理器环境运行 diff --git a/official/cv/brdnet/default_config.yaml b/official/cv/brdnet/default_config.yaml index 6a43860ad80e6cd5e5658d3957108f666cb13f30..365b57dfcf1fad16bb33e72f468de9da9c94c299 100644 --- a/official/cv/brdnet/default_config.yaml +++ b/official/cv/brdnet/default_config.yaml @@ -16,6 +16,8 @@ output_path: './output/' outer_path: 's3://output/' test_dir: '' +onnx_name: '' + pretrain_path: '' ckpt_name: '' @@ -24,4 +26,4 @@ is_distributed: False rank: 0 group_size: 1 is_save_on_master: True -ckpt_save_max: 5 \ No newline at end of file +ckpt_save_max: 5 diff --git a/official/cv/brdnet/eval_onnx.py b/official/cv/brdnet/eval_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e3784459e2171a5f272c58cf64459b91352d25 --- /dev/null +++ b/official/cv/brdnet/eval_onnx.py @@ -0,0 +1,175 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import datetime +import os +import time +import glob +import pandas as pd +import numpy as np +import PIL.Image as Image + +import mindspore +import mindspore.nn as nn +from mindspore import context +from mindspore.common import set_seed +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor + +from src.logger import get_logger +from src.config import config as cfg + +import onnxruntime as ort + +def create_session(checkpoint_path, target_device): + if target_device == 'GPU': + providers = ['CUDAExecutionProvider'] + elif target_device == 'CPU': + providers = ['CPUExecutionProvider'] + else: + raise ValueError(f'Unsupported target device {target_device!r}. Expected one of: "CPU", "GPU"') + session = ort.InferenceSession(checkpoint_path, providers=providers) + input_name = session.get_inputs()[0].name + return session, input_name + +def test(model_path): + session, input_name = create_session(model_path, cfg.device_target) + cfg.logger.info('Start to test on %s', str(cfg.test_dir)) + out_dir = os.path.join(save_dir, cfg.test_dir.split('/')[-2]) # cfg.test_dir must end by '/' + # print(out_dir) + if not cfg.use_modelarts and not os.path.exists(out_dir): + os.makedirs(out_dir) + cfg.logger.info('load test weights from %s', str(model_path)) + + name = [] + psnr = [] #after denoise + ssim = [] #after denoise + psnr_b = [] #before denoise + ssim_b = [] #before denoise + + if cfg.use_modelarts: + cfg.logger.info("copying test dataset from obs to cache....") + mox.file.copy_parallel(cfg.test_dir, 'cache/test') + cfg.logger.info("copying test dataset finished....") + cfg.test_dir = 'cache/test/' + + file_list = glob.glob(os.path.join(cfg.test_dir, "*")) + + cast = P.Cast() + compare_psnr = nn.PSNR() + compare_ssim = nn.SSIM() + + cfg.logger.info("start testing....") + start_time = time.time() + for file in file_list: + suffix = file.split('.')[-1] + # read image + if cfg.channel == 3: + img_clean = np.array(Image.open(file), dtype='float32') / 255.0 + else: + img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), \ + dtype='float32') / 255.0, axis=2) + np.random.seed(0) #obtain the same random data when it is in the test phase + img_test = img_clean + np.random.normal(0, cfg.sigma/255.0, img_clean.shape) + + + img_clean = np.float32(np.expand_dims(np.transpose(img_clean, (2, 0, 1)), 0)) + img_test = np.float32(np.expand_dims(np.transpose(img_test, (2, 0, 1)), 0)) + + y_predict = session.run(None, {input_name: img_test})[0] + # print(y_predict.shape) + img_clean = Tensor(img_clean, mindspore.float32) + img_test = Tensor(img_test, mindspore.float32) + y_predict = Tensor(y_predict, mindspore.float32) + + img_out = C.clip_by_value(y_predict, 0, 1) + + psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test), compare_psnr(img_clean, img_out) + ssim_noise, ssim_denoised = compare_ssim(img_clean, img_test), compare_ssim(img_clean, img_out) + + + psnr.append(psnr_denoised.asnumpy()[0]) + ssim.append(ssim_denoised.asnumpy()[0]) + psnr_b.append(psnr_noise.asnumpy()[0]) + ssim_b.append(ssim_noise.asnumpy()[0]) + + # save images + filename = file.split('/')[-1].split('.')[0] # get the name of image file + name.append(filename) + + if not cfg.use_modelarts: + # inner the operation 'Image.save', it will first check the file \# + # existence of same name, that is not allowed on modelarts + img_test = cast(img_test*255, mindspore.uint8).asnumpy() + img_test = img_test.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image + img_test = Image.fromarray(img_test) + img_test.save(os.path.join(out_dir, filename+'_sigma'+'{}_psnr{:.2f}.'\ + .format(cfg.sigma, psnr_noise.asnumpy()[0])+str(suffix))) + img_out = cast(img_out*255, mindspore.uint8).asnumpy() + img_out = img_out.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image + img_out = Image.fromarray(img_out) + img_out.save(os.path.join(out_dir, filename+'_psnr{:.2f}.'.format(psnr_denoised.asnumpy()[0])+str(suffix))) + + psnr_avg = sum(psnr)/len(psnr) + ssim_avg = sum(ssim)/len(ssim) + psnr_avg_b = sum(psnr_b)/len(psnr_b) + ssim_avg_b = sum(ssim_b)/len(ssim_b) + name.append('Average') + psnr.append(psnr_avg) + ssim.append(ssim_avg) + psnr_b.append(psnr_avg_b) + ssim_b.append(ssim_avg_b) + cfg.logger.info('Before denoise: Average PSNR_b = {0:.6f}, SSIM_b = {1:.6f};After denoise: Average PSNR = {2:.6f},' + ' SSIM = {3:.6f}'.format(psnr_avg_b, ssim_avg_b, psnr_avg, ssim_avg)) + cfg.logger.info("testing finished....") + time_used = time.time() - start_time + cfg.logger.info("time cost: %s seconds!", str(time_used)) + if not cfg.use_modelarts: + pd.DataFrame({'name': np.array(name), 'psnr_b': np.array(psnr_b), \ + 'psnr': np.array(psnr), 'ssim_b': np.array(ssim_b), \ + 'ssim': np.array(ssim)}).to_csv(out_dir+'/metrics.csv', index=True) + +if __name__ == '__main__': + set_seed(1) #设置种子 + + save_dir = os.path.join(cfg.output_path, 'sigma_' + str(cfg.sigma) + \ + '_' + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + + if not cfg.use_modelarts and not os.path.exists(save_dir): + os.makedirs(save_dir) + + context.set_context(mode=context.GRAPH_MODE, + device_target=cfg.device_target, save_graphs=False) + + if cfg.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID', '0')) + context.set_context(device_id=device_id) + + cfg.logger = get_logger(save_dir, "BRDNet", 0) + cfg.logger.save_args(cfg) + + if cfg.use_modelarts: + import moxing as mox + cfg.logger.info("copying test weights from obs to cache....") + mox.file.copy_parallel(cfg.pretrain_path, 'cache/weight') + cfg.logger.info("copying test weights finished....") + cfg.pretrain_path = 'cache/weight/' + + test(cfg.onnx_name) # onnx model url + + if cfg.use_modelarts: + cfg.logger.info("copying files from cache to obs....") + mox.file.copy_parallel(save_dir, cfg.outer_path) + cfg.logger.info("copying finished....") diff --git a/official/cv/brdnet/export.py b/official/cv/brdnet/export.py index 569915f985b1c527669d3c1855e13d37a3ef003a..6c5c94ab1e73ec17edb02a0436fed26a1708c214 100644 --- a/official/cv/brdnet/export.py +++ b/official/cv/brdnet/export.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import mindspore as ms from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export from src.models import BRDNet - +from src.models_onnx import BRDNet_onnx ## Params parser = argparse.ArgumentParser() @@ -42,11 +42,14 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id") args_opt = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, + device_id=args_opt.device_id) if __name__ == '__main__': - - net = BRDNet(args_opt.channel) + if args_opt.file_format == "ONNX": + net = BRDNet_onnx(args_opt.channel) + else: + net = BRDNet(args_opt.channel) param_dict = load_checkpoint(args_opt.ckpt_file) load_param_into_net(net, param_dict) diff --git a/official/cv/brdnet/scripts/run_eval_onnx_gpu.sh b/official/cv/brdnet/scripts/run_eval_onnx_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..c9ace871f003104419e8ab3327029bee8b48ca51 --- /dev/null +++ b/official/cv/brdnet/scripts/run_eval_onnx_gpu.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: + bash run_eval_onnx_gpu.sh [ONNX_NAME] [TESTSET_PATH] + " + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" +fi +} + +BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd) + +ONNX_NAME=$(get_real_path $1) +echo "ONNX_NAME: "$ONNX_NAME + +TESTSET_PATH=$(get_real_path $2) +echo "TESTSET_PATH: "$TESTSET_PATH + + +if [ ! -f $ONNX_NAME ] +then + echo "error: ONNX_NAME=$ONNX_NAME is not a file." +exit 1 +fi + +if [ ! -d $TESTSET_PATH ] +then + echo "error: TESTSET_PATH=$TESTSET_PATH is not a directory." +exit 1 +fi + +export PYTHONPATH=${BASE_PATH}:$PYTHONPATH + +if [ -d "../eval" ]; +then + rm -rf ../eval +fi +mkdir ../eval +cd ../eval || exit + +echo "Evaluating on GPU..." +echo +env > env.log +pwd +echo +python ${BASE_PATH}/../eval_onnx.py --onnx_name=$ONNX_NAME --test_dir=$TESTSET_PATH --device_target=GPU &> eval.log & diff --git a/official/cv/brdnet/src/models_onnx.py b/official/cv/brdnet/src/models_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..fff420cb97cd670024cc017ced631c04b40a41a1 --- /dev/null +++ b/official/cv/brdnet/src/models_onnx.py @@ -0,0 +1,149 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn +import mindspore.ops as ops +#from batch_renorm import BatchRenormalization + +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_group_size +from mindspore.ops import functional as F + +class BRDNet_onnx(nn.Cell): + """ + args: + channel: 3 for color, 1 for gray + """ + def __init__(self, channel): + super(BRDNet_onnx, self).__init__() + + self.Conv2d_1 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), has_bias=True) + self.BRN_1 = nn.BatchNorm2d(64, eps=1e-3) + self.layer1 = self.make_layer1(15) + self.Conv2d_2 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), has_bias=True) + self.Conv2d_3 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), has_bias=True) + self.BRN_2 = nn.BatchNorm2d(64, eps=1e-3) + self.layer2 = self.make_layer2(7) + self.Conv2d_4 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), has_bias=True) + self.BRN_3 = nn.BatchNorm2d(64, eps=1e-3) + self.layer3 = self.make_layer2(6) + self.Conv2d_5 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), has_bias=True) + self.BRN_4 = nn.BatchNorm2d(64, eps=1e-3) + self.Conv2d_6 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), has_bias=True) + self.Conv2d_7 = nn.Conv2d(channel*2, channel, kernel_size=(3, 3), stride=(1, 1), has_bias=True) + self.relu = nn.ReLU() + self.sub = ops.Sub() + self.concat = ops.Concat(axis=1)#NCHW + + def make_layer1(self, nums): + layers = [] + assert nums > 0 + for _ in range(nums): + layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), has_bias=True)) + layers.append(nn.BatchNorm2d(64, eps=1e-3)) + layers.append(nn.ReLU()) + return nn.SequentialCell(layers) + + def make_layer2(self, nums): + layers = [] + assert nums > 0 + for _ in range(nums): + layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), has_bias=True)) + layers.append(nn.ReLU()) + return nn.SequentialCell(layers) + + def construct(self, inpt): + #inpt-----> 'NCHW' + x = self.Conv2d_1(inpt) + x = self.BRN_1(x) + x = self.relu(x) + # 15 layers, Conv+BN+relu + x = self.layer1(x) + + # last layer, Conv + x = self.Conv2d_2(x) #for output channel, gray is 1 color is 3 + x = self.sub(inpt, x) # input - noise + + y = self.Conv2d_3(inpt) + y = self.BRN_2(y) + y = self.relu(y) + + # first Conv+relu's + y = self.layer2(y) + + y = self.Conv2d_4(y) + y = self.BRN_3(y) + y = self.relu(y) + + # second Conv+relu's + y = self.layer3(y) + + y = self.Conv2d_5(y) + y = self.BRN_4(y) + y = self.relu(y) + + y = self.Conv2d_6(y)#for output channel, gray is 1 color is 3 + y = self.sub(inpt, y) # input - noise + + o = self.concat((x, y)) + z = self.Conv2d_7(o)#gray is 1 color is 3 + z = self.sub(inpt, z) + + return z + +class BRDWithLossCell(nn.Cell): + def __init__(self, network): + super(BRDWithLossCell, self).__init__() + self.network = network + self.loss = nn.MSELoss(reduction='sum') #we use 'sum' instead of 'mean' to avoid + #the loss becoming too small + def construct(self, images, targets): + output = self.network(images) + return self.loss(output, targets) + +class TrainingWrapper(nn.Cell): + """Training wrapper.""" + def __init__(self, network, optimizer, sens=1.0): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("gradients_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, *args): + weights = self.weights + loss = self.network(*args) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*args, sens) + if self.reducer_flag: + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads))