From 10bfd445c3968b10a100a578516c7a99b10c3adb Mon Sep 17 00:00:00 2001
From: zhou_lili <zhoulili20@huawei.com>
Date: Thu, 23 Dec 2021 12:06:21 +0800
Subject: [PATCH] Improve accuracy of gan

---
 research/cv/gan/README_CN.md                  |  3 ++
 research/cv/gan/eval.py                       |  6 +--
 .../cv/gan/scripts/run_distributed_train.sh   |  2 +-
 research/cv/gan/scripts/run_eval.sh           |  1 +
 .../cv/gan/scripts/run_standalone_train.sh    |  6 +++
 research/cv/gan/src/dataset.py                | 10 +----
 research/cv/gan/src/gan.py                    | 44 ++++++++++++++++---
 research/cv/gan/src/param_parse.py            |  3 +-
 8 files changed, 55 insertions(+), 20 deletions(-)

diff --git a/research/cv/gan/README_CN.md b/research/cv/gan/README_CN.md
index 0ca8b2fa1..2f753f80d 100644
--- a/research/cv/gan/README_CN.md
+++ b/research/cv/gan/README_CN.md
@@ -167,6 +167,7 @@ bash ./scripts/run_eval.sh [DEVICE_ID]
   python train.py > train.log 2>&1 &
   ```
 
+- 鍦ㄨ缁冧箣鍓嶏紝闇€瑕佸湪src/param_parser.py涓嬩慨鏀筪ata_path涓鸿缁冮泦璺緞
   涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆�
 
   璁粌缁撴潫鍚庯紝鎮ㄥ彲鍦ㄩ粯璁よ剼鏈枃浠跺す涓嬫壘鍒版鏌ョ偣鏂囦欢銆傞噰鐢ㄤ互涓嬫柟寮忚揪鍒版崯澶卞€硷細
@@ -186,6 +187,8 @@ bash ./scripts/run_eval.sh [DEVICE_ID]
 
   ```
 
+- 鍦ㄦ帹鐞嗕箣鍓嶏紝闇€瑕佸湪src/param_parser.py涓嬩慨鏀筩kpt_path涓虹湡瀹炴帹鐞哻kpt鐨勮矾寰�
+
 # 妯″瀷鎻忚堪
 
 ## 鎬ц兘
diff --git a/research/cv/gan/eval.py b/research/cv/gan/eval.py
index 57ee1e983..49372c2c8 100644
--- a/research/cv/gan/eval.py
+++ b/research/cv/gan/eval.py
@@ -152,15 +152,11 @@ def parzen(samples):
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
 
-
-
 test_latent_code_parzen = Tensor(np.random.normal(size=(10000, opt.latent_dim)), dtype=mstype.float32)
 
 if __name__ == '__main__':
     generator = Generator(opt.latent_dim)
-
-    ckpt_file_name = 'checkpoints/' + str(opt.n_epochs-1) + '.ckpt'
-    param_dict = load_checkpoint(ckpt_file_name)
+    param_dict = load_checkpoint(opt.ckpt_path)
     load_param_into_net(generator, param_dict)
     imag = generator(test_latent_code_parzen)
     imag = imag * 127.5 + 127.5
diff --git a/research/cv/gan/scripts/run_distributed_train.sh b/research/cv/gan/scripts/run_distributed_train.sh
index 0afbecfed..d4a37ffd5 100644
--- a/research/cv/gan/scripts/run_distributed_train.sh
+++ b/research/cv/gan/scripts/run_distributed_train.sh
@@ -37,6 +37,6 @@ do
     echo "Start training for rank $RANK_ID, device $DEVICE_ID"
     cd ./device$i
     env > env.log
-    nohup python train.py --device_id=$DEVICE_ID --distribute=True --data_path="../data/MNIST_Data/" > distributed_train.log 2>&1 &
+    nohup python train.py --device_id=$DEVICE_ID --distribute=True > distributed_train.log 2>&1 &
     cd ..
 done
diff --git a/research/cv/gan/scripts/run_eval.sh b/research/cv/gan/scripts/run_eval.sh
index 875ae929d..2d48bb25e 100644
--- a/research/cv/gan/scripts/run_eval.sh
+++ b/research/cv/gan/scripts/run_eval.sh
@@ -23,4 +23,5 @@ if [ ! -d "logs" ]; then
         mkdir logs
 fi
 
+export DEVICE_ID=$1
 nohup python -u eval.py > logs/eval.log 2>&1 &
diff --git a/research/cv/gan/scripts/run_standalone_train.sh b/research/cv/gan/scripts/run_standalone_train.sh
index 19df09997..870242657 100644
--- a/research/cv/gan/scripts/run_standalone_train.sh
+++ b/research/cv/gan/scripts/run_standalone_train.sh
@@ -19,6 +19,12 @@ if [[ $# -gt 1 ]]; then
 exit 1
 fi
 
+ulimit -u unlimited
+export DEVICE_NUM=1
+export DEVICE_ID=$1
+export RANK_ID=0
+export RANK_SIZE=1
+
 if [ ! -d "logs" ]; then
         mkdir logs
 fi
diff --git a/research/cv/gan/src/dataset.py b/research/cv/gan/src/dataset.py
index f1b13a0fa..5872e2af5 100644
--- a/research/cv/gan/src/dataset.py
+++ b/research/cv/gan/src/dataset.py
@@ -127,9 +127,7 @@ class DatasetGenerator_valid:
 def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100):
     """create dataset train"""
     dataset_generator = DatasetGenerator()
-
-    dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False)
-
+    dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=True)
     mnist_ds = dataset1.map(
         operations=lambda x: (
             x.astype("float32"),
@@ -145,10 +143,8 @@ def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100):
 def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100):
     """create dataset train"""
     dataset_generator = DatasetGenerator()
-
     dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"],
-                                   shuffle=False, num_shards=get_group_size(), shard_id=get_rank())
-
+                                   shuffle=True, num_shards=get_group_size(), shard_id=get_rank())
     mnist_ds = dataset1.map(
         operations=lambda x: (
             x.astype("float32"),
@@ -165,9 +161,7 @@ def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100):
 def create_dataset_valid(batch_size=5, repeat_size=1, latent_size=100):
     """create dataset valid"""
     dataset_generator = DatasetGenerator_valid()
-
     dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False)
-
     mnist_ds = dataset.map(
         operations=lambda x: (
             x[-10000:].astype("float32"),
diff --git a/research/cv/gan/src/gan.py b/research/cv/gan/src/gan.py
index d7af0a54e..10e5acd81 100644
--- a/research/cv/gan/src/gan.py
+++ b/research/cv/gan/src/gan.py
@@ -15,7 +15,10 @@
 '''train the gan model'''
 from src.loss import GenWithLossCell
 from src.loss import DisWithLossCell
+import numpy as np
 from mindspore import nn
+from mindspore import Tensor, Parameter
+from mindspore.common import initializer
 import mindspore.ops.operations as P
 import mindspore.ops.functional as F
 import mindspore.ops.composite as C
@@ -29,6 +32,37 @@ class Reshape(nn.Cell):
     def construct(self, x):
         return self.reshape(x, self.shape)
 
+class InstanceNorm2d(nn.Cell):
+    """InstanceNorm2d"""
+
+    def __init__(self, channel):
+        super(InstanceNorm2d, self).__init__()
+        self.gamma = Parameter(initializer.initializer(
+            init=Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32)), shape=[1, channel, 1, 1]),
+                               name='gamma')
+        self.beta = Parameter(initializer.initializer(init=initializer.Zero(), shape=[1, channel, 1, 1]),
+                              name='beta')
+        self.reduceMean = P.ReduceMean(keep_dims=True)
+        self.square = P.Square()
+        self.sub = P.Sub()
+        self.add = P.Add()
+        self.rsqrt = P.Rsqrt()
+        self.mul = P.Mul()
+        self.tile = P.Tile()
+        self.reshape = P.Reshape()
+        self.eps = Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32) * 1e-5)
+        self.cast2fp32 = P.Cast()
+
+    def construct(self, x):
+        mean = self.reduceMean(x, (2, 3))
+        mean_stop_grad = F.stop_gradient(mean)
+        variance = self.reduceMean(self.square(self.sub(x, mean_stop_grad)), (2, 3))
+        variance = variance + self.eps
+        inv = self.rsqrt(variance)
+        normalized = self.sub(x, mean) * inv
+        x_IN = self.add(self.mul(self.gamma, normalized), self.beta)
+        return x_IN
+
 class Generator(nn.Cell):
     """generator"""
 
@@ -38,15 +72,15 @@ class Generator(nn.Cell):
 
         self.network.append(nn.Dense(latent_size, 256 * 7 * 7, has_bias=False))
         self.network.append(Reshape((-1, 256, 7, 7)))
-        self.network.append(nn.BatchNorm2d(256))
+        self.network.append(InstanceNorm2d(256))
         self.network.append(nn.ReLU())
 
         self.network.append(nn.Conv2dTranspose(256, 128, 5, 1))
-        self.network.append(nn.BatchNorm2d(128))
+        self.network.append(InstanceNorm2d(128))
         self.network.append(nn.ReLU())
 
         self.network.append(nn.Conv2dTranspose(128, 64, 5, 2))
-        self.network.append(nn.BatchNorm2d(64))
+        self.network.append(InstanceNorm2d(64))
         self.network.append(nn.ReLU())
 
         self.network.append(nn.Conv2dTranspose(64, 1, 5, 2))
@@ -64,11 +98,11 @@ class Discriminator(nn.Cell):
         self.network = nn.SequentialCell()
 
         self.network.append(nn.Conv2d(1, 64, 5, 2))
-        self.network.append(nn.BatchNorm2d(64))
+        self.network.append(InstanceNorm2d(64))
         self.network.append(nn.LeakyReLU())
 
         self.network.append(nn.Conv2d(64, 128, 5, 2))
-        self.network.append(nn.BatchNorm2d(128))
+        self.network.append(InstanceNorm2d(128))
         self.network.append(nn.LeakyReLU())
 
         self.network.append(nn.Flatten())
diff --git a/research/cv/gan/src/param_parse.py b/research/cv/gan/src/param_parse.py
index e2a35d584..958e9887e 100644
--- a/research/cv/gan/src/param_parse.py
+++ b/research/cv/gan/src/param_parse.py
@@ -41,6 +41,7 @@ def parameter_parser():
     parser.add_argument("--batch_size_t", type=int, default=10, help="size of the test batches")
     parser.add_argument("--batch_size_v", type=int, default=1000, help="size of the valid batches")
     parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
-    parser.add_argument("--data_path", type=str, default="data/MNIST_Data/", help="dataset path")
+    parser.add_argument("--data_path", type=str, default="mnist/", help="dataset path")  # change to train data path
+    parser.add_argument("--ckpt_path", type=str, default="", help="eval ckpt path")  # change to eval ckpt path
     parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
     return parser.parse_args()
-- 
GitLab