Skip to content
Snippets Groups Projects
Commit 10bfd445 authored by zhou_lili's avatar zhou_lili
Browse files

Improve accuracy of gan

parent 6cad9cec
No related branches found
No related tags found
No related merge requests found
...@@ -167,6 +167,7 @@ bash ./scripts/run_eval.sh [DEVICE_ID] ...@@ -167,6 +167,7 @@ bash ./scripts/run_eval.sh [DEVICE_ID]
python train.py > train.log 2>&1 & python train.py > train.log 2>&1 &
``` ```
- 在训练之前,需要在src/param_parser.py下修改data_path为训练集路径
上述python命令将在后台运行,您可以通过train.log文件查看结果。 上述python命令将在后台运行,您可以通过train.log文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值: 训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
...@@ -186,6 +187,8 @@ bash ./scripts/run_eval.sh [DEVICE_ID] ...@@ -186,6 +187,8 @@ bash ./scripts/run_eval.sh [DEVICE_ID]
``` ```
- 在推理之前,需要在src/param_parser.py下修改ckpt_path为真实推理ckpt的路径
# 模型描述 # 模型描述
## 性能 ## 性能
......
...@@ -152,15 +152,11 @@ def parzen(samples): ...@@ -152,15 +152,11 @@ def parzen(samples):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_latent_code_parzen = Tensor(np.random.normal(size=(10000, opt.latent_dim)), dtype=mstype.float32) test_latent_code_parzen = Tensor(np.random.normal(size=(10000, opt.latent_dim)), dtype=mstype.float32)
if __name__ == '__main__': if __name__ == '__main__':
generator = Generator(opt.latent_dim) generator = Generator(opt.latent_dim)
param_dict = load_checkpoint(opt.ckpt_path)
ckpt_file_name = 'checkpoints/' + str(opt.n_epochs-1) + '.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(generator, param_dict) load_param_into_net(generator, param_dict)
imag = generator(test_latent_code_parzen) imag = generator(test_latent_code_parzen)
imag = imag * 127.5 + 127.5 imag = imag * 127.5 + 127.5
......
...@@ -37,6 +37,6 @@ do ...@@ -37,6 +37,6 @@ do
echo "Start training for rank $RANK_ID, device $DEVICE_ID" echo "Start training for rank $RANK_ID, device $DEVICE_ID"
cd ./device$i cd ./device$i
env > env.log env > env.log
nohup python train.py --device_id=$DEVICE_ID --distribute=True --data_path="../data/MNIST_Data/" > distributed_train.log 2>&1 & nohup python train.py --device_id=$DEVICE_ID --distribute=True > distributed_train.log 2>&1 &
cd .. cd ..
done done
...@@ -23,4 +23,5 @@ if [ ! -d "logs" ]; then ...@@ -23,4 +23,5 @@ if [ ! -d "logs" ]; then
mkdir logs mkdir logs
fi fi
export DEVICE_ID=$1
nohup python -u eval.py > logs/eval.log 2>&1 & nohup python -u eval.py > logs/eval.log 2>&1 &
...@@ -19,6 +19,12 @@ if [[ $# -gt 1 ]]; then ...@@ -19,6 +19,12 @@ if [[ $# -gt 1 ]]; then
exit 1 exit 1
fi fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=$1
export RANK_ID=0
export RANK_SIZE=1
if [ ! -d "logs" ]; then if [ ! -d "logs" ]; then
mkdir logs mkdir logs
fi fi
......
...@@ -127,9 +127,7 @@ class DatasetGenerator_valid: ...@@ -127,9 +127,7 @@ class DatasetGenerator_valid:
def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100): def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100):
"""create dataset train""" """create dataset train"""
dataset_generator = DatasetGenerator() dataset_generator = DatasetGenerator()
dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=True)
dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False)
mnist_ds = dataset1.map( mnist_ds = dataset1.map(
operations=lambda x: ( operations=lambda x: (
x.astype("float32"), x.astype("float32"),
...@@ -145,10 +143,8 @@ def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100): ...@@ -145,10 +143,8 @@ def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100):
def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100): def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100):
"""create dataset train""" """create dataset train"""
dataset_generator = DatasetGenerator() dataset_generator = DatasetGenerator()
dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"],
shuffle=False, num_shards=get_group_size(), shard_id=get_rank()) shuffle=True, num_shards=get_group_size(), shard_id=get_rank())
mnist_ds = dataset1.map( mnist_ds = dataset1.map(
operations=lambda x: ( operations=lambda x: (
x.astype("float32"), x.astype("float32"),
...@@ -165,9 +161,7 @@ def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100): ...@@ -165,9 +161,7 @@ def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100):
def create_dataset_valid(batch_size=5, repeat_size=1, latent_size=100): def create_dataset_valid(batch_size=5, repeat_size=1, latent_size=100):
"""create dataset valid""" """create dataset valid"""
dataset_generator = DatasetGenerator_valid() dataset_generator = DatasetGenerator_valid()
dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False) dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False)
mnist_ds = dataset.map( mnist_ds = dataset.map(
operations=lambda x: ( operations=lambda x: (
x[-10000:].astype("float32"), x[-10000:].astype("float32"),
......
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
'''train the gan model''' '''train the gan model'''
from src.loss import GenWithLossCell from src.loss import GenWithLossCell
from src.loss import DisWithLossCell from src.loss import DisWithLossCell
import numpy as np
from mindspore import nn from mindspore import nn
from mindspore import Tensor, Parameter
from mindspore.common import initializer
import mindspore.ops.operations as P import mindspore.ops.operations as P
import mindspore.ops.functional as F import mindspore.ops.functional as F
import mindspore.ops.composite as C import mindspore.ops.composite as C
...@@ -29,6 +32,37 @@ class Reshape(nn.Cell): ...@@ -29,6 +32,37 @@ class Reshape(nn.Cell):
def construct(self, x): def construct(self, x):
return self.reshape(x, self.shape) return self.reshape(x, self.shape)
class InstanceNorm2d(nn.Cell):
"""InstanceNorm2d"""
def __init__(self, channel):
super(InstanceNorm2d, self).__init__()
self.gamma = Parameter(initializer.initializer(
init=Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32)), shape=[1, channel, 1, 1]),
name='gamma')
self.beta = Parameter(initializer.initializer(init=initializer.Zero(), shape=[1, channel, 1, 1]),
name='beta')
self.reduceMean = P.ReduceMean(keep_dims=True)
self.square = P.Square()
self.sub = P.Sub()
self.add = P.Add()
self.rsqrt = P.Rsqrt()
self.mul = P.Mul()
self.tile = P.Tile()
self.reshape = P.Reshape()
self.eps = Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32) * 1e-5)
self.cast2fp32 = P.Cast()
def construct(self, x):
mean = self.reduceMean(x, (2, 3))
mean_stop_grad = F.stop_gradient(mean)
variance = self.reduceMean(self.square(self.sub(x, mean_stop_grad)), (2, 3))
variance = variance + self.eps
inv = self.rsqrt(variance)
normalized = self.sub(x, mean) * inv
x_IN = self.add(self.mul(self.gamma, normalized), self.beta)
return x_IN
class Generator(nn.Cell): class Generator(nn.Cell):
"""generator""" """generator"""
...@@ -38,15 +72,15 @@ class Generator(nn.Cell): ...@@ -38,15 +72,15 @@ class Generator(nn.Cell):
self.network.append(nn.Dense(latent_size, 256 * 7 * 7, has_bias=False)) self.network.append(nn.Dense(latent_size, 256 * 7 * 7, has_bias=False))
self.network.append(Reshape((-1, 256, 7, 7))) self.network.append(Reshape((-1, 256, 7, 7)))
self.network.append(nn.BatchNorm2d(256)) self.network.append(InstanceNorm2d(256))
self.network.append(nn.ReLU()) self.network.append(nn.ReLU())
self.network.append(nn.Conv2dTranspose(256, 128, 5, 1)) self.network.append(nn.Conv2dTranspose(256, 128, 5, 1))
self.network.append(nn.BatchNorm2d(128)) self.network.append(InstanceNorm2d(128))
self.network.append(nn.ReLU()) self.network.append(nn.ReLU())
self.network.append(nn.Conv2dTranspose(128, 64, 5, 2)) self.network.append(nn.Conv2dTranspose(128, 64, 5, 2))
self.network.append(nn.BatchNorm2d(64)) self.network.append(InstanceNorm2d(64))
self.network.append(nn.ReLU()) self.network.append(nn.ReLU())
self.network.append(nn.Conv2dTranspose(64, 1, 5, 2)) self.network.append(nn.Conv2dTranspose(64, 1, 5, 2))
...@@ -64,11 +98,11 @@ class Discriminator(nn.Cell): ...@@ -64,11 +98,11 @@ class Discriminator(nn.Cell):
self.network = nn.SequentialCell() self.network = nn.SequentialCell()
self.network.append(nn.Conv2d(1, 64, 5, 2)) self.network.append(nn.Conv2d(1, 64, 5, 2))
self.network.append(nn.BatchNorm2d(64)) self.network.append(InstanceNorm2d(64))
self.network.append(nn.LeakyReLU()) self.network.append(nn.LeakyReLU())
self.network.append(nn.Conv2d(64, 128, 5, 2)) self.network.append(nn.Conv2d(64, 128, 5, 2))
self.network.append(nn.BatchNorm2d(128)) self.network.append(InstanceNorm2d(128))
self.network.append(nn.LeakyReLU()) self.network.append(nn.LeakyReLU())
self.network.append(nn.Flatten()) self.network.append(nn.Flatten())
......
...@@ -41,6 +41,7 @@ def parameter_parser(): ...@@ -41,6 +41,7 @@ def parameter_parser():
parser.add_argument("--batch_size_t", type=int, default=10, help="size of the test batches") parser.add_argument("--batch_size_t", type=int, default=10, help="size of the test batches")
parser.add_argument("--batch_size_v", type=int, default=1000, help="size of the valid batches") parser.add_argument("--batch_size_v", type=int, default=1000, help="size of the valid batches")
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)') parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
parser.add_argument("--data_path", type=str, default="data/MNIST_Data/", help="dataset path") parser.add_argument("--data_path", type=str, default="mnist/", help="dataset path") # change to train data path
parser.add_argument("--ckpt_path", type=str, default="", help="eval ckpt path") # change to eval ckpt path
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
return parser.parse_args() return parser.parse_args()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment