diff --git a/research/cv/wdsr/README_CN.md b/research/cv/wdsr/README_CN.md index 8c9cb7faa51e8fb3d0059b2133eded4031ca74b0..60ee554f8f36167d95dc207f662687f8758ab56b 100644 --- a/research/cv/wdsr/README_CN.md +++ b/research/cv/wdsr/README_CN.md @@ -111,7 +111,7 @@ WDSR缃戠粶涓昏鐢卞嚑涓熀鏈ā鍧楋紙鍖呮嫭鍗风Н灞傚拰姹犲寲灞傦級缁勬垚銆� # 鐜瑕佹眰 -- 纭欢锛圓scend锛� +- 纭欢锛圓scend/GPU锛� - 浣跨敤ascend澶勭悊鍣ㄦ潵鎼缓纭欢鐜銆� - 妗嗘灦 - [MindSpore](https://www.mindspore.cn/install/en) @@ -124,18 +124,24 @@ WDSR缃戠粶涓昏鐢卞嚑涓熀鏈ā鍧楋紙鍖呮嫭鍗风Н灞傚拰姹犲寲灞傦級缁勬垚銆� 閫氳繃瀹樻柟缃戠珯瀹夎MindSpore鍚庯紝鎮ㄥ彲浠ユ寜鐓у涓嬫楠よ繘琛岃缁冨拰璇勪及锛� ```shell -#鍗曞崱璁粌 +# 鍗曞崱璁粌 +# Ascend sh run_ascend_standalone.sh [TRAIN_DATA_DIR] +# GPU +bash run_gpu_standalone.sh [TRAIN_DATA_DIR] ``` ```shell -#鍒嗗竷寮忚缁� +# 鍒嗗竷寮忚缁� +# Ascend sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR] +# GPU +bash run_gpu_distribute.sh [TRAIN_DATA_DIR] [DEVICE_NUM] ``` ```python #璇勪及 -sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE] +bash run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE] ``` # 鑴氭湰璇存槑 @@ -145,11 +151,11 @@ sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE] ```bash WDSR 鈹溾攢鈹€ README_CN.md //鑷堪鏂囦欢 - 鈹溾攢鈹€ eval.py //璇勪及鑴氭湰 - 鈹溾攢鈹€ export.py 鈹溾攢鈹€ script 鈹偮犅� 鈹溾攢鈹€ run_ascend_distribute.sh //Ascend鍒嗗竷寮忚缁僺hell鑴氭湰 鈹偮犅� 鈹溾攢鈹€ run_ascend_standalone.sh //Ascend鍗曞崱璁粌shell鑴氭湰 + 鈹偮犅� 鈹溾攢鈹€ run_gpu_distribute.sh //GPU鍒嗗竷寮忚缁僺hell鑴氭湰 + 鈹偮犅� 鈹溾攢鈹€ run_gpu_standalone.sh //GPU鍗曞崱璁粌shell鑴氭湰 鈹偮犅� 鈹斺攢鈹€ run_eval.sh //eval楠岃瘉shell鑴氭湰 鈹溾攢鈹€ src 鈹偮犅� 鈹溾攢鈹€ args.py //瓒呭弬鏁� @@ -160,8 +166,10 @@ WDSR 鈹偮犅� 鈹偮犅� 鈹斺攢鈹€ srdata.py //鎵€鏈夋暟鎹泦 鈹偮犅� 鈹溾攢鈹€ metrics.py //PSNR鍜孲SIM璁$畻鍣� 鈹偮犅� 鈹溾攢鈹€ model.py //WDSR缃戠粶 - 鈹偮犅� 鈹斺攢鈹€ utils.py //璁粌鑴氭湰 - 鈹斺攢鈹€ train.py //璁粌鑴氭湰 + 鈹偮犅� 鈹斺攢鈹€ utils.py //杈呭姪鍑芥暟 + 鈹溾攢鈹€ train.py //璁粌鑴氭湰 + 鈹溾攢鈹€ eval.py //璇勪及鑴氭湰 + 鈹斺攢鈹€ export.py ``` ## 鑴氭湰鍙傛暟 @@ -169,43 +177,30 @@ WDSR 涓昏鍙傛暟濡備笅: ```python - -h, --help show this help message and exit - --dir_data DIR_DATA dataset directory - --data_train DATA_TRAIN - train dataset name - --data_test DATA_TEST - test dataset name - --data_range DATA_RANGE - train/test data range - --ext EXT dataset file extension - --scale SCALE super resolution scale - --patch_size PATCH_SIZE - output patch size - --rgb_range RGB_RANGE - maximum value of RGB - --n_colors N_COLORS number of color channels to use - --no_augment do not use data augmentation - --model MODEL model name - --n_resblocks N_RESBLOCKS - number of residual blocks - --n_feats N_FEATS number of feature maps - --res_scale RES_SCALE - residual scaling - --test_every TEST_EVERY - do test per every N batches - --epochs EPOCHS number of epochs to train - --batch_size BATCH_SIZE - input batch size for training - --test_only set this option to test the model - --lr LR learning rate - --ckpt_save_path CKPT_SAVE_PATH - path to save ckpt - --ckpt_save_interval CKPT_SAVE_INTERVAL - save ckpt frequency, unit is epoch - --ckpt_save_max CKPT_SAVE_MAX - max number of saved ckpt - --ckpt_path CKPT_PATH - path of saved ckpt + -h, --help show this help message and exit + --dir_data DIR_DATA dataset directory + --data_train DATA_TRAIN train dataset name + --data_test DATA_TEST test dataset name + --data_range DATA_RANGE train/test data range + --ext EXT dataset file extension + --scale SCALE super-resolution scale + --patch_size PATCH_SIZE output patch size + --rgb_range RGB_RANGE maximum value of RGB + --n_colors N_COLORS number of color channels to use + --no_augment do not use data augmentation + --model MODEL model name + --n_resblocks N_RESBLOCKS number of residual blocks + --n_feats N_FEATS number of feature maps + --res_scale RES_SCALE residual scaling + --test_every TEST_EVERY do test per every N batches + --epochs EPOCHS number of epochs to train + --batch_size BATCH_SIZE input batch size for training + --test_only set this option to test the model + --lr LR learning rate + --ckpt_path CKPT_PATH path of saved ckpt + --ckpt_save_path CKPT_SAVE_PATH path to save ckpt + --ckpt_save_interval CKPT_SAVE_INTERVAL save ckpt frequency, unit is epoch + --ckpt_save_max CKPT_SAVE_MAX max number of saved ckpt --task_id TASK_ID ``` @@ -220,6 +215,12 @@ WDSR sh run_ascend_standalone.sh [TRAIN_DATA_DIR] ``` +- GPU鐜杩愯 + + ```bash + sh run_gpu_standalone.sh [TRAIN_DATA_DIR] + ``` + 涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆� ### 鍒嗗竷寮忚缁� @@ -230,6 +231,12 @@ WDSR sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR] ``` +- GPU鐜杩愯 + + ```bash + sh run_gpu_distribute.sh [TRAIN_DATA_DIR] [DEVICE_NUM] + ``` + TRAIN_DATA_DIR = "~DATA/"銆� ## 璇勪及杩囩▼ @@ -260,33 +267,33 @@ FILE_FORMAT 鍙€� ['MINDIR', 'AIR', 'ONNX'], 榛樿['MINDIR']銆� ### 璁粌鎬ц兘 -| 鍙傛暟 | Ascend | -| ------------- | ------------------------------------------------------------ | -| 璧勬簮 | Ascend 910 | -| 涓婁紶鏃ユ湡 | 2021-7-4 | -| MindSpore鐗堟湰 | 1.2.0 | -| 鏁版嵁闆� | DIV2K | -| 璁粌鍙傛暟 | epoch=1000, steps=100, batch_size =16, lr=0.0001 | -| 浼樺寲鍣� | Adam | -| 鎹熷け鍑芥暟 | L1 | -| 杈撳嚭 | 瓒呭垎杈ㄧ巼鍥剧墖 | -| 鎹熷け | 3.5 | -| 閫熷害 | 8鍗★細绾�130姣/姝� | -| 鎬绘椂闀� | 8鍗★細0.5灏忔椂 | -| 寰皟妫€鏌ョ偣 | 35 MB(.ckpt鏂囦欢) | -| 鑴氭湰 | [WDSR](https://gitee.com/mindspore/models/tree/master/research/cv/wdsr) | +| 鍙傛暟 | Ascend | GPU| +| ------------- | ------------------------------------------------------------ |----| +| 璧勬簮 | Ascend 910 |NVIDIA GeForce RTX 3090| +| 涓婁紶鏃ユ湡 | 2021-7-4 |2021-11-22| +| MindSpore鐗堟湰 | 1.2.0 |1.5.0| +| 鏁版嵁闆� | DIV2K |DIV2K| +| 璁粌鍙傛暟 | epoch=1000, steps=100, batch_size =16, lr=0.0001 |epoch=300, batch_size=16, lr=0.0005| +| 浼樺寲鍣� | Adam |Adam| +| 鎹熷け鍑芥暟 | L1 |L1| +| 杈撳嚭 | 瓒呭垎杈ㄧ巼鍥剧墖 |瓒呭垎杈ㄧ巼鍥剧墖| +| 鎹熷け | 3.5 |3.3| +| 閫熷害 | 8鍗★細绾�130姣/姝� |8鍗★細绾�140姣/姝 +| 鎬绘椂闀� | 8鍗★細0.5灏忔椂 |8鍗★細1.5灏忔椂| +| 寰皟妫€鏌ョ偣 | 35 MB(.ckpt鏂囦欢) |14 MB(.ckpt鏂囦欢)| +| 鑴氭湰 | [WDSR](https://gitee.com/mindspore/models/tree/master/research/cv/wdsr) |[WDSR](https://gitee.com/mindspore/models/tree/master/research/cv/wdsr)| ### 璇勪及鎬ц兘 -| 鍙傛暟 | Ascend | -| ------------- | ----------------------------------------------------------- | -| 璧勬簮 | Ascend 910 | -| 涓婁紶鏃ユ湡 | 2021-7-4 | -| MindSpore鐗堟湰 | 1.2.0 | -| 鏁版嵁闆� | DIV2K | -| batch_size | 1 | -| 杈撳嚭 | 瓒呭垎杈ㄧ巼鍥剧墖 | -| PSNR | DIV2K 34.7780 | +| 鍙傛暟 | Ascend |GPU | +| ------------- | ----------------------------------------------------------- |----------------------| +| 璧勬簮 | Ascend 910 |NVIDIA GeForce RTX 3090| +| 涓婁紶鏃ユ湡 | 2021-7-4 |2021-11-22 | +| MindSpore鐗堟湰 | 1.2.0 |1.5.0 | +| 鏁版嵁闆� | DIV2K |DIV2K | +| batch_size | 1 |1 | +| 杈撳嚭 | 瓒呭垎杈ㄧ巼鍥剧墖 |瓒呭垎杈ㄧ巼鍥剧墖 | +| PSNR | DIV2K 34.7780 |DIV2K 35.9735 | # 闅忔満鎯呭喌璇存槑 diff --git a/research/cv/wdsr/eval.py b/research/cv/wdsr/eval.py index 97a785fc6c0e98a12e5c545f8a433834cbdd34a5..2f2c0f428db3719b93b6858475158f3b1fba95b6 100644 --- a/research/cv/wdsr/eval.py +++ b/research/cv/wdsr/eval.py @@ -20,13 +20,21 @@ from mindspore import Tensor, context from mindspore.common import dtype as mstype from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.args import args -import src.model as wdsr +from src.model import WDSR from src.data.srdata import SRData from src.data.div2k import DIV2K from src.metrics import calc_psnr, quantize, calc_ssim -device_id = int(os.getenv('DEVICE_ID', '0')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) -context.set_context(max_call_depth=10000) + +if args.device_target == 'GPU': + context.set_context(mode=context.GRAPH_MODE, + device_target=args.device_target, + save_graphs=False) + context.set_context(max_call_depth=10000) +elif args.device_target == 'Ascend': + device_id = int(os.getenv('DEVICE_ID', '0')) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) + context.set_context(max_call_depth=10000) + def eval_net(): """eval""" if args.epochs == 0: @@ -43,7 +51,7 @@ def eval_net(): train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False) train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) - net_m = wdsr.WDSR() + net_m = WDSR(scale=args.scale[args.task_id], n_resblocks=args.n_resblocks, n_feats=args.n_feats) if args.ckpt_path: param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(net_m, param_dict) diff --git a/research/cv/wdsr/script/run_eval.sh b/research/cv/wdsr/script/run_eval.sh index f8828810f9e9398794e3cff33fa3fd1150b0695a..2942fec300e7b9a85c983132320cce3a9eb8eb98 100644 --- a/research/cv/wdsr/script/run_eval.sh +++ b/research/cv/wdsr/script/run_eval.sh @@ -57,5 +57,6 @@ python eval.py \ --ext "img" \ --data_test=${DATASET_TYPE} \ --ckpt_path=${PATH2} \ + --data_range "801-900" \ --task_id 0 \ --scale 2 > eval.log 2>&1 & diff --git a/research/cv/wdsr/script/run_gpu_distribute.sh b/research/cv/wdsr/script/run_gpu_distribute.sh new file mode 100644 index 0000000000000000000000000000000000000000..836d431e1c5fa647300453f0d3c2415b61d768aa --- /dev/null +++ b/research/cv/wdsr/script/run_gpu_distribute.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: sh run_gpu_distribute.sh [TRAIN_DATA_DIR] [DEVICE_NUM]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +DEVICE_NUM=$2 + +if [ ! -d $PATH1 ]; then + echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory" + exit 1 +fi + +if [ -d "train_parallel" ]; then + rm -rf ./train_parallel +fi +mkdir ./train_parallel +cp ../*.py ./train_parallel +cp -r ../src ./train_parallel +cd ./train_parallel || exit + +env >env.log + +nohup mpirun --allow-run-as-root -n $DEVICE_NUM \ +python train.py \ + --run_distribute 1 \ + --device_num $DEVICE_NUM \ + --batch_size 16 \ + --lr 5e-4 \ + --scale 2 \ + --task_id 0 \ + --dir_data $PATH1 \ + --epochs 300 \ + --test_every 1000 \ + --patch_size 48 > train.log 2>&1 & diff --git a/research/cv/wdsr/script/run_gpu_standalone.sh b/research/cv/wdsr/script/run_gpu_standalone.sh new file mode 100644 index 0000000000000000000000000000000000000000..ac412fa1156f387b674195d254dd317b32a1f096 --- /dev/null +++ b/research/cv/wdsr/script/run_gpu_standalone.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ]; then + echo "Usage: sh run_gpu_standalone.sh [TRAIN_DATA_DIR]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) + +if [ ! -d $PATH1 ]; then + echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory" + exit 1 +fi + + +if [ -d "train" ]; then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp -r ../src ./train +cd ./train || exit + +env >env.log + +nohup python train.py \ + --batch_size 16 \ + --lr 1e-4 \ + --scale 2 \ + --task_id 0 \ + --dir_data $PATH1 \ + --epochs 300 \ + --test_every 1000 \ + --patch_size 48 > train.log 2>&1 & diff --git a/research/cv/wdsr/src/args.py b/research/cv/wdsr/src/args.py index 70b3bd43d08cfce76a665e8d306f7fadce39d4ad..358bd3ec74c63a90b138ab64e26c2c78f9c2f243 100644 --- a/research/cv/wdsr/src/args.py +++ b/research/cv/wdsr/src/args.py @@ -22,7 +22,7 @@ parser.add_argument('--data_train', type=str, default='DIV2K', help='train dataset name') parser.add_argument('--data_test', type=str, default='DIV2K', help='test dataset name') -parser.add_argument('--data_range', type=str, default='1-800/801-900', +parser.add_argument('--data_range', type=str, default='1-800/801-810', help='train/test data range') parser.add_argument('--ext', type=str, default='sep', help='dataset file extension') @@ -61,7 +61,7 @@ parser.add_argument('--init_loss_scale', type=float, default=65536., help='scaling factor') parser.add_argument('--loss_scale', type=float, default=1024.0, help='loss_scale') -parser.add_argument('--decay', type=str, default='200', +parser.add_argument('--decay', type=int, default=200, help='learning rate decay type') parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') @@ -80,8 +80,21 @@ parser.add_argument('--ckpt_save_max', type=int, default=5, help='max number of saved ckpt') parser.add_argument('--ckpt_path', type=str, default='', help='path of saved ckpt') +# sr result specifications +parser.add_argument('--save_dir', type=str, default='result', + help='file name to save') +parser.add_argument('--save_result', action='store_true', + help='save output results') # alltask parser.add_argument('--task_id', type=int, default=0) +parser.add_argument('--pre_trained', type=str, default='', help='model_path, local pretrained model to load') +parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"), + help="Device target, support GPU.") +parser.add_argument("--run_distribute", type=int, default=False, + help="Run distribute, default: false.") +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.") +parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") # rgb_mean parser.add_argument('--r_mean', type=float, default=0.4488, help='r_mean') @@ -89,6 +102,7 @@ parser.add_argument('--g_mean', type=float, default=0.4371, help='g_mean') parser.add_argument('--b_mean', type=float, default=0.4040, help='b_mean') + args, unparsed = parser.parse_known_args() args.scale = [int(x) for x in args.scale.split("+")] args.data_train = args.data_train.split('+') diff --git a/research/cv/wdsr/src/model.py b/research/cv/wdsr/src/model.py index 50922dcaa50fa3a6a1d42a5e026378e55303a04e..b10a950aafd178b4910dabc1fbb1ff28bcc79b07 100644 --- a/research/cv/wdsr/src/model.py +++ b/research/cv/wdsr/src/model.py @@ -65,11 +65,8 @@ class PixelShuffle(nn.Cell): class WDSR(nn.Cell): """main structure of wdsr""" - def __init__(self): + def __init__(self, scale=2, n_resblocks=8, n_feats=64): super(WDSR, self).__init__() - scale = 2 - n_resblocks = 8 - n_feats = 64 self.sub_mean = MeanShift(255) self.add_mean = MeanShift(255, sign=1) # define head module diff --git a/research/cv/wdsr/train.py b/research/cv/wdsr/train.py index a858bb6d6b0bc85ca19f2f8bb5569c8c16065452..345043579c21cc0adae57a8137457beb13ec4800 100644 --- a/research/cv/wdsr/train.py +++ b/research/cv/wdsr/train.py @@ -19,7 +19,7 @@ from mindspore import dataset as ds import mindspore.nn as nn from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init +from mindspore.communication.management import init, get_rank from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.common import set_seed from mindspore.train.model import Model @@ -31,21 +31,47 @@ from src.model import WDSR def train_net(): """train wdsr""" set_seed(1) - device_id = int(os.getenv('DEVICE_ID', '0')) - rank_id = int(os.getenv('RANK_ID', '0')) - device_num = int(os.getenv('RANK_SIZE', '1')) + if args.device_target == 'GPU': + context.set_context(mode=context.GRAPH_MODE, + device_target=args.device_target, + device_id=args.device_id, + save_graphs=False) + elif args.device_target == 'Ascend': + device_id = int(os.getenv('DEVICE_ID', '0')) + rank_id = int(os.getenv('RANK_ID', '0')) + device_num = int(os.getenv('RANK_SIZE', '1')) + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + save_graphs=False, + device_id=device_id) + rank = 0 # if distribute: - if device_num > 1: - init() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, - device_num=device_num, gradients_mean=True) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) + if args.run_distribute: + print("distribute") + if args.device_target == 'GPU': + init("nccl") + device_num = args.device_num + context.reset_auto_parallel_context() + rank = get_rank() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + elif args.device_target == 'Ascend': + init() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + device_num=device_num, gradients_mean=True) + train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False) train_dataset.set_scale(args.task_id) - train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num, - shard_id=rank_id, shuffle=True) + if args.device_target == 'GPU': + train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], + num_shards=args.device_num, + shard_id=rank, shuffle=True) + elif args.device_target == 'Ascend': + train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], + num_shards=device_num, + shard_id=rank_id, shuffle=True) train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True) - net_m = WDSR() + net_m = WDSR(scale=args.scale[args.task_id], n_resblocks=args.n_resblocks, n_feats=args.n_feats) print("Init net successfully") if args.ckpt_path: param_dict = load_checkpoint(args.ckpt_path) @@ -54,7 +80,7 @@ def train_net(): step_size = train_de_dataset.get_dataset_size() lr = [] for i in range(0, args.epochs): - cur_lr = args.lr / (2 ** ((i + 1)//200)) + cur_lr = args.lr / (2 ** ((i + 1)//args.decay)) lr.extend([cur_lr] * step_size) opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale) loss = nn.L1Loss() @@ -67,7 +93,7 @@ def train_net(): config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size, keep_checkpoint_max=args.ckpt_save_max) ckpt_cb = ModelCheckpoint(prefix="wdsr", directory=args.ckpt_save_path, config=config_ck) - if device_id == 0: + if rank == 0: cb += [ckpt_cb] model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)