Skip to content
Snippets Groups Projects
Unverified Commit 19f94fa6 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1973 Models: PGGAN GPU

Merge pull request !1973 from adenisov/models-pr-pggan
parents 6c3053ac 6763e2c6
No related branches found
No related tags found
No related merge requests found
Showing
with 1204 additions and 151 deletions
...@@ -9,6 +9,7 @@ data_path: "/cache/data" ...@@ -9,6 +9,7 @@ data_path: "/cache/data"
output_path: "/cache/train" output_path: "/cache/train"
load_path: "/cache/checkpoint_path" load_path: "/cache/checkpoint_path"
device_target: "Ascend" device_target: "Ascend"
gpu_distribute_training: 0
# ============================================================================== # ==============================================================================
...@@ -17,11 +18,14 @@ dataset_name: "celeba" ...@@ -17,11 +18,14 @@ dataset_name: "celeba"
name: "celeba" name: "celeba"
resume_load_scale: -1 resume_load_scale: -1
batch_size: 16 batch_size: 16
batch_size_list: []
train_data_path: "" train_data_path: ""
resume_check_d: "checkpoint_path/" resume_check_d: "checkpoint_path/"
resume_check_g: "checkpoint_path/" resume_check_g: "checkpoint_path/"
ckpt_save_dir: "./checkpoint" ckpt_save_dir: "./checkpoint"
eval_img_save_dir: "./eval_img"
model_save_step: 10000 model_save_step: 10000
save_ckpt_from_device_with_id: 0
#network #network
scales: [4, 8, 16, 32, 64, 128] scales: [4, 8, 16, 32, 64, 128]
depth: [512, 512, 512, 512, 256, 128] depth: [512, 512, 512, 512, 256, 128]
...@@ -30,8 +34,9 @@ alpha_jumps: [0, 600, 600, 600, 600, 600] ...@@ -30,8 +34,9 @@ alpha_jumps: [0, 600, 600, 600, 600, 600]
alpha_size_jumps: [32, 32, 32, 32, 32, 32] alpha_size_jumps: [32, 32, 32, 32, 32, 32]
# optimizer and lr related # optimizer and lr related
lr: 0.001 lr: 0.001
lr_list: []
# loss related # loss related
loss_scale_value: 20 loss_scale_value: 12
scale_factor: 10 scale_factor: 10
scale_window: 1000 scale_window: 1000
# export option # export option
...@@ -48,6 +53,36 @@ train_url: "Url for modelarts" ...@@ -48,6 +53,36 @@ train_url: "Url for modelarts"
data_path: "The location of the input data." data_path: "The location of the input data."
output_path: "The location of the output file." output_path: "The location of the output file."
device_target: 'Target device type' device_target: 'Target device type'
gpu_distribute_training: 'Use more then 1 device to train network. 1 - true, 0 - false. Used only for GPU'
enable_profiling: 'Whether enable profiling while training, default: False' enable_profiling: 'Whether enable profiling while training, default: False'
dataset_name: "the name of the dataset being used" #unused
name: "the name of the dataset being used" #unused
resume_load_scale: "resume training from this scale, must be in [4, 8, 16, 32, 64, 128]"
batch_size: "batch size used for training unless the batch_size_list is specified"
batch_size_list: "if not empty, specifies the bach size for each input scale.
Must have the same length as the scales list"
train_data_path: "path to folder with images for network training"
resume_check_d: "discriminator checkpoint path"
resume_check_g: "generator checkpoint path"
ckpt_save_dir: "folder to save checkpoints"
eval_img_save_dir: "folder to save validation images"
model_save_step: "save the model every time through this number of steps"
save_ckpt_from_device_with_id: "save checkpoint from device with this id"
#network
scales: "generated image sizes"
depth: "input channel numbers for convolution in each layer"
num_batch: "Number of iterations for each image scale"
alpha_jumps: "do not change this parameter unless you know exactly what you are doing"
alpha_size_jumps: "do not change this parameter unless you know exactly what you are doing"
# optimizer and lr related
lr: "learning rate used for training unless the lr_list is specified"
lr_list: "if not empty, specifies the learning rate for each input scale. Must have the same length as the scales list"
# loss related
loss_scale_value: "parameter for model optimizer"
scale_factor: "parameter for model optimizer"
scale_window: "parameter for model optimizer"
# export option
ckpt_file: "" #unused
file_name: "file name for model export"
file_format: "choices in ['AIR', 'ONNX', 'MINDIR']" file_format: "choices in ['AIR', 'ONNX', 'MINDIR']"
# content
<!-- TOC -->
- [directory](#directory)
- [PGAN model introduction](#model-introduction)
- [model-architecture](#model-architecture)
- [dataset](#dataset)
- [environmental requirements](#environmental-requirements)
- [Quickstart](#quickstart)
- [Script description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [training process](#training-process)
- [training](#training)
- [distributed training](#distributed-training)
- [evaluation process](#evaluation-process)
- [assessment](#assessment)
- [inference process](#inference-process)
- [Export MindIR](#export-mindir)
- [Execute reasoning on Ascend310](#execute-reasoning-on-ascend310)
- [result](#result)
- [model description](#model-description)
- [performance](#performance)
- [evaluate performance](#evaluate-performance)
- [PGAN on CelebA](#pgan-on-celeba)
- [ModelZoo homepage](#modelzoo-homepage)
# model introduction
PGAN refers to Progressive Growing of GANs for Improved Quality, Stability, and Variation, this network is
characterized by the progressive generation of face images
[Paper](https://arxiv.org/abs/1710.10196): Progressive Growing of GANs for Improved Quality, Stability,
and Variation//2018 ICLR
[Reference github](https://github.com/tkarras/progressive_growing_of_gans)
# Model architecture
The entire network structure consists of generator and discriminator. The core idea of ​​the network is to
generate the image with low resolution, add new layers as the training progresses, and gradually begin to generate
more detailed image. Doing so speeds up training and stabilizes it. In addition, this code implements core tricks
such as equalized learning rate, exponential running average, residual structure, and WGANGPGradientPenalty
in the paper.
# Dataset
Dataset web-page: [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
> Note: For this task we use the "Align&Cropped Images" dataset (from the "Downloads" section on the official web-page.).
Dataset link (1.34 GB): [Celeba Aligned and Cropped Images](https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?usp=sharing&resourcekey=0-dYn9z10tMJOBAkviAcfdyQ)
After unpacking the dataset, it should look as follows:
```text
.
└── Celeba
└── img_align_celeba
├── 000001.jpg
├── 000002.jpg
└── ...
```
CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset with over 200K celebrity images,
each with 40 attribute annotations. CelebA is diverse, numerous, and annotated, including
- 10,177 number of identities,
- 202,599 number of face images, and 5 landmark locations, 40 binary attributes annotations per image.
This dataset can be used as a training and test set for the following computer vision tasks: face attribute recognition,
face detection, and face editing and synthesis.
# Environmental requirements
- Hardware (Ascend, GPU)
- Use Ascend or GPU to build the hardware environment.
- Framework
- [MindSpore](https://www.mindspore.cn/install)
- For details, see the following resources:
- [MindSpore Tutorial](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# Quick start
After installing MindSpore through the official website, you can follow the steps below for training and evaluation:
- Ascend processor environment to run
```bash
# run the training example
export DEVICE_ID=0
export RANK_SIZE=1
python train_data_path.py --train_data_path /path/data/image --config_path ./910_config.yaml
OR
bash run_standalone_train.sh /path/data/image device_id ./910_config.yaml
# run the distributed training example
bash run_distributed_train.sh /path/data/image /path/hccl_config_file ./910_config.yaml
# run the evaluation example
export DEVICE_ID=0
export RANK_SIZE=1
python eval.py --checkpoint_g=/path/checkpoint --device_id=0
OR
bash run_eval.sh /path/checkpoint 0
```
- GPU environment to run
```bash
# run the training example
python ./train.py --config_path /path/to/gpu_config.yaml --train_data_path /path/to/img_align_celeba > ./train.log 2>&1 &
OR
bash script/run_standalone_train_gpu.sh ./gpu_config.yaml /path/to/img_align_celeba
# run the distributed training example
bash script/run_distribute_train_gpu.sh /path/to/gpu_config.yaml /path/to/img_align_celeba
# run the evaluation example
python -u eval.py \
--checkpoint_g=/path/to/checkpoint \
--device_target GPU \
--device_id=/path/to/checkpoint \
--measure_ms_ssim=True \
--original_img_dir=/path/to/img_align_celeba
OR
bash script/run_eval_gpu.sh /path/to/checkpoint 0 True /path/to/img_align_celeba
```
For evaluation scripts, the checkpoint file is placed by default by the training script in
In the `/output/{scale}/checkpoint` directory, you need to pass the name of the checkpoint file (Generator)
as a parameter when executing the script.
# script description
## Script and sample code
```text
.
└─ cv
└─ PGAN
├── script
├──run_distribute_train_gpu.sh # Distributed training on GPU shell script
├──run_distributed_train_ascend.sh # Distributed training shell script
├──run_infer_310.sh # Inference on Ascend 310
├──run_standalone_train.sh # Shell script for single card training
├──run_standalone_train_GPU.sh # Shell script for single GPU training
├──run_eval_ascend.sh # evaluation script
├──run_eval_GPU.sh # GPU evaluation script
├─ src
├─ customer_layer.py # Basic cell
├─ dataset.py # data loading
├─ image_transform.py # process image function
├─ metrics.py # Metric function
├─ network_D.py # Discriminate network
├─ network_G.py # Generate network
├─ optimizer.py # loss calculation
├─ time_monitor.py # time monitor
├─ eval.py # test script
├─ export.py # MINDIR model export script
├─ 910_config.yaml # Ascend config
├─ gpu_config.yaml # GPU config
├─ modelarts_config.yaml # Ascend config
├─ README_CN.md # PGAN file description
└─ README.md # PGAN file English description
```
## training process
### train
- Ascend processor environment to run
```bash
export DEVICE_ID=0
export RANK_SIZE=1
python train.py --train_data_path /path/data/image --config_path ./910_config.yaml
# or
bash run_standalone_train.sh /path/data/image device_id ./910_config.yaml
```
- GPU environment to run
```bash
python train.py --config_path ./gpu_config.yaml --train_data_path /path/to/img_align_celeba
# or
bash run_standalone_train_gpu.sh /path/to/gpu_config.yaml /path/to/img_align_celeba
```
After the training, the output directory will be generated in the current directory. In this directory,
the corresponding subdirectory will be generated according to the ckpt_dir parameter you set, and the parameters
of each scale will be saved during training.
### Distributed training
- Ascend processor environment to run
```bash
bash run_distributed_train.sh /path/to/img_align_celeba /path/hccl_config_file ./910_config.yaml
```
- GPU environment to run
```bash
bash script/run_distributed_train.sh /path/to/gpu_config.yaml /path/to/img_align_celeba
```
The above shell script will run distributed training in the background. The script will generate the corresponding
LOG{RANK_ID} directory under the script directory, and the output of each process will be recorded in the
log_distribute file under the corresponding LOG{RANK_ID} directory. The checkpoint file is saved under
output/rank{RANK_ID}.
## Evaluation process
### evaluate
- Generate images in the Ascend environment
User-generated 64 face pictures
When evaluating, select the generated checkpoint file and pass it into the test script as a parameter.
The corresponding parameter is `checkpoint_g` (the checkpoint of the generator is saved)
- Use a checkpoint which name starts with `AvG` (for example, AvG_12000.ckpt)
- Ascend processor environment to run
```bash
bash run_eval.sh /path/to/avg/checkpoint 0
```
- GPU environment to run
```bash
bash script/run_eval_gpu.sh [CKPT_PATH] [DEVICE_ID] [MEASURE_MSSIM] [DATASET_DIR]
```
- CKPT_PATH - path to the checkpoint
- DEVICE_ID - device ID
- MEASURE_MSSIM - Flag to calculate objective metrics. If True, MS-SSIM is calculated,
otherwise, the script will only generated images of faces.
- DATASET_DIR - path to the dataset images
After the test script is executed, the generated images are stored in `img_eval/`.
## Reasoning process
### Export MindIR
```bash
python export.py --checkpoint_g [GENERATOR_CKPT_NAME] --device_id [DEVICE_ID] --device_target [DEVICE_TARGET]
```
- GENERATOR_CKPT_NAME - path to the trained checkpoint (Use AvG_xx.ckpt)
- DEVICE_TARGET - Device target: Ascend or GPU
- DEVICE_ID - Device ID
The script will generate the corresponding MINDIR file in the current directory.
### Perform inference on Ascend310
Before performing inference, the MINDIR model must be exported via the export script. The following commands show
how to edit the properties of images on Ascend310 through commands:
```bash
bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [NIMAGES] [DEVICE_ID]
````
- `MINDIR_PATH` path to the MINDIR file
- `NEED_PREPROCESS` indicates whether the attribute editing file needs to be preprocessed, which can be selected
from y or n. If y is selected, it means preprocessing (it needs to be set to y when the inference is executed
for the first time)
- `NIMAGES` indicates the number of generated images.
- `DEVICE_ID` is optional, default is 0.
### result
The inference results are saved in the directory where the script is executed, the pictures after attribute editing
are saved in the `result_Files/` directory, and the time statistics results of inference are saved in the
`time_Result/` directory. The edited image is saved in the format `generated_{NUMBER}.png`.
# model description
## performance
### Evaluate performance
#### PGAN on CelebA
| Parameters | Ascend 910 | GPU |
|---------------------|--------------------------------------------------------------------------------|--------------------------------------------------------------------------------|
| Model Version | PGAN | PGAN |
| Resources | Ascend | GPU |
| Upload Date | 09/31/2021 (month/day/year) | 02/08/2022 (month/day/year) |
| MindSpore Version | 1.3.0 | 1.5.0 |
| Datasets | CelebA | CelebA |
| Training parameters | batch_size=16, lr=0.001 | batch_size=16 for scales 4-64,batch_size=8 for scale128, , lr=0.002 |
| Optimizer | Adam | Adam |
| generator output | image | image |
| Speed ​​ | 8p: 9h 26m 54ы; 1p: 76h 23m 39s; 1.1s/step | 8: 10h 28m 37s; 1: 83h 45m 34s |
| Convergence loss | G:[-232.61 to 273.87] loss D:[-27.736 to 2.601] | G:[-232.61 to 273.87] D:[-27.736 to 2.601] |
| MS-SSIM metric | | 0.2948 |
| Script | [PGAN script](https://gitee.com/mindspore/models/tree/master/research/cv/PGAN) | [PGAN script](https://gitee.com/mindspore/models/tree/master/research/cv/PGAN) |
> Note: For measuring the metrics and generating the images we are using the checkpoint with prefix AvG (AvG_xxxx.ckpt)
# ModelZoo homepage
Please visit the official website [homepage](https://gitee.com/mindspore/models)
...@@ -233,7 +233,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [NIMAGES] [DEVICE_ID] ...@@ -233,7 +233,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [NIMAGES] [DEVICE_ID]
| 上传日期 | 09/31/2021 (month/day/year) | | 上传日期 | 09/31/2021 (month/day/year) |
| MindSpore版本 | 1.3.0 | | MindSpore版本 | 1.3.0 |
| 数据集 | CelebA | | 数据集 | CelebA |
| 训练参数 | batch_size=16, lr=0.001 | | 训练参数 | batch_size=128, lr=0.001 |
| 优化器 | Adam | | 优化器 | Adam |
| 生成器输出 | image | | 生成器输出 | image |
| 速度 |8p:9h26m54S; 1p:76h23m39s; 1.1s/step | | 速度 |8p:9h26m54S; 1p:76h23m39s; 1.1s/step |
......
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,47 +13,70 @@ ...@@ -13,47 +13,70 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""eval PGAN""" """eval PGAN"""
import os
import argparse import argparse
import os
import random
import numpy as np import numpy as np
from PIL import Image
from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore import dtype as mstype
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore
from src.network_G import GNet4_4_Train, GNet4_4_last, GNetNext_Train, GNetNext_Last from src.image_transform import Crop
from src.image_transform import Normalize, TransporeAndMul, Resize from src.image_transform import Normalize, TransporeAndMul, Resize
from PIL import Image from src.metric import msssim
from src.network_G import GNet4_4_Train, GNet4_4_last, GNetNext_Train, GNetNext_Last
def preLauch(): def set_every(num):
"""set random seed"""
random.seed(num)
set_seed(num)
np.random.seed(num)
set_every(1)
def pre_launch():
"""parse the console argument""" """parse the console argument"""
parser = argparse.ArgumentParser(description='MindSpore PGAN training') parser = argparse.ArgumentParser(description='MindSpore PGAN training')
parser.add_argument('--device_target', type=str, default='Ascend',
help='Target device (Ascend or GPU, default Ascend)')
parser.add_argument('--device_id', type=int, default=0, parser.add_argument('--device_id', type=int, default=0,
help='device id of Ascend (Default: 0)') help='device id of Ascend (Default: 0)')
parser.add_argument('--checkpoint_g', type=str, default='', parser.add_argument('--checkpoint_g', type=str, default='',
help='checkpoint of g net (default )') help='checkpoint of g net (default )')
parser.add_argument('--img_out', type=str, parser.add_argument('--img_out_dir', type=str,
default='img_eval', help='the dir of output img') default='img_eval', help='the dir of output img')
parser.add_argument('--measure_ms_ssim', type=bool,
default=False, help='measure ms-ssim metric flag')
parser.add_argument('--original_img_dir', type=str,
default='', help='the dir of real img')
args = parser.parse_args() args = parser.parse_args()
context.set_context(device_id=args.device_id, context.set_context(device_id=args.device_id,
mode=context.GRAPH_MODE, mode=context.GRAPH_MODE,
device_target="Ascend") device_target=args.device_target)
# if not exists 'img_out', make it # if not exists 'img_out', make it
if not os.path.exists(args.img_out): if not os.path.exists(args.img_out_dir):
os.mkdir(args.img_out) os.mkdir(args.img_out_dir)
return args return args
def buildNoiseData(n_samples): def build_noise_data(n_samples):
"""buildNoiseData """build_noise_data
Returns: Returns:
output. output.
""" """
inputLatent = np.random.randn(n_samples, 512) input_latent = np.random.randn(n_samples, 512)
inputLatent = mindspore.Tensor(inputLatent, mindspore.float32) input_latent = Tensor(input_latent, mstype.float32)
return inputLatent return input_latent
def image_compose(out_images, size=(8, 8)): def image_compose(out_images, size=(8, 8)):
...@@ -70,30 +93,41 @@ def image_compose(out_images, size=(8, 8)): ...@@ -70,30 +93,41 @@ def image_compose(out_images, size=(8, 8)):
return to_image return to_image
def resizeTensor(data, out_size_image): def to_img_list(out_images):
"""resizeTensor """to_img_list
Returns: Returns:
output. output.
""" """
out_data_size = (data.shape[0], data.shape[ img_list = []
1], out_size_image[0], out_size_image[1]) for img in out_images:
img_list.append(Image.fromarray(img))
return img_list
def resize_tensor(data, out_size_image):
"""resize_tensor
Returns:
output.
"""
out_data_size = (data.shape[0], data.shape[1], out_size_image[0], out_size_image[1])
outdata = [] outdata = []
data = data.asnumpy() data = data.asnumpy()
data = np.clip(data, a_min=-1, a_max=1) data = np.clip(data, a_min=-1, a_max=1)
transformList = [Normalize((-1., -1., -1.), (2, 2, 2)), TransporeAndMul(), Resize(out_size_image)] transform_list = [Normalize((-1., -1., -1.), (2, 2, 2)),
TransporeAndMul(), Resize(out_size_image)]
for img in range(out_data_size[0]): for img in range(out_data_size[0]):
processed = data[img] processed = data[img]
for transform in transformList: for transform in transform_list:
processed = transform(processed) processed = transform(processed)
processed = np.array(processed) processed = np.array(processed)
outdata.append(processed) outdata.append(processed)
return outdata return outdata
def main(): def construct_gnet():
"""main""" """construct_gnet"""
args = preLauch()
scales = [4, 8, 16, 32, 64, 128] scales = [4, 8, 16, 32, 64, 128]
depth = [512, 512, 512, 512, 256, 128] depth = [512, 512, 512, 512, 256, 128]
for scale_index, scale in enumerate(scales): for scale_index, scale in enumerate(scales):
...@@ -105,13 +139,57 @@ def main(): ...@@ -105,13 +139,57 @@ def main():
else: else:
last_avg_gnet = GNetNext_Last(avg_gnet) last_avg_gnet = GNetNext_Last(avg_gnet)
avg_gnet = GNetNext_Train(depth[scale_index], last_avg_gnet, leakyReluLeak=0.2, dimOutput=3) avg_gnet = GNetNext_Train(depth[scale_index], last_avg_gnet, leakyReluLeak=0.2, dimOutput=3)
return avg_gnet
def load_original_images(original_img_dir, img_number):
"""load_original_images"""
file_names = [f for f in os.listdir(original_img_dir)
if os.path.isfile(os.path.join(original_img_dir, f)) and '.jpg' in f]
file_names = random.sample(file_names, img_number)
crop = Crop()
img_list = []
for im_name in file_names:
img = Image.open(os.path.join(original_img_dir, im_name))
img = np.array(crop(img))
img_list.append(img)
return img_list
def main():
"""main"""
print("Creating evaluation image...")
args = pre_launch()
avg_gnet = construct_gnet()
param_dict_g = load_checkpoint(args.checkpoint_g) param_dict_g = load_checkpoint(args.checkpoint_g)
load_param_into_net(avg_gnet, param_dict_g) load_param_into_net(avg_gnet, param_dict_g)
inputNoise = buildNoiseData(64) input_noise = build_noise_data(64)
gen_imgs_eval = avg_gnet(inputNoise, 0.0) gen_imgs_eval = avg_gnet(input_noise, 0.0)
out_images = resizeTensor(gen_imgs_eval, (128, 128)) out_images = resize_tensor(gen_imgs_eval, (128, 128))
to_image = image_compose(out_images) to_image = image_compose(out_images)
to_image.save(os.path.join(args.img_out, "result.jpg")) to_image.save(os.path.join(args.img_out_dir, "result.jpg"))
if args.measure_ms_ssim:
print("Preparing images for metric calculation...")
n_eval_batch = 200
real_img_list = load_original_images(args.original_img_dir, n_eval_batch * 64 * 2)
real_img_list = np.stack(real_img_list)
fake_img_list = []
for _ in range(n_eval_batch):
input_noise = build_noise_data(64)
gen_imgs_eval = avg_gnet(input_noise, 0.0)
out_images = resize_tensor(gen_imgs_eval, (128, 128))
fake_img_list += to_img_list(out_images)
fake_img_list = np.stack(fake_img_list)
print("Calculating metrics...")
mssim_real = msssim(real_img_list[::2], real_img_list[1::2])
mssim_fake = msssim(fake_img_list, real_img_list[1::2])
print(f"Structure similarity for reals with reals: {mssim_real}")
print(f"Structure similarity for reals with fakes: {mssim_fake}")
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,12 +24,11 @@ from src.network_G import GNet4_4_Train, GNet4_4_last, GNetNext_Train, GNetNext_ ...@@ -24,12 +24,11 @@ from src.network_G import GNet4_4_Train, GNet4_4_last, GNetNext_Train, GNetNext_
def preLauch(): def preLauch():
"""parse the console argument""" """parse the console argument"""
parser = argparse.ArgumentParser(description='MindSpore PGAN training') parser = argparse.ArgumentParser(description='MindSpore PGAN training')
parser.add_argument('--device_id', type=int, default=0, parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
help='device id of Ascend (Default: 0)') parser.add_argument('--device_target', type=str, required=True, choices=['Ascend', 'GPU'], help='Device target')
parser.add_argument('--checkpoint_g', type=str, parser.add_argument('--checkpoint_g', type=str, default='ckpt', help='checkpoint dir of PGAN')
default='ckpt', help='checkpoint dir of PGAN')
args = parser.parse_args() args = parser.parse_args()
context.set_context(device_id=args.device_id, mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(device_id=args.device_id, mode=context.GRAPH_MODE, device_target=args.device_target)
return args return args
......
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "GPU"
gpu_distribute_training: 1
# ==============================================================================
# options
dataset_name: "celeba"
name: "celeba"
resume_load_scale: -1
batch_size: 16
batch_size_list: [ 16, 16, 16, 16, 16, 8]
train_data_path: "/path/to/dataset/images"
resume_check_d: "chekpoint_from_huawei/D_12000.ckpt"
resume_check_g: "chekpoint_from_huawei/AvG_12000.ckpt"
ckpt_save_dir: "./checkpoint"
eval_img_save_dir: "./eval_img"
model_save_step: 1000
save_ckpt_from_device_with_id: 0
#network
scales: [4, 8, 16, 32, 64, 128]
depth: [512, 512, 512, 512, 256, 128]
num_batch: [48000, 96000, 96000, 96000, 96000, 96000]
alpha_jumps: [0, 600, 600, 600, 600, 600]
alpha_size_jumps: [32, 32, 32, 32, 32, 32]
# optimizer and lr related
lr: 0.002
lr_list: []
# loss related
loss_scale_value: 12
scale_factor: 10
scale_window: 1000
# export option
ckpt_file: ""
file_name: "PGAN"
file_format: "MINDIR"
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: 'Target device type'
distributed_training: 'Use more then 1 device to train network. 1 - true, 0 - false. Used only for GPU'
enable_profiling: 'Whether enable profiling while training, default: False'
dataset_name: "the name of the dataset being used" #unused
name: "the name of the dataset being used" #unused
resume_load_scale: "resume training from this scale, must be in [4, 8, 16, 32, 64, 128]"
batch_size: "batch size used for training unless the batch_size_list is specified"
batch_size_list: "if not empty, specifies the bach size for each input scale.
Must have the same length as the scales list"
train_data_path: "path to folder with images for network training"
resume_check_d: "discriminator checkpoint path"
resume_check_g: "generator checkpoint path"
ckpt_save_dir: "folder to save checkpoints"
eval_img_save_dir: "folder to save validation images"
model_save_step: "save the model every time through this number of steps"
save_ckpt_from_device_with_id: "save checkpoint from device with this id"
#network
scales: "generated image sizes"
depth: "input channel numbers for convolution in each layer"
num_batch: "Number of iterations for each image scale"
alpha_jumps: "do not change this parameter unless you know exactly what you are doing"
alpha_size_jumps: "do not change this parameter unless you know exactly what you are doing"
# optimizer and lr related
lr: "learning rate used for training unless the lr_list is specified"
lr_list: "if not empty, specifies the learning rate for each input scale. Must have the same length as the scales list"
# loss related
loss_scale_value: "parameter for model optimizer"
scale_factor: "parameter for model optimizer"
scale_window: "parameter for model optimizer"
# export option
ckpt_file: "" #unused
file_name: "file name for model export"
file_format: "choices in ['AIR', 'ONNX', 'MINDIR']"
...@@ -9,6 +9,7 @@ data_path: "/cache/data" ...@@ -9,6 +9,7 @@ data_path: "/cache/data"
output_path: "/cache/train" output_path: "/cache/train"
load_path: "/cache/checkpoint_path" load_path: "/cache/checkpoint_path"
device_target: "Ascend" device_target: "Ascend"
gpu_distribute_training: 0
# ============================================================================== # ==============================================================================
...@@ -17,11 +18,14 @@ dataset_name: "celeba" ...@@ -17,11 +18,14 @@ dataset_name: "celeba"
name: "celeba" name: "celeba"
resume_load_scale: -1 resume_load_scale: -1
batch_size: 16 batch_size: 16
batch_size_list: []
train_data_path: "/cache/data/" train_data_path: "/cache/data/"
resume_check_d: "/cache/checkpoint_path/" resume_check_d: "/cache/checkpoint_path/"
resume_check_g: "/cache/checkpoint_path/" resume_check_g: "/cache/checkpoint_path/"
ckpt_save_dir: "checkpoint" ckpt_save_dir: "checkpoint"
eval_img_save_dir: "./eval_img"
model_save_step: 10000 model_save_step: 10000
save_ckpt_from_device_with_id: 0
#network #network
scales: [4, 8, 16, 32, 64, 128] scales: [4, 8, 16, 32, 64, 128]
depth: [512, 512, 512, 512, 256, 128] depth: [512, 512, 512, 512, 256, 128]
...@@ -30,6 +34,7 @@ alpha_jumps: [0, 600, 600, 600, 600, 600] ...@@ -30,6 +34,7 @@ alpha_jumps: [0, 600, 600, 600, 600, 600]
alpha_size_jumps: [32, 32, 32, 32, 32, 32] alpha_size_jumps: [32, 32, 32, 32, 32, 32]
# optimizer and lr related # optimizer and lr related
lr: 0.001 lr: 0.001
lr_list: []
# loss related # loss related
loss_scale_value: 12 loss_scale_value: 12
scale_factor: 10 scale_factor: 10
...@@ -48,6 +53,36 @@ train_url: "Url for modelarts" ...@@ -48,6 +53,36 @@ train_url: "Url for modelarts"
data_path: "The location of the input data." data_path: "The location of the input data."
output_path: "The location of the output file." output_path: "The location of the output file."
device_target: 'Target device type' device_target: 'Target device type'
gpu_distribute_training: 'Use more then 1 device to train network. 1 - true, 0 - false. Used only for GPU'
enable_profiling: 'Whether enable profiling while training, default: False' enable_profiling: 'Whether enable profiling while training, default: False'
dataset_name: "the name of the dataset being used" #unused
name: "the name of the dataset being used" #unused
resume_load_scale: "resume training from this scale, must be in [4, 8, 16, 32, 64, 128]"
batch_size: "batch size used for training unless the batch_size_list is specified"
batch_size_list: "if not empty, specifies the bach size for each input scale.
Must have the same length as the scales list"
train_data_path: "path to folder with images for network training"
resume_check_d: "discriminator checkpoint path"
resume_check_g: "generator checkpoint path"
ckpt_save_dir: "folder to save checkpoints"
eval_img_save_dir: "folder to save validation images"
model_save_step: "save the model every time through this number of steps"
save_ckpt_from_device_with_id: "save checkpoint from device with this id"
#network
scales: "generated image sizes"
depth: "input channel numbers for convolution in each layer"
num_batch: "Number of iterations for each image scale"
alpha_jumps: "do not change this parameter unless you know exactly what you are doing"
alpha_size_jumps: "do not change this parameter unless you know exactly what you are doing"
# optimizer and lr related
lr: "learning rate used for training unless the lr_list is specified"
lr_list: "if not empty, specifies the learning rate for each input scale. Must have the same length as the scales list"
# loss related
loss_scale_value: "parameter for model optimizer"
scale_factor: "parameter for model optimizer"
scale_window: "parameter for model optimizer"
# export option
ckpt_file: "" #unused
file_name: "file name for model export"
file_format: "choices in ['AIR', 'ONNX', 'MINDIR']" file_format: "choices in ['AIR', 'ONNX', 'MINDIR']"
PIL Pillow
matplotlib >= 3.4.2 matplotlib >= 3.4.2
imageio >= 2.9.0 imageio >= 2.9.0
pyyaml
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "===================================================================================================="
echo "Please run the script as: "
echo "bash script/run_distribute_train_gpu.sh CONFIG_PATH DATASET_PATH"
echo "for example: bash script/run_distribute_train_gpu.sh /path/to/gpu_config.yaml /path/to/dataset/images"
echo "===================================================================================================="
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
CONFIG_PATH=$(get_real_path "$1")
DATASET_PATH=$(get_real_path "$2")
if [ -d logs ]
then
rm -rf logs
fi
mkdir ./logs
echo "Start OpenMPI"
export RANK_SIZE=8
export DEVICE_NUM=$RANK_SIZE
mpirun --output-filename logs -merge-stderr-to-stdout \
-np $DEVICE_NUM --allow-run-as-root \
python ./train.py --config_path "$CONFIG_PATH" \
--train_data_path "$DATASET_PATH" \
> ./logs/train.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]
then
echo "Usage: bash script/run_eval_gpu.sh [CKPT_PATH] [DEVICE_ID] [MEASURE_MSSIM] [DATASET_DIR]"
exit 1
fi
export CKPT=$1
export DEVICE_ID=$2
export MEASURE_MSSIM=$3
export DATASET_DIR=$4
python -u eval.py \
--checkpoint_g="$CKPT" \
--device_target GPU \
--device_id="$DEVICE_ID" \
--measure_ms_ssim="$MEASURE_MSSIM" \
--original_img_dir="$DATASET_DIR" \
> eval.log 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "===================================================================================================="
echo "Please run the script as: "
echo "bash script/run_standalone_train_gpu.sh CONFIG_PATH DATASET_PATH"
echo "for example: bash script/run_standalone_train_gpu.sh /path/to/gpu_config.yaml /path/to/dataset/images"
echo "===================================================================================================="
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
realpath -m "$PWD/$1"
fi
}
CONFIG_PATH=$(get_real_path "$1")
DATASET_PATH=$(get_real_path "$2")
if [ -d logs ]
then
rm -rf logs
fi
mkdir logs
python ./train.py --config_path "$CONFIG_PATH" --train_data_path "$DATASET_PATH" > ./logs/train.log 2>&1 &
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""image transform""" """image transform"""
import os import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -36,6 +37,22 @@ class NumpyResize(): ...@@ -36,6 +37,22 @@ class NumpyResize():
return np_image return np_image
class Crop():
"""Crop"""
def __init__(self, cx=89, cy=121):
self.cx = cx
self.cy = cy
def __call__(self, img):
r"""
Args:
img (np array): image to be cropped
Returns:
np.Array: cropped image
"""
return np.array(img)[self.cy - 64: self.cy + 64, self.cx - 64: self.cx + 64]
class Resize(): class Resize():
"""Resize""" """Resize"""
def __init__(self, size): def __init__(self, size):
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metric"""
import numpy as np
from scipy import signal
def FSpecialGauss(size, sigma):
"""Function to mimic the fspecial gaussian MATLAB function."""
radius = size // 2
offset = 0.0
start, stop = -radius, radius + 1
if size % 2 == 0:
offset = 0.5
stop -= 1
x, y = np.mgrid[offset + start:stop, offset + start:stop]
assert len(x) == size
g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
return g / g.sum()
def SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
"""Return the Structural Similarity Map between `img1` and `img2`.
This function attempts to match the functionality of ssim_index_new.m by
Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
Arguments:
img1: Numpy array holding the first RGB image batch.
img2: Numpy array holding the second RGB image batch.
max_val: the dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
filter_size: Size of blur kernel to use (will be reduced for small images).
filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
for small images).
k1: Constant used to maintain stability in the SSIM calculation (0.01 in
the original paper).
k2: Constant used to maintain stability in the SSIM calculation (0.03 in
the original paper).
Returns:
Pair containing the mean SSIM and contrast sensitivity between `img1` and
`img2`.
Raises:
RuntimeError: If input images don't have the same shape or don't have four
dimensions: [batch_size, height, width, depth].
"""
if img1.shape != img2.shape:
raise RuntimeError('Input images must have the same shape (%s vs. %s).' % (img1.shape, img2.shape))
if img1.ndim != 4:
raise RuntimeError('Input images must have four dimensions, not %d' % img1.ndim)
img1 = img1.astype(np.float32)
img2 = img2.astype(np.float32)
_, height, width, _ = img1.shape
# Filter size can't be larger than height or width of images.
size = min(filter_size, height, width)
# Scale down sigma if a smaller filter size is used.
sigma = size * filter_sigma / filter_size if filter_size else 0
if filter_size:
window = np.reshape(FSpecialGauss(size, sigma), (1, size, size, 1))
mu1 = signal.fftconvolve(img1, window, mode='valid')
mu2 = signal.fftconvolve(img2, window, mode='valid')
sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
else:
# Empty blur kernel so no need to convolve.
mu1, mu2 = img1, img2
sigma11 = img1 * img1
sigma22 = img2 * img2
sigma12 = img1 * img2
mu11 = mu1 * mu1
mu22 = mu2 * mu2
mu12 = mu1 * mu2
sigma11 -= mu11
sigma22 -= mu22
sigma12 -= mu12
# Calculate intermediate values used by both ssim and cs_map.
c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2
v1 = 2.0 * sigma12 + c2
v2 = sigma11 + sigma22 + c2
ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)),
axis=(1, 2, 3)) # Return for each image individually.
cs = np.mean(v1 / v2, axis=(1, 2, 3))
return ssim, cs
def HoxDownsample(img):
"""_HoxDownsample"""
return (img[:, 0::2, 0::2, :] + img[:, 1::2, 0::2, :] + img[:, 0::2, 1::2, :] + img[:, 1::2, 1::2, :]) * 0.25
def msssim(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, weights=None):
"""Return the MS-SSIM score between `img1` and `img2`.
This function implements Multi-Scale Structural Similarity (MS-SSIM) Image
Quality Assessment according to Zhou Wang's paper, "Multi-scale structural
similarity for image quality assessment" (2003).
Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
Author's MATLAB implementation:
http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
Arguments:
img1: Numpy array holding the first RGB image batch.
img2: Numpy array holding the second RGB image batch.
max_val: the dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
filter_size: Size of blur kernel to use (will be reduced for small images).
filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced
for small images).
k1: Constant used to maintain stability in the SSIM calculation (0.01 in
the original paper).
k2: Constant used to maintain stability in the SSIM calculation (0.03 in
the original paper).
weights: List of weights for each level; if none, use five levels and the
weights from the original paper.
Returns:
MS-SSIM score between `img1` and `img2`.
Raises:
RuntimeError: If input images don't have the same shape or don't have four
dimensions: [batch_size, height, width, depth].
"""
if img1.shape != img2.shape:
raise RuntimeError('Input images must have the same shape (%s vs. %s).' % (img1.shape, img2.shape))
if img1.ndim != 4:
raise RuntimeError('Input images must have four dimensions, not %d' % img1.ndim)
# Note: default weights don't sum to 1.0 but do match the paper / matlab code.
weights = np.array(weights if weights else [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
levels = weights.size
im1, im2 = [x.astype(np.float32) for x in [img1, img2]]
mssim = []
mcs = []
for _ in range(levels):
ssim, cs = SSIMForMultiScale(
im1, im2, max_val=max_val, filter_size=filter_size,
filter_sigma=filter_sigma, k1=k1, k2=k2)
mssim.append(ssim)
mcs.append(cs)
im1, im2 = [HoxDownsample(x) for x in [im1, im2]]
# Clip to zero. Otherwise we get NaNs.
mssim = np.clip(np.asarray(mssim), 0.0, np.inf)
mcs = np.clip(np.asarray(mcs), 0.0, np.inf)
# Average over images only at the end.
return np.mean(np.prod(mcs[:-1, :] ** weights[:-1, np.newaxis], axis=0) * (mssim[-1, :] ** weights[-1]))
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,11 +13,36 @@ ...@@ -13,11 +13,36 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Dnet define""" """Dnet define"""
from mindspore import context
from mindspore import dtype as mstype
from mindspore import ops, nn from mindspore import ops, nn
import mindspore
from src.customer_layer import EqualizedConv2d, EqualizedLinear, num_flat_features from src.customer_layer import EqualizedConv2d, EqualizedLinear, num_flat_features
class CustomMaxPool2x2(nn.Cell):
"""CustomMaxPool2x2"""
def __init__(self):
super().__init__()
self.kernel_size = 2
self.stride = 2
self.max_op = ops.ReduceMax()
def construct(self, x):
"""CustomMaxPool2x2
Returns:
output.
"""
shape = x.shape
batch, channel, height, width = shape
x_reshaped = x.reshape((batch, channel, height // 2, 2, width // 2, 2))
x_pooled = self.max_op(x_reshaped, 5)
x_pooled = self.max_op(x_pooled, 3)
return x_pooled
class DNet4_4_Train(nn.Cell): class DNet4_4_Train(nn.Cell):
"""DNet4_4_Train""" """DNet4_4_Train"""
...@@ -28,7 +53,7 @@ class DNet4_4_Train(nn.Cell): ...@@ -28,7 +53,7 @@ class DNet4_4_Train(nn.Cell):
dimInput=3): dimInput=3):
super(DNet4_4_Train, self).__init__() super(DNet4_4_Train, self).__init__()
self.dimInput = dimInput self.dimInput = dimInput
self.depthScale0 = depthScale0 # 512 self.depthScale0 = depthScale0
self.fromRGBLayers = EqualizedConv2d(dimInput, depthScale0, 1, padding=0, pad_mode="same", has_bias=True) self.fromRGBLayers = EqualizedConv2d(dimInput, depthScale0, 1, padding=0, pad_mode="same", has_bias=True)
self.dimEntryScale0 = depthScale0 self.dimEntryScale0 = depthScale0
self.groupScale0 = EqualizedConv2d(self.dimEntryScale0, depthScale0, 3, padding=1, pad_mode="pad", self.groupScale0 = EqualizedConv2d(self.dimEntryScale0, depthScale0, 3, padding=1, pad_mode="pad",
...@@ -104,7 +129,12 @@ class DNetNext_Train(nn.Cell): ...@@ -104,7 +129,12 @@ class DNetNext_Train(nn.Cell):
self.groupScale0 = EqualizedConv2d(depthNewScale, depthNewScale, 3, padding=1, pad_mode="pad", has_bias=True) self.groupScale0 = EqualizedConv2d(depthNewScale, depthNewScale, 3, padding=1, pad_mode="pad", has_bias=True)
self.groupScale1 = EqualizedConv2d(depthNewScale, depthLastScale, 3, padding=1, pad_mode="pad", has_bias=True) self.groupScale1 = EqualizedConv2d(depthNewScale, depthLastScale, 3, padding=1, pad_mode="pad", has_bias=True)
self.leakyRelu = nn.LeakyReLU(leakyReluLeak) self.leakyRelu = nn.LeakyReLU(leakyReluLeak)
self.avgPool2d = ops.MaxPool(kernel_size=2, strides=2) if context.get_context('device_target') != "GPU":
self.avgPool2d = ops.MaxPool(kernel_size=2, strides=2)
else:
# GPU does not support the second derivative for MaxPool (MaxPoolGradGrad).
# Thus, we use our custom implementation based on the transpose and ReduceMax operations.
self.avgPool2d = CustomMaxPool2x2()
self.cast = ops.Cast() self.cast = ops.Cast()
def construct(self, x, alpha=0): def construct(self, x, alpha=0):
...@@ -113,14 +143,14 @@ class DNetNext_Train(nn.Cell): ...@@ -113,14 +143,14 @@ class DNetNext_Train(nn.Cell):
Returns: Returns:
output. output.
""" """
mid = self.cast(x, mindspore.float16) mid = self.cast(x, mstype.float16)
y = self.avgPool2d(mid) y = self.avgPool2d(mid)
y = self.cast(y, mindspore.float32) y = self.cast(y, mstype.float32)
y = self.leakyRelu(self.last_fromRGBLayers(y)) y = self.leakyRelu(self.last_fromRGBLayers(y))
x = self.leakyRelu(self.fromRGBLayers(x)) x = self.leakyRelu(self.fromRGBLayers(x))
x = self.leakyRelu(self.groupScale0(x)) x = self.leakyRelu(self.groupScale0(x))
x = self.leakyRelu(self.groupScale1(x)) x = self.leakyRelu(self.groupScale1(x))
x = self.cast(x, mindspore.float16) x = self.cast(x, mstype.float16)
x = self.avgPool2d(x) x = self.avgPool2d(x)
x = alpha * y + (1 - alpha) * x x = alpha * y + (1 - alpha) * x
out = self.last_Dnet(x) out = self.last_Dnet(x)
...@@ -140,7 +170,12 @@ class DNetNext_Last(nn.Cell): ...@@ -140,7 +170,12 @@ class DNetNext_Last(nn.Cell):
self.groupScale0 = dNetNext_Train.groupScale0 self.groupScale0 = dNetNext_Train.groupScale0
self.groupScale1 = dNetNext_Train.groupScale1 self.groupScale1 = dNetNext_Train.groupScale1
self.leakyRelu = dNetNext_Train.leakyRelu self.leakyRelu = dNetNext_Train.leakyRelu
self.avgPool2d = ops.MaxPool(kernel_size=2, strides=2) if context.get_context('device_target') != "GPU":
self.avgPool2d = ops.MaxPool(kernel_size=2, strides=2)
else:
# GPU does not support the second derivative for MaxPool (MaxPoolGradGrad).
# Thus, we use our custom implementation based on the transpose and ReduceMax operations.
self.avgPool2d = CustomMaxPool2x2()
self.cast = ops.Cast() self.cast = ops.Cast()
def construct(self, x): def construct(self, x):
...@@ -151,8 +186,8 @@ class DNetNext_Last(nn.Cell): ...@@ -151,8 +186,8 @@ class DNetNext_Last(nn.Cell):
""" """
x = self.leakyRelu(self.groupScale0(x)) x = self.leakyRelu(self.groupScale0(x))
x = self.leakyRelu(self.groupScale1(x)) x = self.leakyRelu(self.groupScale1(x))
x = self.cast(x, mindspore.float16) x = self.cast(x, mstype.float16)
x = self.avgPool2d(x) x = self.avgPool2d(x)
x = self.cast(x, mindspore.float32) x = self.cast(x, mstype.float32)
out = self.last_Dnet(x) out = self.last_Dnet(x)
return out return out
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TimeMonitor class."""
import time
class TimeMonitor:
"""
Monitor the time in training.
steps_log_interval (int): Number of steps between logging the intermediate performance values.
"""
def __init__(self, steps_log_interval):
self.num_steps = 0
self.num_epochs = 0
self.steps_log_interval = steps_log_interval
self.epoch_time = None
self.step_start_time = None
self.data_iter_time_mark = None
self.steps_accumulated_time = 0
self.data_iter_accumulated_time = 0
def epoch_begin(self):
"""
Record time at the begin of epoch.
"""
self.num_steps = 0
self.steps_accumulated_time = 0
self.data_iter_accumulated_time = 0
self.epoch_time = time.time()
self.data_iter_time_mark = self.epoch_time
def step_start(self):
self.step_start_time = time.time()
self.data_iter_accumulated_time += self.step_start_time - self.data_iter_time_mark
if self.num_steps == 0:
print(f'Dataset first iteration time: {self.data_iter_accumulated_time * 1000:5.3f} ms', flush=True)
def data_iter_end(self):
"""Record the time of the data iteration end
(for computing the data loader time)
"""
self.data_iter_time_mark = time.time()
def step_end(self):
"""Step end callback"""
self.num_steps += 1
self.steps_accumulated_time += time.time() - self.step_start_time
if self.num_steps % self.steps_log_interval == 0:
print(
f'Intermediate: epoch {self.num_epochs} step {self.num_steps}, '
f'per_step_time {self.steps_accumulated_time / self.steps_log_interval * 1000:5.3f} ms, '
f'(not including the data loader time per step '
f'{self.data_iter_accumulated_time / self.steps_log_interval * 1000:5.3f} ms)',
flush=True,
)
self.steps_accumulated_time = 0
self.data_iter_accumulated_time = 0
def epoch_end(self):
"""
Print process cost time at the end of epoch.
"""
if self.epoch_time is None:
return
epoch_seconds = (time.time() - self.epoch_time) * 1000
if not isinstance(self.num_steps, int) or self.num_steps < 1:
raise ValueError("data_size must be positive int.")
step_seconds = epoch_seconds / self.num_steps
print(f"epoch {self.num_epochs} time: {epoch_seconds:5.3f} ms, "
f"per step time: {step_seconds:5.3f} ms")
self.num_epochs += 1
self.epoch_time = None
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,30 +15,42 @@ ...@@ -15,30 +15,42 @@
""" """
#################train pgan######################## #################train pgan########################
""" """
import datetime
import os import os
import time import pathlib
import numpy as np import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore import dataset as ds
from mindspore import dtype as mstype
from mindspore import load_checkpoint, load_param_into_net, save_checkpoint
from mindspore import nn from mindspore import nn
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_group_size, get_rank from mindspore.communication.management import init, get_group_size, get_rank
import mindspore from mindspore.context import ParallelMode
import mindspore.dataset as ds
from mindspore import load_checkpoint, load_param_into_net, save_checkpoint from model_utils.config import config
from src.image_transform import Normalize, NumpyResize, TransporeAndDiv from model_utils.device_adapter import get_device_id, get_device_num
from model_utils.moxing_adapter import moxing_wrapper
from src.dataset import ImageDataset from src.dataset import ImageDataset
from src.image_transform import Normalize, NumpyResize, TransporeAndDiv, Crop
from src.network_D import DNet4_4_Train, DNetNext_Train, DNet4_4_Last, DNetNext_Last from src.network_D import DNet4_4_Train, DNetNext_Train, DNet4_4_Last, DNetNext_Last
from src.network_G import GNet4_4_Train, GNet4_4_last, GNetNext_Train, GNetNext_Last from src.network_G import GNet4_4_Train, GNet4_4_last, GNetNext_Train, GNetNext_Last
from src.optimizer import AllLossD, AllLossG from src.optimizer import AllLossD, AllLossG
from src.time_monitor import TimeMonitor
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
def set_every(num): def set_every(num):
"""set random seed"""
set_seed(num) set_seed(num)
np.random.seed(num) np.random.seed(num)
set_every(1) set_every(1)
ds.config.set_prefetch_size(16)
def _get_rank_info(): def _get_rank_info():
""" """
get rank size and rank id get rank size and rank id
...@@ -51,36 +63,55 @@ def _get_rank_info(): ...@@ -51,36 +63,55 @@ def _get_rank_info():
rank_size = rank_id = None rank_size = rank_id = None
return rank_size, rank_id return rank_size, rank_id
def cell_deepcopy(gnet, avg_gnet): def cell_deepcopy(gnet, avg_gnet):
"""cell_deepcopy""" """cell_deepcopy"""
for p, avg_p in zip(gnet.trainable_params(), for param, avg_param in zip(gnet.trainable_params(),
avg_gnet.trainable_params()): avg_gnet.trainable_params()):
avg_p.set_data(p.clone()) avg_param.set_data(param.clone())
def cell_deepcopy_update(gnet, avg_gnet): def cell_deepcopy_update(gnet, avg_gnet):
"""cell_deepcopy_update""" """cell_deepcopy_update"""
for p, avg_p in zip(gnet.trainable_params(), for param, avg_param in zip(gnet.trainable_params(),
avg_gnet.trainable_params()): avg_gnet.trainable_params()):
new_p = avg_p * 0.999 + p * 0.001 new_p = avg_param * 0.999 + param * 0.001
avg_p.set_data(new_p) avg_param.set_data(new_p)
def save_checkpoint_g(avg, gnet, dnet, ckpt_dir, i_batch): def save_checkpoints(avg, gnet, dnet, ckpt_dir, i_batch):
"""save_checkpoint""" """save_checkpoint"""
save_checkpoint(gnet, os.path.join(ckpt_dir, "G_{}.ckpt".format(i_batch))) pathlib.Path(ckpt_dir).mkdir(parents=True, exist_ok=True)
save_checkpoint(avg, os.path.join(ckpt_dir, "AvG_{}.ckpt".format(i_batch))) save_checkpoint(gnet, os.path.join(ckpt_dir, f"G_{i_batch}.ckpt"))
save_checkpoint(dnet, os.path.join(ckpt_dir, "D_{}.ckpt".format(i_batch))) save_checkpoint(avg, os.path.join(ckpt_dir, f"AvG_{i_batch}.ckpt"))
save_checkpoint(dnet, os.path.join(ckpt_dir, f"D_{i_batch}.ckpt"))
def load_checkpoints(gnet, dnet, cfg):
"""load_checkpoints"""
param_dict_g = load_checkpoint(cfg.resume_check_g)
param_dict_d = load_checkpoint(cfg.resume_check_d)
load_param_into_net(gnet, param_dict_g)
load_param_into_net(dnet, param_dict_d)
return gnet, dnet
def modelarts_pre_process(): def modelarts_pre_process():
'''modelarts pre process function.''' """modelarts pre process function."""
config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir) config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir)
def getDataset(args, size=None):
def get_dataset(args, size=None):
"""getDataset """getDataset
Returns: Returns:
output. output.
""" """
transformList = [NumpyResize(size), TransporeAndDiv(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] transform_list = [Crop(), NumpyResize(size), TransporeAndDiv(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return ImageDataset(args.train_data_path, transform=transformList) return ImageDataset(args.train_data_path, transform=transform_list)
def cal_each_batch_alpha(): def cal_each_batch_alpha():
"""buildNoiseData""" """buildNoiseData"""
each_batch_alpha = [] each_batch_alpha = []
...@@ -99,21 +130,22 @@ def cal_each_batch_alpha(): ...@@ -99,21 +130,22 @@ def cal_each_batch_alpha():
each_batch_alpha.append(new_batch_alpha) each_batch_alpha.append(new_batch_alpha)
return each_batch_alpha return each_batch_alpha
def getOptimizerD(dnet, args):
def get_optimize_d(dnet, lr, cfg):
"""getOptimizerD """getOptimizerD
Returns: Returns:
output. output.
""" """
manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** args.loss_scale_value, manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** cfg.loss_scale_value,
scale_factor=args.scale_factor, scale_window=args.scale_factor) scale_factor=cfg.scale_factor, scale_window=cfg.scale_factor)
lossCell = AllLossD(dnet) loss_cell = AllLossD(dnet)
opti = nn.Adam(dnet.trainable_params(), beta1=0.0001, beta2=0.99, learning_rate=args.lr) opti = nn.Adam(dnet.trainable_params(), beta1=0.0001, beta2=0.99, learning_rate=lr)
train_network = nn.TrainOneStepWithLossScaleCell(lossCell, opti, scale_sense=manager) train_network = nn.TrainOneStepWithLossScaleCell(loss_cell, opti, scale_sense=manager)
return train_network return train_network
def getOptimizerG(gnet, dnet, args): def get_optimizer_g(gnet, dnet, args):
"""getOptimizerG """getOptimizerG
Returns: Returns:
...@@ -121,31 +153,26 @@ def getOptimizerG(gnet, dnet, args): ...@@ -121,31 +153,26 @@ def getOptimizerG(gnet, dnet, args):
""" """
manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** args.loss_scale_value, manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** args.loss_scale_value,
scale_factor=args.scale_factor, scale_window=args.scale_factor) scale_factor=args.scale_factor, scale_window=args.scale_factor)
lossCell = AllLossG(gnet, dnet) loss_cell = AllLossG(gnet, dnet)
opti = nn.Adam(gnet.trainable_params(), opti = nn.Adam(gnet.trainable_params(),
beta1=0.0001, beta2=0.99, learning_rate=args.lr) beta1=0.0001, beta2=0.99, learning_rate=args.lr)
train_network = nn.TrainOneStepWithLossScaleCell(lossCell, opti, scale_sense=manager) train_network = nn.TrainOneStepWithLossScaleCell(loss_cell, opti, scale_sense=manager)
return train_network return train_network
def buildNoiseData(n_samples): def build_noise_data(n_samples):
"""buildNoiseData """buildNoiseData
Returns: Returns:
output. output.
""" """
inputLatent = np.random.randn(n_samples, 512) input_latent = np.random.randn(n_samples, 512)
inputLatent = mindspore.Tensor(inputLatent, mindspore.float32) input_latent = Tensor(input_latent, mstype.float32)
return inputLatent return input_latent
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train(): def prepare_context(cfg):
"""buildNoiseData""" """prepare context"""
cfg = config
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
cfg.device_num = get_device_num()
if not os.path.exists(config.ckpt_save_dir):
os.mkdir(config.ckpt_save_dir)
if cfg.device_target == "Ascend": if cfg.device_target == "Ascend":
device_id = get_device_id() device_id = get_device_id()
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
...@@ -154,77 +181,112 @@ def run_train(): ...@@ -154,77 +181,112 @@ def run_train():
context.set_auto_parallel_context(device_num=cfg.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=cfg.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
init() init()
if cfg.device_target == "GPU":
device_id = get_device_id()
context.set_context(device_id=device_id)
if cfg.device_num > 1:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=cfg.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
def construct_network(gnet, dnet, avg_gnet, depth, scale):
"""construct_network"""
if scale == 4:
dnet = DNet4_4_Train(depth, leakyReluLeak=0.2, sizeDecisionLayer=1, dimInput=3)
gnet = GNet4_4_Train(512, depth, leakyReluLeak=0.2, dimOutput=3)
avg_gnet = GNet4_4_Train(512, depth, leakyReluLeak=0.2, dimOutput=3)
elif scale == 8:
last_dnet = DNet4_4_Last(dnet)
last_gnet = GNet4_4_last(gnet)
dnet = DNetNext_Train(depth, last_Dnet=last_dnet, leakyReluLeak=0.2, dimInput=3)
gnet = GNetNext_Train(depth, last_Gnet=last_gnet, leakyReluLeak=0.2, dimOutput=3)
last_avg_gnet = GNet4_4_last(avg_gnet)
avg_gnet = GNetNext_Train(depth, last_Gnet=last_avg_gnet, leakyReluLeak=0.2, dimOutput=3)
else:
last_dnet = DNetNext_Last(dnet)
last_gnet = GNetNext_Last(gnet)
dnet = DNetNext_Train(depth, last_Dnet=last_dnet, leakyReluLeak=0.2, dimInput=3)
gnet = GNetNext_Train(depth, last_gnet, leakyReluLeak=0.2, dimOutput=3)
last_avg_gnet = GNetNext_Last(avg_gnet)
avg_gnet = GNetNext_Train(depth, last_avg_gnet, leakyReluLeak=0.2, dimOutput=3)
cell_deepcopy(gnet, avg_gnet)
return gnet, dnet, avg_gnet
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
"""run_train"""
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
config.device_num = get_device_num()
prepare_context(config)
if config.lr_list:
if len(config.lr_list) != len(config.scales):
raise ValueError(f"len(lr_list) and len(config.scales) must be same")
else:
config.lr_list = [config.lr] * len(config.scales)
if config.batch_size_list:
if len(config.batch_size_list) != len(config.scales):
raise ValueError(f"len(lr_list) and len(config.scales) must be same")
else:
config.batch_size_list = [config.batch_size] * len(config.scales)
gnet, dnet, avg_gnet = None, None, None
each_batch_alpha = cal_each_batch_alpha() each_batch_alpha = cal_each_batch_alpha()
for scale_index, scale in enumerate(cfg.scales): time_monitor = TimeMonitor(200)
this_scale_checkpoint = os.path.join(cfg.ckpt_save_dir, str(scale)) for scale_index, scale in enumerate(config.scales):
if not os.path.exists(this_scale_checkpoint): print('Scale', scale, flush=True)
os.mkdir(this_scale_checkpoint) this_scale_checkpoint_path = os.path.join(config.ckpt_save_dir, str(scale))
if scale == 4:
dnet = DNet4_4_Train(cfg.depth[scale_index], leakyReluLeak=0.2, sizeDecisionLayer=1, dimInput=3) gnet, dnet, avg_gnet = construct_network(gnet, dnet, avg_gnet, config.depth[scale_index], scale)
gnet = GNet4_4_Train(512, cfg.depth[scale_index], leakyReluLeak=0.2, dimOutput=3) if config.resume_load_scale != -1 and scale <= config.resume_load_scale:
avg_gnet = GNet4_4_Train(512, cfg.depth[scale_index], leakyReluLeak=0.2, dimOutput=3) if scale == config.resume_load_scale:
elif scale == 8: gnet, dnet = load_checkpoints(gnet, dnet, config)
last_dnet = DNet4_4_Last(dnet)
last_gnet = GNet4_4_last(gnet)
dnet = DNetNext_Train(cfg.depth[scale_index], last_Dnet=last_dnet, leakyReluLeak=0.2, dimInput=3)
gnet = GNetNext_Train(cfg.depth[scale_index], last_Gnet=last_gnet, leakyReluLeak=0.2, dimOutput=3)
last_avg_gnet = GNet4_4_last(avg_gnet)
avg_gnet = GNetNext_Train(cfg.depth[scale_index], last_Gnet=last_avg_gnet, leakyReluLeak=0.2, dimOutput=3)
else:
last_dnet = DNetNext_Last(dnet)
last_gnet = GNetNext_Last(gnet)
dnet = DNetNext_Train(cfg.depth[scale_index], last_Dnet=last_dnet, leakyReluLeak=0.2, dimInput=3)
gnet = GNetNext_Train(cfg.depth[scale_index], last_gnet, leakyReluLeak=0.2, dimOutput=3)
last_avg_gnet = GNetNext_Last(avg_gnet)
avg_gnet = GNetNext_Train(cfg.depth[scale_index], last_avg_gnet, leakyReluLeak=0.2, dimOutput=3)
cell_deepcopy(gnet, avg_gnet)
if cfg.resume_load_scale != -1 and scale < cfg.resume_load_scale:
continue
elif cfg.resume_load_scale != -1 and scale == cfg.resume_load_scale:
param_dict_g = load_checkpoint(cfg.resume_check_g)
param_dict_d = load_checkpoint(cfg.resume_check_d)
load_param_into_net(gnet, param_dict_g)
load_param_into_net(dnet, param_dict_d)
continue continue
optimizerD = getOptimizerD(dnet, cfg)
optimizerG = getOptimizerG(gnet, dnet, cfg) optimizer_d = get_optimize_d(dnet, config.lr_list[scale_index], config)
dbLoader = getDataset(cfg, (scale, scale)) optimizer_g = get_optimizer_g(gnet, dnet, config)
rank_size, rank_id = _get_rank_info() rank_size, rank_id = _get_rank_info()
if rank_id: if rank_id:
this_scale_checkpoint = os.path.join(this_scale_checkpoint, "rank_{}".format(rank_id)) this_scale_checkpoint_path = os.path.join(this_scale_checkpoint_path, f"rank_{rank_id}")
if not os.path.exists(this_scale_checkpoint):
os.mkdir(this_scale_checkpoint) db_loader = get_dataset(config, (scale, scale))
dataset = ds.GeneratorDataset(dbLoader, column_names=["data", "label"], shuffle=True, dataset = ds.GeneratorDataset(db_loader, column_names=["data", "label"], shuffle=True,
num_parallel_workers=4, num_shards=rank_size, shard_id=rank_id) num_parallel_workers=4, num_shards=rank_size, shard_id=rank_id)
dataset = dataset.batch(batch_size=cfg.batch_size, drop_remainder=True) dataset = dataset.batch(batch_size=config.batch_size_list[scale_index], drop_remainder=True)
dataset_iter = dataset.create_tuple_iterator() dataset_iter = dataset.create_tuple_iterator()
print('Dataset size', dataset.get_dataset_size(), flush=True)
i_batch = 0 i_batch = 0
time_stamp = time.time() while i_batch < config.num_batch[scale_index] / config.device_num:
while i_batch < cfg.num_batch[scale_index] / cfg.device_num: time_monitor.epoch_begin()
epoch = 0
for data in dataset_iter: for data in dataset_iter:
time_monitor.step_start()
alpha = each_batch_alpha[scale_index][i_batch] alpha = each_batch_alpha[scale_index][i_batch]
alpha = mindspore.Tensor(alpha, mindspore.float32) alpha = Tensor(alpha, mstype.float32)
inputs_real = data[0] inputs_real = data[0]
n_samples = inputs_real.shape[0] n_samples = inputs_real.shape[0]
inputLatent = buildNoiseData(n_samples) fake_image = gnet(build_noise_data(n_samples), alpha)
fake_image = gnet(inputLatent, alpha) loss_d, overflow, _ = optimizer_d(inputs_real, fake_image.copy(), alpha)
lossD, overflow, _ = optimizerD(inputs_real, fake_image, alpha) loss_g, overflow, _ = optimizer_g(build_noise_data(n_samples), alpha)
inputNoise = buildNoiseData(n_samples)
lossG, overflow, _ = optimizerG(inputNoise, alpha)
cell_deepcopy_update(gnet=gnet, avg_gnet=avg_gnet) cell_deepcopy_update(gnet=gnet, avg_gnet=avg_gnet)
i_batch += 1 i_batch += 1
if i_batch >= cfg.num_batch[scale_index] / cfg.device_num: time_monitor.step_end()
if i_batch >= config.num_batch[scale_index] / config.device_num:
break break
if i_batch % 100 == 0: if i_batch % 100 == 0:
time_now = time.time() time_now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
print('batch_i:{} alpha:{} loss G:{} loss D:{} overflow:{}'.format(i_batch, alpha, print(f'batch_i:{i_batch} alpha:{alpha} loss G:{loss_g} '
lossG, lossD, overflow)) f'loss D:{loss_d} overflow:{overflow},time:{time_now}')
print("per step time is ", (time_now - time_stamp)/100, "s") if (i_batch + 1) % config.model_save_step == 0:
time_stamp = time_now save_checkpoints(avg_gnet, gnet, dnet, this_scale_checkpoint_path, i_batch)
if (i_batch + 1) % cfg.model_save_step == 0: time_monitor.data_iter_end()
save_checkpoint_g(avg_gnet, gnet, dnet, this_scale_checkpoint, i_batch) time_monitor.epoch_end()
epoch += 1 save_checkpoints(avg_gnet, gnet, dnet, this_scale_checkpoint_path, i_batch)
save_checkpoint_g(avg_gnet, gnet, dnet, this_scale_checkpoint, i_batch) dataset.close_pool()
if __name__ == '__main__': if __name__ == '__main__':
run_train() run_train()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment