diff --git a/research/cv/gan/README_CN.md b/research/cv/gan/README_CN.md index 0ca8b2fa127a9e1320358a57c61babe8fd51fe7c..2f753f80d71d394fc053a806ca79522a72c40094 100644 --- a/research/cv/gan/README_CN.md +++ b/research/cv/gan/README_CN.md @@ -167,6 +167,7 @@ bash ./scripts/run_eval.sh [DEVICE_ID] python train.py > train.log 2>&1 & ``` +- 鍦ㄨ缁冧箣鍓嶏紝闇€瑕佸湪src/param_parser.py涓嬩慨鏀筪ata_path涓鸿缁冮泦璺緞 涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆� 璁粌缁撴潫鍚庯紝鎮ㄥ彲鍦ㄩ粯璁よ剼鏈枃浠跺す涓嬫壘鍒版鏌ョ偣鏂囦欢銆傞噰鐢ㄤ互涓嬫柟寮忚揪鍒版崯澶卞€硷細 @@ -186,6 +187,8 @@ bash ./scripts/run_eval.sh [DEVICE_ID] ``` +- 鍦ㄦ帹鐞嗕箣鍓嶏紝闇€瑕佸湪src/param_parser.py涓嬩慨鏀筩kpt_path涓虹湡瀹炴帹鐞哻kpt鐨勮矾寰� + # 妯″瀷鎻忚堪 ## 鎬ц兘 diff --git a/research/cv/gan/eval.py b/research/cv/gan/eval.py index 57ee1e9837d6caed4908f9d87631a5e4e64854d2..49372c2c83724343dd35ef4d904aa26a16d3b226 100644 --- a/research/cv/gan/eval.py +++ b/research/cv/gan/eval.py @@ -152,15 +152,11 @@ def parzen(samples): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - test_latent_code_parzen = Tensor(np.random.normal(size=(10000, opt.latent_dim)), dtype=mstype.float32) if __name__ == '__main__': generator = Generator(opt.latent_dim) - - ckpt_file_name = 'checkpoints/' + str(opt.n_epochs-1) + '.ckpt' - param_dict = load_checkpoint(ckpt_file_name) + param_dict = load_checkpoint(opt.ckpt_path) load_param_into_net(generator, param_dict) imag = generator(test_latent_code_parzen) imag = imag * 127.5 + 127.5 diff --git a/research/cv/gan/scripts/run_distributed_train.sh b/research/cv/gan/scripts/run_distributed_train.sh index 0afbecfed73f1341cceab53dd5f776233f76301d..d4a37ffd57f00c24e5bd1ce226602b96ff64f49b 100644 --- a/research/cv/gan/scripts/run_distributed_train.sh +++ b/research/cv/gan/scripts/run_distributed_train.sh @@ -37,6 +37,6 @@ do echo "Start training for rank $RANK_ID, device $DEVICE_ID" cd ./device$i env > env.log - nohup python train.py --device_id=$DEVICE_ID --distribute=True --data_path="../data/MNIST_Data/" > distributed_train.log 2>&1 & + nohup python train.py --device_id=$DEVICE_ID --distribute=True > distributed_train.log 2>&1 & cd .. done diff --git a/research/cv/gan/scripts/run_eval.sh b/research/cv/gan/scripts/run_eval.sh index 875ae929dbf4eda68c8d1fb547a046717bb97768..2d48bb25e6908f16f394a57aa3605714759965cc 100644 --- a/research/cv/gan/scripts/run_eval.sh +++ b/research/cv/gan/scripts/run_eval.sh @@ -23,4 +23,5 @@ if [ ! -d "logs" ]; then mkdir logs fi +export DEVICE_ID=$1 nohup python -u eval.py > logs/eval.log 2>&1 & diff --git a/research/cv/gan/scripts/run_standalone_train.sh b/research/cv/gan/scripts/run_standalone_train.sh index 19df0999719b1d42a716a8f6c51f8ac76796283b..8702426575f114100002b5c59659c3a07fc225fb 100644 --- a/research/cv/gan/scripts/run_standalone_train.sh +++ b/research/cv/gan/scripts/run_standalone_train.sh @@ -19,6 +19,12 @@ if [[ $# -gt 1 ]]; then exit 1 fi +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=$1 +export RANK_ID=0 +export RANK_SIZE=1 + if [ ! -d "logs" ]; then mkdir logs fi diff --git a/research/cv/gan/src/dataset.py b/research/cv/gan/src/dataset.py index f1b13a0fa1f128b4d5c47535ae7cc68ed4b21b08..5872e2af5ca5c2e52786eb5b215dda8377b76d1a 100644 --- a/research/cv/gan/src/dataset.py +++ b/research/cv/gan/src/dataset.py @@ -127,9 +127,7 @@ class DatasetGenerator_valid: def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100): """create dataset train""" dataset_generator = DatasetGenerator() - - dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False) - + dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=True) mnist_ds = dataset1.map( operations=lambda x: ( x.astype("float32"), @@ -145,10 +143,8 @@ def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100): def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100): """create dataset train""" dataset_generator = DatasetGenerator() - dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], - shuffle=False, num_shards=get_group_size(), shard_id=get_rank()) - + shuffle=True, num_shards=get_group_size(), shard_id=get_rank()) mnist_ds = dataset1.map( operations=lambda x: ( x.astype("float32"), @@ -165,9 +161,7 @@ def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100): def create_dataset_valid(batch_size=5, repeat_size=1, latent_size=100): """create dataset valid""" dataset_generator = DatasetGenerator_valid() - dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False) - mnist_ds = dataset.map( operations=lambda x: ( x[-10000:].astype("float32"), diff --git a/research/cv/gan/src/gan.py b/research/cv/gan/src/gan.py index d7af0a54ecaac99b5bbc114816483c90f245d288..10e5acd819a6daba9d26893ff3ec4cb83fa67386 100644 --- a/research/cv/gan/src/gan.py +++ b/research/cv/gan/src/gan.py @@ -15,7 +15,10 @@ '''train the gan model''' from src.loss import GenWithLossCell from src.loss import DisWithLossCell +import numpy as np from mindspore import nn +from mindspore import Tensor, Parameter +from mindspore.common import initializer import mindspore.ops.operations as P import mindspore.ops.functional as F import mindspore.ops.composite as C @@ -29,6 +32,37 @@ class Reshape(nn.Cell): def construct(self, x): return self.reshape(x, self.shape) +class InstanceNorm2d(nn.Cell): + """InstanceNorm2d""" + + def __init__(self, channel): + super(InstanceNorm2d, self).__init__() + self.gamma = Parameter(initializer.initializer( + init=Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32)), shape=[1, channel, 1, 1]), + name='gamma') + self.beta = Parameter(initializer.initializer(init=initializer.Zero(), shape=[1, channel, 1, 1]), + name='beta') + self.reduceMean = P.ReduceMean(keep_dims=True) + self.square = P.Square() + self.sub = P.Sub() + self.add = P.Add() + self.rsqrt = P.Rsqrt() + self.mul = P.Mul() + self.tile = P.Tile() + self.reshape = P.Reshape() + self.eps = Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32) * 1e-5) + self.cast2fp32 = P.Cast() + + def construct(self, x): + mean = self.reduceMean(x, (2, 3)) + mean_stop_grad = F.stop_gradient(mean) + variance = self.reduceMean(self.square(self.sub(x, mean_stop_grad)), (2, 3)) + variance = variance + self.eps + inv = self.rsqrt(variance) + normalized = self.sub(x, mean) * inv + x_IN = self.add(self.mul(self.gamma, normalized), self.beta) + return x_IN + class Generator(nn.Cell): """generator""" @@ -38,15 +72,15 @@ class Generator(nn.Cell): self.network.append(nn.Dense(latent_size, 256 * 7 * 7, has_bias=False)) self.network.append(Reshape((-1, 256, 7, 7))) - self.network.append(nn.BatchNorm2d(256)) + self.network.append(InstanceNorm2d(256)) self.network.append(nn.ReLU()) self.network.append(nn.Conv2dTranspose(256, 128, 5, 1)) - self.network.append(nn.BatchNorm2d(128)) + self.network.append(InstanceNorm2d(128)) self.network.append(nn.ReLU()) self.network.append(nn.Conv2dTranspose(128, 64, 5, 2)) - self.network.append(nn.BatchNorm2d(64)) + self.network.append(InstanceNorm2d(64)) self.network.append(nn.ReLU()) self.network.append(nn.Conv2dTranspose(64, 1, 5, 2)) @@ -64,11 +98,11 @@ class Discriminator(nn.Cell): self.network = nn.SequentialCell() self.network.append(nn.Conv2d(1, 64, 5, 2)) - self.network.append(nn.BatchNorm2d(64)) + self.network.append(InstanceNorm2d(64)) self.network.append(nn.LeakyReLU()) self.network.append(nn.Conv2d(64, 128, 5, 2)) - self.network.append(nn.BatchNorm2d(128)) + self.network.append(InstanceNorm2d(128)) self.network.append(nn.LeakyReLU()) self.network.append(nn.Flatten()) diff --git a/research/cv/gan/src/param_parse.py b/research/cv/gan/src/param_parse.py index e2a35d584959ae99843bafb3656bd6fd7fb6b40f..958e9887e148bf807d8d0649f51607165c5cfcca 100644 --- a/research/cv/gan/src/param_parse.py +++ b/research/cv/gan/src/param_parse.py @@ -41,6 +41,7 @@ def parameter_parser(): parser.add_argument("--batch_size_t", type=int, default=10, help="size of the test batches") parser.add_argument("--batch_size_v", type=int, default=1000, help="size of the valid batches") parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)') - parser.add_argument("--data_path", type=str, default="data/MNIST_Data/", help="dataset path") + parser.add_argument("--data_path", type=str, default="mnist/", help="dataset path") # change to train data path + parser.add_argument("--ckpt_path", type=str, default="", help="eval ckpt path") # change to eval ckpt path parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") return parser.parse_args()