From 10bfd445c3968b10a100a578516c7a99b10c3adb Mon Sep 17 00:00:00 2001 From: zhou_lili <zhoulili20@huawei.com> Date: Thu, 23 Dec 2021 12:06:21 +0800 Subject: [PATCH] Improve accuracy of gan --- research/cv/gan/README_CN.md | 3 ++ research/cv/gan/eval.py | 6 +-- .../cv/gan/scripts/run_distributed_train.sh | 2 +- research/cv/gan/scripts/run_eval.sh | 1 + .../cv/gan/scripts/run_standalone_train.sh | 6 +++ research/cv/gan/src/dataset.py | 10 +---- research/cv/gan/src/gan.py | 44 ++++++++++++++++--- research/cv/gan/src/param_parse.py | 3 +- 8 files changed, 55 insertions(+), 20 deletions(-) diff --git a/research/cv/gan/README_CN.md b/research/cv/gan/README_CN.md index 0ca8b2fa1..2f753f80d 100644 --- a/research/cv/gan/README_CN.md +++ b/research/cv/gan/README_CN.md @@ -167,6 +167,7 @@ bash ./scripts/run_eval.sh [DEVICE_ID] python train.py > train.log 2>&1 & ``` +- 鍦ㄨ缁冧箣鍓嶏紝闇€瑕佸湪src/param_parser.py涓嬩慨鏀筪ata_path涓鸿缁冮泦璺緞 涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆� 璁粌缁撴潫鍚庯紝鎮ㄥ彲鍦ㄩ粯璁よ剼鏈枃浠跺す涓嬫壘鍒版鏌ョ偣鏂囦欢銆傞噰鐢ㄤ互涓嬫柟寮忚揪鍒版崯澶卞€硷細 @@ -186,6 +187,8 @@ bash ./scripts/run_eval.sh [DEVICE_ID] ``` +- 鍦ㄦ帹鐞嗕箣鍓嶏紝闇€瑕佸湪src/param_parser.py涓嬩慨鏀筩kpt_path涓虹湡瀹炴帹鐞哻kpt鐨勮矾寰� + # 妯″瀷鎻忚堪 ## 鎬ц兘 diff --git a/research/cv/gan/eval.py b/research/cv/gan/eval.py index 57ee1e983..49372c2c8 100644 --- a/research/cv/gan/eval.py +++ b/research/cv/gan/eval.py @@ -152,15 +152,11 @@ def parzen(samples): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - test_latent_code_parzen = Tensor(np.random.normal(size=(10000, opt.latent_dim)), dtype=mstype.float32) if __name__ == '__main__': generator = Generator(opt.latent_dim) - - ckpt_file_name = 'checkpoints/' + str(opt.n_epochs-1) + '.ckpt' - param_dict = load_checkpoint(ckpt_file_name) + param_dict = load_checkpoint(opt.ckpt_path) load_param_into_net(generator, param_dict) imag = generator(test_latent_code_parzen) imag = imag * 127.5 + 127.5 diff --git a/research/cv/gan/scripts/run_distributed_train.sh b/research/cv/gan/scripts/run_distributed_train.sh index 0afbecfed..d4a37ffd5 100644 --- a/research/cv/gan/scripts/run_distributed_train.sh +++ b/research/cv/gan/scripts/run_distributed_train.sh @@ -37,6 +37,6 @@ do echo "Start training for rank $RANK_ID, device $DEVICE_ID" cd ./device$i env > env.log - nohup python train.py --device_id=$DEVICE_ID --distribute=True --data_path="../data/MNIST_Data/" > distributed_train.log 2>&1 & + nohup python train.py --device_id=$DEVICE_ID --distribute=True > distributed_train.log 2>&1 & cd .. done diff --git a/research/cv/gan/scripts/run_eval.sh b/research/cv/gan/scripts/run_eval.sh index 875ae929d..2d48bb25e 100644 --- a/research/cv/gan/scripts/run_eval.sh +++ b/research/cv/gan/scripts/run_eval.sh @@ -23,4 +23,5 @@ if [ ! -d "logs" ]; then mkdir logs fi +export DEVICE_ID=$1 nohup python -u eval.py > logs/eval.log 2>&1 & diff --git a/research/cv/gan/scripts/run_standalone_train.sh b/research/cv/gan/scripts/run_standalone_train.sh index 19df09997..870242657 100644 --- a/research/cv/gan/scripts/run_standalone_train.sh +++ b/research/cv/gan/scripts/run_standalone_train.sh @@ -19,6 +19,12 @@ if [[ $# -gt 1 ]]; then exit 1 fi +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=$1 +export RANK_ID=0 +export RANK_SIZE=1 + if [ ! -d "logs" ]; then mkdir logs fi diff --git a/research/cv/gan/src/dataset.py b/research/cv/gan/src/dataset.py index f1b13a0fa..5872e2af5 100644 --- a/research/cv/gan/src/dataset.py +++ b/research/cv/gan/src/dataset.py @@ -127,9 +127,7 @@ class DatasetGenerator_valid: def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100): """create dataset train""" dataset_generator = DatasetGenerator() - - dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False) - + dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=True) mnist_ds = dataset1.map( operations=lambda x: ( x.astype("float32"), @@ -145,10 +143,8 @@ def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100): def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100): """create dataset train""" dataset_generator = DatasetGenerator() - dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], - shuffle=False, num_shards=get_group_size(), shard_id=get_rank()) - + shuffle=True, num_shards=get_group_size(), shard_id=get_rank()) mnist_ds = dataset1.map( operations=lambda x: ( x.astype("float32"), @@ -165,9 +161,7 @@ def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100): def create_dataset_valid(batch_size=5, repeat_size=1, latent_size=100): """create dataset valid""" dataset_generator = DatasetGenerator_valid() - dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False) - mnist_ds = dataset.map( operations=lambda x: ( x[-10000:].astype("float32"), diff --git a/research/cv/gan/src/gan.py b/research/cv/gan/src/gan.py index d7af0a54e..10e5acd81 100644 --- a/research/cv/gan/src/gan.py +++ b/research/cv/gan/src/gan.py @@ -15,7 +15,10 @@ '''train the gan model''' from src.loss import GenWithLossCell from src.loss import DisWithLossCell +import numpy as np from mindspore import nn +from mindspore import Tensor, Parameter +from mindspore.common import initializer import mindspore.ops.operations as P import mindspore.ops.functional as F import mindspore.ops.composite as C @@ -29,6 +32,37 @@ class Reshape(nn.Cell): def construct(self, x): return self.reshape(x, self.shape) +class InstanceNorm2d(nn.Cell): + """InstanceNorm2d""" + + def __init__(self, channel): + super(InstanceNorm2d, self).__init__() + self.gamma = Parameter(initializer.initializer( + init=Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32)), shape=[1, channel, 1, 1]), + name='gamma') + self.beta = Parameter(initializer.initializer(init=initializer.Zero(), shape=[1, channel, 1, 1]), + name='beta') + self.reduceMean = P.ReduceMean(keep_dims=True) + self.square = P.Square() + self.sub = P.Sub() + self.add = P.Add() + self.rsqrt = P.Rsqrt() + self.mul = P.Mul() + self.tile = P.Tile() + self.reshape = P.Reshape() + self.eps = Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32) * 1e-5) + self.cast2fp32 = P.Cast() + + def construct(self, x): + mean = self.reduceMean(x, (2, 3)) + mean_stop_grad = F.stop_gradient(mean) + variance = self.reduceMean(self.square(self.sub(x, mean_stop_grad)), (2, 3)) + variance = variance + self.eps + inv = self.rsqrt(variance) + normalized = self.sub(x, mean) * inv + x_IN = self.add(self.mul(self.gamma, normalized), self.beta) + return x_IN + class Generator(nn.Cell): """generator""" @@ -38,15 +72,15 @@ class Generator(nn.Cell): self.network.append(nn.Dense(latent_size, 256 * 7 * 7, has_bias=False)) self.network.append(Reshape((-1, 256, 7, 7))) - self.network.append(nn.BatchNorm2d(256)) + self.network.append(InstanceNorm2d(256)) self.network.append(nn.ReLU()) self.network.append(nn.Conv2dTranspose(256, 128, 5, 1)) - self.network.append(nn.BatchNorm2d(128)) + self.network.append(InstanceNorm2d(128)) self.network.append(nn.ReLU()) self.network.append(nn.Conv2dTranspose(128, 64, 5, 2)) - self.network.append(nn.BatchNorm2d(64)) + self.network.append(InstanceNorm2d(64)) self.network.append(nn.ReLU()) self.network.append(nn.Conv2dTranspose(64, 1, 5, 2)) @@ -64,11 +98,11 @@ class Discriminator(nn.Cell): self.network = nn.SequentialCell() self.network.append(nn.Conv2d(1, 64, 5, 2)) - self.network.append(nn.BatchNorm2d(64)) + self.network.append(InstanceNorm2d(64)) self.network.append(nn.LeakyReLU()) self.network.append(nn.Conv2d(64, 128, 5, 2)) - self.network.append(nn.BatchNorm2d(128)) + self.network.append(InstanceNorm2d(128)) self.network.append(nn.LeakyReLU()) self.network.append(nn.Flatten()) diff --git a/research/cv/gan/src/param_parse.py b/research/cv/gan/src/param_parse.py index e2a35d584..958e9887e 100644 --- a/research/cv/gan/src/param_parse.py +++ b/research/cv/gan/src/param_parse.py @@ -41,6 +41,7 @@ def parameter_parser(): parser.add_argument("--batch_size_t", type=int, default=10, help="size of the test batches") parser.add_argument("--batch_size_v", type=int, default=1000, help="size of the valid batches") parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)') - parser.add_argument("--data_path", type=str, default="data/MNIST_Data/", help="dataset path") + parser.add_argument("--data_path", type=str, default="mnist/", help="dataset path") # change to train data path + parser.add_argument("--ckpt_path", type=str, default="", help="eval ckpt path") # change to eval ckpt path parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") return parser.parse_args() -- GitLab