diff --git a/research/cv/wgan_gp/README_CN.md b/research/cv/wgan_gp/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..2396ebd717179e96bdf79faae21da595a2c90bee --- /dev/null +++ b/research/cv/wgan_gp/README_CN.md @@ -0,0 +1,184 @@ +# 目录 + +<!-- TOC --> + +- [目录](#目录) +- [WGAN-GP描述](#wgan-gp描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [单机训练](#单机训练) +- [模型描述](#模型描述) + - [性能](#性能) + - [训练性能](#训练性能) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + +<!-- /TOC --> + +# WGAN-GP描述 + +WGAN-GP(Wasserstein GAN-Gradient Penalty)是一种包含DCGAN结构判别器与生成器的生成对抗网络,它在WGAN基础上用梯度惩罚替代了梯度剪裁,在损失函数引入了判别器输出相对输入的二阶导数,作为规范判别器损失模的函数,解决了WGAN随机不收敛与生成样本质量差的问题。 + +[论文](https://arxiv.org/pdf/1704.00028v3.pdf):Improved Training of Wasserstein GANs + +# 模型架构 + +WGAN-GP网络包含两部分,生成器网络和判别器网络。判别器网络采用卷积DCGAN的架构,即多层二维卷积相连。生成器网络采用卷积DCGAN生成器结构。输入数据包括真实图片数据和噪声数据,数据集Cifar10的真实图片resize到32*32,噪声数据随机生成。 + +# 数据集 + +[CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>) + +- 数据集大小:175M, 60000张10分类彩色图像 + - 训练集:146M,共50000张图像。 + - 注:对于生成对抗网络,推理部分是传入噪声数据生成图片,故无需使用测试集数据。 +- 数据格式:二进制文件 + +# 环境要求 + +- 硬件(Ascend) + - 使用Ascend来搭建硬件环境。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install) +- 如需查看详情,请参见如下资源: + - [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html) + +# 快速入门 + +通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估: + +- Ascend处理器环境运行 + + ```python + # 运行单机训练示例: + bash run_train.sh [DATAROOT] [DEVICE_ID] + + + # 运行评估示例 + bash run_eval.sh [DEVICE_ID] [CONFIG_PATH] [CKPT_FILE_PATH] [OUTPUT_DIR] [NIMAGES] + ``` + +# 脚本说明 + +## 脚本及样例代码 + +```bash +├── model_zoo + ├── README.md // 所有模型相关说明 + ├── WGAN-GP + ├── README.md // WGAN-GP相关说明 + ├── scripts + │ ├── run_train.sh // 单机到Ascend处理器的shell脚本 + │ ├── run_eval.sh // Ascend评估的shell脚本 + ├── src + │ ├── dataset.py // 创建数据集及数据预处理 + │ ├── model.py // WGAN-GP生成器与判别器定义 + │ ├── args.py // 参数配置文件 + │ ├── cell.py // 模型单步训练文件 + ├── train.py // 训练脚本 + ├── eval.py // 评估脚本 +``` + +## 脚本参数 + +在args.py中可以同时配置训练参数、评估参数及模型导出参数。 + + ```python + # common_config + 'device_target': 'Ascend', # 运行设备 + 'device_id': 0, # 用于训练或评估数据集的设备ID + + # train_config + 'dataroot': None, # 数据集路径,必须输入,不能为空 + 'workers': 8, # 数据加载线程数 + 'batchSize': 64, # 批处理大小 + 'imageSize': 32, # 图片尺寸大小 + 'DIM': 128, # GAN网络隐藏层大小 + 'niter': 1200, # 网络训练的epoch数 + 'save_iterations': 1000, # 保存模型文件的生成器迭代次数 + 'lrD': 0.0001, # 判别器初始学习率 + 'lrG': 0.0001, # 生成器初始学习率 + 'beta1': 0.5, # Adam优化器beta1参数 + 'beta2': 0.9, # Adam优化器beta2参数 + 'netG': '', # 恢复训练的生成器的ckpt文件路径 + 'netD': '', # 恢复训练的判别器的ckpt文件路径 + 'Diters': 5, # 每训练一次生成器需要训练判别器的次数 + 'experiment': None, # 保存模型和生成图片的路径,若不指定,则使用默认路径 + + # eval_config + 'ckpt_file': None, # 训练时保存的生成器的权重文件.ckpt的路径,必须指定 + 'output_dir': None, # 生成图片的输出路径,必须指定 + ``` + +更多配置细节请参考脚本`args.py`。 + +## 训练过程 + +### 单机训练 + +- Ascend处理器环境运行 + + ```bash + bash run_train.sh [DATAROOT] [DEVICE_ID] + ``` + + 上述python命令将在后台运行,您可以通过train.log文件查看结果。 + + 训练结束后,您可在存储的文件夹(默认是./samples)下找到生成的图片、检查点文件和.json文件。采用以下方式得到损失值: + + ```bash + [0/1200][230/937][23] Loss_D: -379.555344 Loss_G: -33.761238 + [0/1200][235/937][24] Loss_D: -214.557617 Loss_G: -23.762344 + ... + ``` + +## 推理过程 + +### 推理 + +- 在Ascend环境下评估 + + 在运行以下命令之前,请检查用于推理的检查点和json文件路径,并设置输出图片的路径。 + + ```bash + bash run_eval.sh [DEVICE_ID] [CKPT_FILE_PATH] [OUTPUT_DIR] + ``` + + 上述python命令将在后台运行,您可以通过eval/eval.log文件查看日志信息,在输出图片的路径下查看生成的图片。 + +# 模型描述 + +## 性能 + +### 训练性能 + +| 参数 | Ascend | +| ------------------------- | ----------------------------------------------------- | +| 资源 | Ascend 910 ;CPU 2.60GHz,192核;内存:755G | +| 上传日期 | 2022-08-01 | +| MindSpore版本 | 1.8.0 | +| 数据集 | CIFAR-10 | +| 训练参数 | max_epoch=1200, batch_size=64, lr_init=0.0001 | +| 优化器 | Adam | +| 损失函数 | 自定义损失函数 | +| 输出 | 生成的图片 | +| 速度 | 单卡:0.06秒/步 | + +生成图片效果如下: + + + +# 随机情况说明 + +在train.py中,我们设置了随机种子。 + +# ModelZoo主页 + + 请浏览官网[主页](https://gitee.com/mindspore/models)。 diff --git a/research/cv/wgan_gp/eval.py b/research/cv/wgan_gp/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..9c1a53dab7e0c9cc4721dc07a1ebd2acdcbd072f --- /dev/null +++ b/research/cv/wgan_gp/eval.py @@ -0,0 +1,71 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import context +import numpy as np +from PIL import Image + +from src.model import DcganG +from src.args import get_args + +def save_image(img, img_path, IMAGE_SIZE): + """save image""" + mul = ops.Mul() + add = ops.Add() + if isinstance(img, Tensor): + img = mul(img, 255 * 0.5) + img = add(img, 255 * 0.5) + + img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1)) + + elif not isinstance(img, np.ndarray): + raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img))) + + IMAGE_ROW = 8 # Row num + IMAGE_COLUMN = 8 # Column num + PADDING = 2 # Interval of small pictures + to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1), + IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture + # cycle + ii = 0 + for y in range(1, IMAGE_ROW + 1): + for x in range(1, IMAGE_COLUMN + 1): + from_image = Image.fromarray(img[ii]) + to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y)) + ii = ii + 1 + + to_image.save(img_path) # save + +if __name__ == "__main__": + + args_opt = get_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + context.set_context(device_id=args_opt.device_id) + + netG = DcganG(args_opt.DIM) + + # load weights + load_param_into_net(netG, load_checkpoint(args_opt.ckpt_file)) + + fixed_noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + + fake = netG(fixed_noise) + save_image(fake, '{0}/generated_samples.png'.format(args_opt.output_dir), args_opt.imageSize) + + print("Generate images success!") diff --git a/research/cv/wgan_gp/imgs/fake_samples_200000.png b/research/cv/wgan_gp/imgs/fake_samples_200000.png new file mode 100644 index 0000000000000000000000000000000000000000..7bd6f304c38021c7a4e39f279f4bc0e2b9a7fe04 Binary files /dev/null and b/research/cv/wgan_gp/imgs/fake_samples_200000.png differ diff --git a/research/cv/wgan_gp/requirements.txt b/research/cv/wgan_gp/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..adf4746ea1793617fe8ca72a5622d958893d310a --- /dev/null +++ b/research/cv/wgan_gp/requirements.txt @@ -0,0 +1,2 @@ +Pillow +onnxruntime-gpu \ No newline at end of file diff --git a/research/cv/wgan_gp/scripts/run_eval.sh b/research/cv/wgan_gp/scripts/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..a34da980cc18bdb69ac90b6b2f6002b988772329 --- /dev/null +++ b/research/cv/wgan_gp/scripts/run_eval.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash bash run_eval.sh device_id ckpt_file output_dir" +echo "For example: bash run_eval.sh DEVICE_ID CKPT_FILE_PATH OUTPUT_DIR" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +EXEC_PATH=$(pwd) +echo "$EXEC_PATH" +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +cd ../ +rm -rf eval +mkdir eval +cd ./eval +mkdir src +cd ../ +cp ./*.py ./eval +cp ./src/*.py ./eval/src +cd ./eval + +env > env0.log + +echo "train begin." +python eval.py --device_id $1 --ckpt_file $2 --output_dir $3 > ./eval.log 2>&1 & + +if [ $? -eq 0 ];then + echo "eval success" +else + echo "eval failed" + exit 2 +fi +echo "finish" +cd ../ diff --git a/research/cv/wgan_gp/scripts/run_train.sh b/research/cv/wgan_gp/scripts/run_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..02d0203080d3501d8075701fe555768952df4179 --- /dev/null +++ b/research/cv/wgan_gp/scripts/run_train.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash bash run_train.sh dataroot device_id" +echo "For example: bash run_train.sh /home/cifar10/cifar-10-batches-bin/ 3" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +EXEC_PATH=$(pwd) +echo "$EXEC_PATH" +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +cd ../ +rm -rf train +mkdir train +cd ./train +mkdir src +cd ../ +cp ./*.py ./train +cp ./src/*.py ./train/src +cd ./train + +env > env0.log + +echo "train begin." +python train.py --dataroot $1 --device_id $2 > ./train.log 2>&1 & + +if [ $? -eq 0 ];then + echo "training success" +else + echo "training failed" + exit 2 +fi +echo "finish" +cd ../ diff --git a/research/cv/wgan_gp/src/args.py b/research/cv/wgan_gp/src/args.py new file mode 100644 index 0000000000000000000000000000000000000000..7a934994b765b632146faecf13842038b64641ae --- /dev/null +++ b/research/cv/wgan_gp/src/args.py @@ -0,0 +1,45 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""get args""" +import argparse + +def get_args(): + """Define the common options that are used in training.""" + parser = argparse.ArgumentParser(description='WGAN-GP') + parser.add_argument('--device_target', default='Ascend', help='enables npu') + parser.add_argument('--device_id', type=int, default=0) + + parser.add_argument('--dataroot', default=None, help='path to dataset') + + parser.add_argument('--workers', type=int, help='number of data loading workers', default=8) + parser.add_argument('--batchSize', type=int, default=64, help='input batch size') + parser.add_argument('--imageSize', type=int, default=32, help='the height/width of the input image to network') + parser.add_argument('--DIM', type=int, default=128, help='dimension of input samples') + parser.add_argument('--niter', type=int, default=1200, help='number of epochs to train for') + parser.add_argument('--save_iterations', type=int, default=1000, help='num of gen iterations to save model') + parser.add_argument('--lrD', type=float, default=0.0001, help='learning rate for Critic, default=0.0001') + parser.add_argument('--lrG', type=float, default=0.0001, help='learning rate for Generator, default=0.0001') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') + parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam. default=0.9') + parser.add_argument('--netG', default='', help="path to netG (to continue training)") + parser.add_argument('--netD', default='', help="path to netD (to continue training)") + parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter') + parser.add_argument('--experiment', default="samples", help='Where to store samples and models') + parser.add_argument('--ckpt_file', default=None, help='path to pretrained ckpt model file') + parser.add_argument('--output_dir', default=None, help='output path of generated images') + + args_opt = parser.parse_args() + return args_opt diff --git a/research/cv/wgan_gp/src/cell.py b/research/cv/wgan_gp/src/cell.py new file mode 100644 index 0000000000000000000000000000000000000000..09b39299afb2577dcd29b0b5fbc6bf40bce7cf6a --- /dev/null +++ b/research/cv/wgan_gp/src/cell.py @@ -0,0 +1,67 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore import ops, nn +import mindspore.numpy as mnp + +class GenWithLossCell(nn.Cell): + """Generator with loss(wrapped)""" + + def __init__(self, netG, netD): + super(GenWithLossCell, self).__init__() + self.netG = netG + self.netD = netD + + def construct(self, noise): + + fake = self.netG(noise) + errG = self.netD(fake) + return -errG + + +class DisWithLossCell(nn.Cell): + """ Discriminator with loss(wrapped) """ + + def __init__(self, netG, netD): + super(DisWithLossCell, self).__init__() + self.netG = netG + self.netD = netD + self.gradop = ops.GradOperation() + self.LAMBDA = 100 + self.uniform = ops.UniformReal() + + def compute_gradient_penalty(self, real_samples, fake_samples): + """Calculates the gradient penalty loss for WGAN GP""" + + # Get random interpolation between real and fake samples + alpha = self.uniform((real_samples.shape[0], 1, 1, 1)) + interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)) + + grad_fn = self.gradop(self.netD) + gradients = grad_fn(interpolates) + gradients = gradients.view(gradients.shape[0], -1) + gradient_penalty = ops.reduce_mean(((mnp.norm(gradients, 2, axis=1) - 1) ** 2)) + return gradient_penalty + + def construct(self, real, noise): + + errD_real = self.netD(real) + fake = self.netG(noise) + fake = ops.stop_gradient(fake) + errD_fake = self.netD(fake) + + gradient_penalty = self.compute_gradient_penalty(real, fake) + + return errD_fake - errD_real + gradient_penalty * self.LAMBDA diff --git a/research/cv/wgan_gp/src/dataset.py b/research/cv/wgan_gp/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b555cd8d4946bed66fc44a59fb532ac333260e78 --- /dev/null +++ b/research/cv/wgan_gp/src/dataset.py @@ -0,0 +1,39 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import mindspore.dataset as ds +import mindspore.dataset.vision as C +import mindspore.dataset.transforms as C2 + +def create_dataset(dataroot, batchSize, imageSize, repeat_num=1, workers=8, target='Ascend'): + + # define map operations + resize_op = C.Resize(imageSize) + normalize_op = C.Normalize(mean=(0.5*255, 0.5*255, 0.5*255), std=(0.5*255, 0.5*255, 0.5*255)) + hwc2chw_op = C.HWC2CHW() + + data_set = ds.Cifar10Dataset(dataroot, num_parallel_workers=workers, shuffle=True) + transform = [resize_op, normalize_op, hwc2chw_op] + + type_cast_op = C2.TypeCast(ms.int32) + + data_set = data_set.map(input_columns='image', operations=transform, num_parallel_workers=workers) + data_set = data_set.map(input_columns='label', operations=type_cast_op, num_parallel_workers=workers) + + data_set = data_set.batch(batchSize, drop_remainder=True) + data_set = data_set.repeat(repeat_num) + + return data_set diff --git a/research/cv/wgan_gp/src/model.py b/research/cv/wgan_gp/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9c203ac34d99312b612b36154558ad3e61f476cf --- /dev/null +++ b/research/cv/wgan_gp/src/model.py @@ -0,0 +1,100 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +import mindspore.ops as ops + +class DcgannobnD(nn.Cell): + """ DCGAN Descriminator with no Batchnorm layer """ + def __init__(self, DIM): + super(DcgannobnD, self).__init__() + + self.DIM = DIM + KERNEL_SIZE = 5 + STRIDE = 2 + + main = nn.SequentialCell() + main.append(nn.Conv2d(3, self.DIM, KERNEL_SIZE, STRIDE, 'same')) + main.append(nn.LeakyReLU(0.2)) + + main.append(nn.Conv2d(self.DIM, self.DIM*2, KERNEL_SIZE, STRIDE, 'same')) + main.append(nn.LeakyReLU(0.2)) + + main.append(nn.Conv2d(self.DIM*2, self.DIM*4, KERNEL_SIZE, STRIDE, 'same')) + main.append(nn.LeakyReLU(0.2)) + self.main = main + self.linear = nn.Dense(4*4*4*self.DIM, 1) + + def construct(self, input1): + + output = self.main(input1) + output = output.view(-1, 4*4*4*self.DIM) + output = self.linear(output) + output = ops.reduce_mean(output) + return output + +class DcganG(nn.Cell): + + def __init__(self, DIM): + super(DcganG, self).__init__() + + self.DIM = DIM + KERNEL_SIZE = 5 + STRIDE = 2 + + self.linear = nn.Dense(self.DIM, 4*4*4*self.DIM) + self.bn = nn.BatchNorm2d(4*4*4*self.DIM) + self.relu = nn.ReLU() + + main = nn.SequentialCell() + main.append(nn.Conv2dTranspose( + self.DIM*4, + self.DIM*2, + KERNEL_SIZE, + stride=STRIDE, + weight_init='normal', + pad_mode='same')) + main.append(nn.BatchNorm2d(self.DIM*2)) + main.append(nn.ReLU()) + + main.append(nn.Conv2dTranspose( + self.DIM*2, + self.DIM, + KERNEL_SIZE, + stride=STRIDE, + weight_init='normal', + pad_mode='same')) + main.append(nn.BatchNorm2d(self.DIM)) + main.append(nn.ReLU()) + + main.append(nn.Conv2dTranspose( + self.DIM, + 3, + KERNEL_SIZE, + stride=STRIDE, + weight_init='normal', + pad_mode='same')) + main.append(nn.Tanh()) + self.main = main + + def construct(self, input1): + + output = self.linear(input1) + output = output.view(64, 4*4*4*self.DIM, 1, 1) + output = self.bn(output) + output = self.relu(output) + output = output.view(64, 4*self.DIM, 4, 4) + output = self.main(output) + return output diff --git a/research/cv/wgan_gp/train.py b/research/cv/wgan_gp/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d7eb97c465dfe4a0a891f943b7665ea8c717a296 --- /dev/null +++ b/research/cv/wgan_gp/train.py @@ -0,0 +1,191 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import time +import mindspore as ms +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common import initializer as init +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net, save_checkpoint +from PIL import Image +import numpy as np + +from src.dataset import create_dataset +from src.model import DcganG, DcgannobnD +from src.cell import GenWithLossCell, DisWithLossCell +from src.args import get_args + +if __name__ == '__main__': + + t_begin = time.time() + args_opt = get_args() + + if args_opt.experiment is None: + args_opt.experiment = 'samples' + os.system('rm -rf {0}'.format(args_opt.experiment)) + os.system('mkdir {0}'.format(args_opt.experiment)) + + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=int(args_opt.device_id)) + ms.set_seed(0) + dataset = create_dataset(args_opt.dataroot, args_opt.batchSize, args_opt.imageSize, 1, + args_opt.workers, args_opt.device_target) + + def init_weight(net): + for _, cell in net.cells_and_names(): + if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)): + cell.weight.set_data(init.initializer(init.Normal(0.02), cell.weight.shape)) + elif isinstance(cell, nn.BatchNorm2d): + cell.gamma.set_data(init.initializer(Tensor(np.random.normal(1, 0.02, cell.gamma.shape), \ + mstype.float32), cell.gamma.shape)) + cell.beta.set_data(init.initializer('zeros', cell.beta.shape)) + elif isinstance(cell, nn.Dense): + cell.weight.set_data(init.initializer(init.Normal(0.02), cell.weight.shape)) + + def save_image(img, img_path, IMAGE_SIZE): + """save image""" + mul = ops.Mul() + add = ops.Add() + if isinstance(img, Tensor): + img = mul(img, 255 * 0.5) + img = add(img, 255 * 0.5) + + img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1)) + + elif not isinstance(img, np.ndarray): + raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img))) + + IMAGE_ROW = 8 # Row num + IMAGE_COLUMN = 8 # Column num + PADDING = 2 # Interval of small pictures + to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1), + IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture + # cycle + ii = 0 + for y in range(1, IMAGE_ROW + 1): + for x in range(1, IMAGE_COLUMN + 1): + from_image = Image.fromarray(img[ii]) + to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y)) + ii = ii + 1 + + to_image.save(img_path) # save + + + # define net---------------------------------------------------------------------------------------------- + # Generator + netG = DcganG(args_opt.DIM) + + init_weight(netG) + + if args_opt.netG != '': # load checkpoint if needed + load_param_into_net(netG, load_checkpoint(args_opt.netG)) + + netD = DcgannobnD(args_opt.DIM) + init_weight(netD) + + if args_opt.netD != '': + load_param_into_net(netD, load_checkpoint(args_opt.netD)) + + fixed_noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + + # # setup optimizer + optimizerD = nn.Adam( + netD.trainable_params(), + learning_rate=args_opt.lrD, + beta1=args_opt.beta1, + beta2=args_opt.beta2) + optimizerG = nn.Adam( + netG.trainable_params(), + learning_rate=args_opt.lrG, + beta1=args_opt.beta1, + beta2=args_opt.beta2) + + netG_train = nn.TrainOneStepCell(GenWithLossCell(netG, netD), optimizerG) + netD_train = nn.TrainOneStepCell(DisWithLossCell(netG, netD), optimizerD) + + netG_train.set_train() + netD_train.set_train() + + gen_iterations = 0 + + t0 = time.time() + # Train + for epoch in range(args_opt.niter): + data_iter = dataset.create_dict_iterator() + length = dataset.get_dataset_size() + i = 0 + while i < length: + ########################### + # (1) Update D network + ########################### + for p in netD.trainable_params(): # reset requires_grad + p.requires_grad = True # they are set to False below in netG update + + # train the discriminator Diters times + Diters = args_opt.Diters + + j = 0 + while j < Diters and i < length: + j += 1 + + data = data_iter.__next__() + i += 1 + + # train with real and fake + real = data['image'] + noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + loss_D = netD_train(real, noise) + + print('epoch %d loss_D: %.4f ' % (epoch, float(loss_D))) + + # ########################## + # (2) Update G network + # ########################## + for p in netD.trainable_params(): + p.requires_grad = False # to avoid computation + + noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, 128]), dtype=ms.float32) + + loss_G = netG_train(noise) + gen_iterations += 1 + + t1 = time.time() + print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f' + % (epoch, args_opt.niter, i, length, gen_iterations, + loss_D.asnumpy(), loss_G.asnumpy())) + print('step_cost: %.4f seconds' % (float(t1 - t0))) + t0 = t1 + + if gen_iterations % args_opt.save_iterations == 0: + + fake = netG(fixed_noise) + save_image( + real, + '{0}/real_samples.png'.format(args_opt.experiment), + args_opt.imageSize) + save_image( + fake, + '{0}/fake_samples_{1}.png'.format(args_opt.experiment, gen_iterations), + args_opt.imageSize) + + save_checkpoint(netD, '{0}/debug_netD_giter_{1}.ckpt'.format(args_opt.experiment, gen_iterations)) + save_checkpoint(netG, '{0}/debug_netG_giter_{1}.ckpt'.format(args_opt.experiment, gen_iterations)) + + t_end = time.time() + print('total_cost: %.4f seconds' % (float(t_end - t_begin))) + print("Train success!")