Skip to content
Snippets Groups Projects
Unverified Commit 081fb29d authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1344 add ESRGAN GPU

Merge pull request !1344 from yusi-wang/ESRGAN
parents 0a12b587 685f62d4
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,11 @@ The ESRGAN contains a generation network and a discriminator network.
Train ESRGAN Dataset used: [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)
- Note: Please run the script extract_subimages.py to crop sub-images, then use the sub-images to train the model
- Note: Before training, please modify the dataset path, align DIV2K dataset path in src/util/extract_subimages.py, and run the script:
```shell
python src/util/extract_subimages.py
```
- Note: Data will be processed in src/dataset/traindataset.py
......@@ -62,7 +66,10 @@ The process of training ESRGAN needs a pretrained VGG19 based on Imagenet.
ESRGAN
├─ README.md # descriptions about ESRGAN
├── scripts
├── scripts
├─ run_distribute_train_gpu.sh # launch GPU training(8 pcs)
├─ run_eval_gpu.sh # launch GPU eval
├─ run_stranalone_train_gpu.sh # launch GPU training(1 pcs)
├─ run_distribute_train.sh # launch ascend training(8 pcs)
├─ run_eval.sh # launch ascend eval
└─ run_stranalone_train.sh # launch ascend training(1 pcs)
......@@ -93,18 +100,39 @@ ESRGAN
```shell
# distributed training
Ascend:
Usage: bash run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE] [LRPATH] [GTPATH] [VGGCKPT] [VPSNRLRPATH] [VPSNRGTPATH] [VGANLRPATH] [VGANGTPATH]
# The meaning of the parameters: DEVICE_NUM(Number of machines) DISTRIBUTE(Whether to use multiple machines) RANK_TABLE_FILE(Machine configuration file) LRPATH(LR training data set picture location) GTPATH(HR training data set picture location) VGGCKPT(VGG19 pre-training parameter position) VPSNRLRPATH(Set5 test set LR picture position) VPSNRGTPATH(Set5 test set HR picture location) VGANLRPATH(Set14 test set LR picture position) VGANGTPATH(Set14 test set HR picture location)
eg: bash run_distribute_train.sh 8 1 ./hccl_8p.json /data/DIV2K/DIV2K_train_LR_bicubic/X4_sub /data/DIV2K/DIV2K_train_HR_sub /home/HEU_535/A8/used/GAN_MD/VGG.ckpt /data/DIV2K/Set5/LRbicx4 /data/DIV2K/Set5/GTmod12 /data/DIV2K/Set14/LRbicx4 /data/DIV2K/Set14/GTmod12
GPU:
Usage: bash run_distribute_train_gpu.sh [DEVICE_NUM] [LRPATH] [GTPATH] [VGGCKPT] [VPSNRLRPATH] [VPSNRGTPATH] [VGANLRPATH] [VGANGTPATH]
# The meaning of the parameters: DEVICE_NUM(Number of machines) LRPATH(LR training data set picture location) GTPATH(HR training data set picture location) VGGCKPT(VGG19 pre-training parameter position) VPSNRLRPATH(Set5 test set LR picture position) VPSNRGTPATH(Set5 test set HR picture location) VGANLRPATH(Set14 test set LR picture position) VGANGTPATH(Set14 test set HR picture location)
eg: bash run_distribute_train_gpu.sh 8 /data/DIV2K/DIV2K_train_LR_bicubic/X4_sub /data/DIV2K/DIV2K_train_HR_sub /home/HEU_535/A8/used/GAN_MD/VGG.ckpt /data/DIV2K/Set5/LRbicx4 /data/DIV2K/Set5/GTmod12 /data/DIV2K/Set14/LRbicx4 /data/DIV2K/Set14/GTmod12
# standalone training
Ascend:
Usage: bash run_standalone_train.sh [DEVICE_ID] [LRPATH] [GTPATH] [VGGCKPT] [VPSNRLRPATH] [VPSNRGTPATH] [VGANLRPATH] [VGANGTPATH]
# The meaning of the parameters DEVICE_ID(Machine ID) LRPATH(LR training data set picture location) GTPATH(HR training data set picture location) VGGCKPT(VGG19 pre-training parameter position) VPSNRLRPATH(Set5 test set LR picture position) VPSNRGTPATH(Set5 test set HR picture position) VGANLRPATH(Set14 test set LR picture position) VGANGTPATH(Set14 test set HR picture position)
eg: bash run_standalone_train.sh 0 /data/DIV2K/DIV2K_train_LR_bicubic/X4_sub /data/DIV2K/DIV2K_train_HR_sub /home/HEU_535/A8/used/GAN_MD/VGG.ckpt /data/DIV2K/Set5/LRbicx4 /data/DIV2K/Set5/GTmod12 /data/DIV2K/Set14/LRbicx4 /data/DIV2K/Set14/GTmod12
GPU:
Usage: bash run_standalone_train_gpu.sh [DEVICE_ID] [LRPATH] [GTPATH] [VGGCKPT] [VPSNRLRPATH] [VPSNRGTPATH] [VGANLRPATH] [VGANGTPATH]
# The meaning of the parameters DEVICE_ID(Machine ID) LRPATH(LR training data set picture location) GTPATH(HR training data set picture location) VGGCKPT(VGG19 pre-training parameter position) VPSNRLRPATH(Set5 test set LR picture position) VPSNRGTPATH(Set5 test set HR picture position) VGANLRPATH(Set14 test set LR picture position) VGANGTPATH(Set14 test set HR picture position)
eg: bash run_standalone_train_gpu.sh 0 /data/DIV2K/DIV2K_train_LR_bicubic/X4_sub /data/DIV2K/DIV2K_train_HR_sub /home/HEU_535/A8/used/GAN_MD/VGG.ckpt /data/DIV2K/Set5/LRbicx4 /data/DIV2K/Set5/GTmod12 /data/DIV2K/Set14/LRbicx4 /data/DIV2K/Set14/GTmod12
```
### [Training Result](#content)
......@@ -113,13 +141,21 @@ Training result will be stored in ckpt/train_parallel0/ckpt. You can find checkp
### [Evaluation Script Parameters](#content)
- Run `run_eval.sh` for evaluation.
- Run `run_eval.sh` or `run_eval_gpu.sh` for evaluation.
```bash
# evaling
Ascend:
Usage: bash run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]
eg: bash run_eval.sh ./ckpt/psnr_best.ckpt /data/DIV2K/Set5/LRbicx4 /data/DIV2K/Set5/GTmod12 0
eg: bash run_eval.sh /ckpt/psnr_best.ckpt /data/DIV2K/Set5/LRbicx4 /data/DIV2K/Set5/GTmod12 0
GPU:
Usage: bash run_eval_gpu.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]
eg: bash run_eval_gpu.sh /ckpt/psnr_best.ckpt /data/DIV2K/Set5/LRbicx4 /data/DIV2K/Set5/GTmod12 0
```
### [Evaluation result](#content)
......@@ -132,20 +168,19 @@ Evaluation result will be stored in the ./result. Under this, you can find gener
### Training Performance
| Parameters | |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | ESRGAN |
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz, 192cores; Memory 755G |
| MindSpore Version | 1.3.0 |
| Dataset | DIV2K |
| Training Parameters | step=1000000+400000, batch_size = 16 |
| Optimizer | Adam |
| Loss Function | BCEWithLogitsLoss L1Loss VGGLoss |
| outputs | super-resolution pictures |
| Accuracy | Set5 psnr 32.56, Set14 psnr 26.23 |
| Speed | 1pc(Ascend): 212,216 ms/step; 8pcs: 77,118 ms/step |
| Total time | 8pcs: 36h |
| Checkpoint for Fine tuning | 64.86M (.ckpt file) |
| Parameters | Ascend 910 | NVIDIA GeForce RTX 3090 |
| -------------------------- | ----------------------------------------------------------- |------------------------------------------------|
| Model Version | V1 | V1 |
| MindSpore Version | 1.3.0 | 1.6.0 |
| Dataset | DIV2K | DIV2K |
| Training Parameters | step=1000000+400000, batch_size = 16 | step=1000000+400000, batch_size = 16 |
| Optimizer | Adam | Adam |
| Loss Function | BCEWithLogitsLoss L1Loss VGGLoss | BCEWithLogitsLoss L1Loss VGGLoss |
| outputs | super-resolution pictures | super-resolution pictures |
| Accuracy | Set5 psnr 32.56, Set14 psnr 26.23 | Set5 psnr 30.37, Set14 psnr 26.51 |
| Speed | 1pc(Ascend): 212,216 ms/step; 8pcs: 77,118 ms/step | 8pcs:239ms/step + 409ms/step |
| Total time | 8pcs: 36h | |
| Checkpoint for Fine tuning | 64.86M (.ckpt file) |64.86M (.ckpt file) |
| Scripts | [esrgan script](https://gitee.com/mindspore/models/tree/master/research/cv/ESRGAN) |
# [ModelZoo Homepage](#contents)
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -32,10 +32,11 @@ parser.add_argument("--scale", type=int, default=4)
parser.add_argument("--generator_path", type=str, default='./ckpt/195_gan_generator.ckpt')
parser.add_argument("--mode", type=str, default='train')
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU', 'CPU'))
if __name__ == '__main__':
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id, save_graphs=False)
context.set_context(device_target=args.platform)
test_ds = create_testdataset(1, args.test_LR_path, args.test_GT_path)
test_data_loader = test_ds.create_dict_iterator()
generator = RRDBNet(3, 3)
......
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 8 ]
then
echo "Usage: bash run_distribute_train_gpu.sh [DEVICE_NUM] [LRPATH] [GTPATH] [VGGCKPT] [VPSNRLRPATH] [VPSNRGTPATH] [VGANLRPATH] [VGANGTPATH]"
exit 1
fi
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
export LRPATH=$2
export GTPATH=$3
export VGGCKPT=$4
export VPSNRLRPATH=$5
export VPSNRGTPATH=$6
export VGANLRPATH=$7
export VGANGTPATH=${8}
rm -rf ./train_parallel
mkdir ./train_parallel
cp -r ../src ./train_parallel
cp -r ../*.py ./train_parallel
cd ./train_parallel || exit
env > env.log
if [ ! -f "$VGGCKPT" ]; then
echo "vggckpt not exist"
exit
else
echo "start traing"
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --run_distribute=1 --device_num=$RANK_SIZE \
--train_LR_path=$LRPATH --train_GT_path=$GTPATH --vgg_ckpt=$VGGCKPT \
--val_PSNR_LR_path=$VPSNRLRPATH --val_PSNR_GT_path=$VPSNRGTPATH --val_GAN_LR_path=$VGANLRPATH \
--val_GAN_GT_path=$VGANGTPATH --platform=GPU > log 2>&1 &
cd ..
fi
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 4 ]
then
echo "Usage: bash run_eval_gpu.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]"
exit 1
fi
export CKPT=$1
export EVALLRPATH=$2
export EVALGTPATH=$3
export DEVICE_ID=$4
rm -rf ./eval
mkdir ./eval
cp -r ../src ./eval
cp -r ../*.py ./eval
cd ./eval || exit
env > env.log
python ./eval.py --generator_path=$CKPT --test_LR_path=$EVALLRPATH --device_id $DEVICE_ID\
--test_GT_path=$EVALGTPATH --platform=GPU> eval.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 8 ]
then
echo "Usage: bash run_standalone_train_gpu.sh [DEVICE_ID] [LRPATH] [GTPATH] [VGGCKPT] [VPSNRLRPATH] [VPSNRGTPATH] [VGANLRPATH] [VGANGTPATH]"
exit 1
fi
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
export DEVICE_ID=$1
export LRPATH=$2
export GTPATH=$3
export VGGCKPT=$4
export VPSNRLRPATH=$5
export VPSNRGTPATH=$6
export VGANLRPATH=$7
export VGANGTPATH=$8
rm -rf ./train_standalone
mkdir ./train_standalone
cp -r ../src ./train_standalone
cp -r ../*.py ./train_standalone
cd ./train_standalone || exit
if [ ! -f "$VGGCKPT" ]; then
echo "vggckpt not exist"
exit
else
echo "start training"
env > env.log
python train.py --device_id=$DEVICE_ID \
--train_LR_path=$LRPATH --train_GT_path=$GTPATH --vgg_ckpt=$VGGCKPT \
--val_PSNR_LR_path=$VPSNRLRPATH --val_PSNR_GT_path=$VPSNRGTPATH --val_GAN_LR_path=$VGANLRPATH \
--val_GAN_GT_path=$VGANGTPATH --platform=GPU > train.log 2>&1 &
echo "1"
cd ..
fi
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -50,12 +50,12 @@ parser.add_argument("--train_url", type=str, default='', help="train url.")
# dataset
parser.add_argument("--train_LR_path", type=str, default='/data/DIV2K/DIV2K_train_LR_bicubic/X4_sub')
parser.add_argument("--train_GT_path", type=str, default='/data/DIV2K/DIV2K_train_HR_sub')
parser.add_argument("--val_PSNR_LR_path", type=str, default='/data/DIV2K/Set5/LRbicx4')
parser.add_argument("--val_PSNR_GT_path", type=str, default='/data/DIV2K/Set5/GTmod12')
parser.add_argument("--val_GAN_LR_path", type=str, default='/data/DIV2K/Set14/LRbicx4')
parser.add_argument("--val_GAN_GT_path", type=str, default='/data/DIV2K/Set14/GTmod12')
parser.add_argument("--val_PSNR_LR_path", type=str, default='/data/Set5/LRbicx4')
parser.add_argument("--val_PSNR_GT_path", type=str, default='/data/Set5/GTmod12')
parser.add_argument("--val_GAN_LR_path", type=str, default='/data/Set14/LRbicx4')
parser.add_argument("--val_GAN_GT_path", type=str, default='/data/Set14/GTmod12')
parser.add_argument("--vgg_ckpt", type=str, default='/data/DIV2K/VGG.ckpt')
parser.add_argument("--vgg_ckpt", type=str, default='/ckpt/vgg19.ckpt')
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8], help='super resolution upscale factor')
parser.add_argument("--image_size", type=int, default=128, help="Image size of high resolution image. (default: 128)")
......@@ -72,7 +72,7 @@ parser.add_argument("--start_gan_epoch", default=0, type=int, metavar='N',
parser.add_argument("--gan_steps", default=400000, type=int, metavar="N",
help="Number of total gan epochs to run. (default: 400000)")
parser.add_argument("--sens", default=1024.0, type=float)
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU', 'CPU'))
# distribute
parser.add_argument("--modelArts", type=int, default=0, help="Run cloud, default: false.")
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
......@@ -82,6 +82,7 @@ parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.
parser.add_argument("--version", type=int, default=0, help="version, default: 0.")
def evaluate(model, test_data_loader, cur_step, cur_epoch, cur_best_psnr, eval_mode, runs_path):
"""evaluate for every epoch"""
print("start valing:")
......@@ -106,6 +107,7 @@ def evaluate(model, test_data_loader, cur_step, cur_epoch, cur_best_psnr, eval_m
if __name__ == '__main__':
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
context.set_context(device_target=args.platform)
# distribute
if args.modelArts:
import moxing as mox
......@@ -139,22 +141,33 @@ if __name__ == '__main__':
local_train_ckpt_path = './ckpt'
if args.run_distribute:
print("distribute")
rank_id = int(os.getenv('RANK_ID'))
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(device_id=device_id)
if args.platform == 'Ascend':
rank_id = int(os.getenv('RANK_ID'))
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(device_id=device_id)
device_num = args.device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
init()
if args.platform == 'GPU':
rank = get_rank()
rank_id = rank
shard_id = rank_id
num_shards = device_num
rank = get_rank()
else:
context.set_context(device_id=args.device_id)
device_num = args.device_num
shard_id = None
num_shards = None
if args.platform == 'GPU':
rank = 0
shard_id = 0
num_shards = 1
# for RRDBNet
# create dataset
args.train_batch_size = int(args.train_batch_size // device_num) if args.run_distribute else args.train_batch_size
......@@ -183,9 +196,9 @@ if __name__ == '__main__':
best_psnr = 0.0
if not os.path.exists("./ckpt"):
os.makedirs("./ckpt")
os.makedirs("./ckpt", exist_ok=True)
if not os.path.exists("./runs"):
os.makedirs("./runs")
os.makedirs("./runs", exist_ok=True)
print('start training:')
print('start training PSNR:')
......@@ -212,13 +225,13 @@ if __name__ == '__main__':
# Check whether the evaluation index of the current model is the highest.
is_best = psnr > best_psnr
best_psnr = max(psnr, best_psnr)
if is_best:
if is_best and rank == 0:
print("best_psnr saving ckpt ", end="")
print(best_psnr)
save_checkpoint(generator, os.path.join(local_train_ckpt_path, 'psnr_best.ckpt'))
# save checkpoint every epoch
save_checkpoint(generator, os.path.join(local_train_ckpt_path, f'{epoch}_psnr_generator.ckpt'))
print(f"{epoch + 1}/{total_psnr_epochs} epoch finished")
save_checkpoint(generator, os.path.join(local_train_ckpt_path, f'{epoch}_psnr_generator.ckpt'))
print(f"{epoch + 1}/{total_psnr_epochs} epoch finished")
# for esrgan
test_gan_ds = create_testdataset(args.val_batch_size, args.val_GAN_LR_path, args.val_GAN_GT_path)
test_gan_data_loader = test_gan_ds.create_dict_iterator()
......@@ -269,15 +282,15 @@ if __name__ == '__main__':
is_best = psnr > best_psnr
best_psnr = max(psnr, best_psnr)
print(best_psnr)
if is_best:
if is_best and rank == 0:
print("best_psnr saving ckpt ", end="")
print(best_psnr)
save_checkpoint(generator, os.path.join(local_train_ckpt_path, 'gan_generator_best.ckpt'))
save_checkpoint(discriminator, os.path.join(local_train_ckpt_path, 'gan_discriminator_best.ckpt'))
# save checkpoint every epoch
save_checkpoint(generator, os.path.join(local_train_ckpt_path, f'{epoch}_gan_generator.ckpt'))
save_checkpoint(discriminator, os.path.join(local_train_ckpt_path, f'{epoch}_gan_discriminator.ckpt'))
if epoch == total_gan_epochs - 1:
save_checkpoint(generator, os.path.join(local_train_ckpt_path, f'{epoch}_gan_generator.ckpt'))
save_checkpoint(discriminator, os.path.join(local_train_ckpt_path, f'{epoch}_gan_discriminator.ckpt'))
if epoch == total_gan_epochs - 1 and rank == 0:
save_checkpoint(generator, os.path.join(local_train_ckpt_path, 'gan_generator.ckpt'))
print(f"{epoch + 1}/{total_gan_epochs} epoch finished")
print("all")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment