Skip to content
Snippets Groups Projects
Commit 10bfd445 authored by zhou_lili's avatar zhou_lili
Browse files

Improve accuracy of gan

parent 6cad9cec
Branches
Tags
No related merge requests found
......@@ -167,6 +167,7 @@ bash ./scripts/run_eval.sh [DEVICE_ID]
python train.py > train.log 2>&1 &
```
- 在训练之前,需要在src/param_parser.py下修改data_path为训练集路径
上述python命令将在后台运行,您可以通过train.log文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
......@@ -186,6 +187,8 @@ bash ./scripts/run_eval.sh [DEVICE_ID]
```
- 在推理之前,需要在src/param_parser.py下修改ckpt_path为真实推理ckpt的路径
# 模型描述
## 性能
......
......@@ -152,15 +152,11 @@ def parzen(samples):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_latent_code_parzen = Tensor(np.random.normal(size=(10000, opt.latent_dim)), dtype=mstype.float32)
if __name__ == '__main__':
generator = Generator(opt.latent_dim)
ckpt_file_name = 'checkpoints/' + str(opt.n_epochs-1) + '.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
param_dict = load_checkpoint(opt.ckpt_path)
load_param_into_net(generator, param_dict)
imag = generator(test_latent_code_parzen)
imag = imag * 127.5 + 127.5
......
......@@ -37,6 +37,6 @@ do
echo "Start training for rank $RANK_ID, device $DEVICE_ID"
cd ./device$i
env > env.log
nohup python train.py --device_id=$DEVICE_ID --distribute=True --data_path="../data/MNIST_Data/" > distributed_train.log 2>&1 &
nohup python train.py --device_id=$DEVICE_ID --distribute=True > distributed_train.log 2>&1 &
cd ..
done
......@@ -23,4 +23,5 @@ if [ ! -d "logs" ]; then
mkdir logs
fi
export DEVICE_ID=$1
nohup python -u eval.py > logs/eval.log 2>&1 &
......@@ -19,6 +19,12 @@ if [[ $# -gt 1 ]]; then
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=$1
export RANK_ID=0
export RANK_SIZE=1
if [ ! -d "logs" ]; then
mkdir logs
fi
......
......@@ -127,9 +127,7 @@ class DatasetGenerator_valid:
def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100):
"""create dataset train"""
dataset_generator = DatasetGenerator()
dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False)
dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=True)
mnist_ds = dataset1.map(
operations=lambda x: (
x.astype("float32"),
......@@ -145,10 +143,8 @@ def create_dataset_train(batch_size=5, repeat_size=1, latent_size=100):
def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100):
"""create dataset train"""
dataset_generator = DatasetGenerator()
dataset1 = ds.GeneratorDataset(dataset_generator, ["image", "label"],
shuffle=False, num_shards=get_group_size(), shard_id=get_rank())
shuffle=True, num_shards=get_group_size(), shard_id=get_rank())
mnist_ds = dataset1.map(
operations=lambda x: (
x.astype("float32"),
......@@ -165,9 +161,7 @@ def create_dataset_train_dis(batch_size=5, repeat_size=1, latent_size=100):
def create_dataset_valid(batch_size=5, repeat_size=1, latent_size=100):
"""create dataset valid"""
dataset_generator = DatasetGenerator_valid()
dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=False)
mnist_ds = dataset.map(
operations=lambda x: (
x[-10000:].astype("float32"),
......
......@@ -15,7 +15,10 @@
'''train the gan model'''
from src.loss import GenWithLossCell
from src.loss import DisWithLossCell
import numpy as np
from mindspore import nn
from mindspore import Tensor, Parameter
from mindspore.common import initializer
import mindspore.ops.operations as P
import mindspore.ops.functional as F
import mindspore.ops.composite as C
......@@ -29,6 +32,37 @@ class Reshape(nn.Cell):
def construct(self, x):
return self.reshape(x, self.shape)
class InstanceNorm2d(nn.Cell):
"""InstanceNorm2d"""
def __init__(self, channel):
super(InstanceNorm2d, self).__init__()
self.gamma = Parameter(initializer.initializer(
init=Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32)), shape=[1, channel, 1, 1]),
name='gamma')
self.beta = Parameter(initializer.initializer(init=initializer.Zero(), shape=[1, channel, 1, 1]),
name='beta')
self.reduceMean = P.ReduceMean(keep_dims=True)
self.square = P.Square()
self.sub = P.Sub()
self.add = P.Add()
self.rsqrt = P.Rsqrt()
self.mul = P.Mul()
self.tile = P.Tile()
self.reshape = P.Reshape()
self.eps = Tensor(np.ones(shape=[1, channel, 1, 1], dtype=np.float32) * 1e-5)
self.cast2fp32 = P.Cast()
def construct(self, x):
mean = self.reduceMean(x, (2, 3))
mean_stop_grad = F.stop_gradient(mean)
variance = self.reduceMean(self.square(self.sub(x, mean_stop_grad)), (2, 3))
variance = variance + self.eps
inv = self.rsqrt(variance)
normalized = self.sub(x, mean) * inv
x_IN = self.add(self.mul(self.gamma, normalized), self.beta)
return x_IN
class Generator(nn.Cell):
"""generator"""
......@@ -38,15 +72,15 @@ class Generator(nn.Cell):
self.network.append(nn.Dense(latent_size, 256 * 7 * 7, has_bias=False))
self.network.append(Reshape((-1, 256, 7, 7)))
self.network.append(nn.BatchNorm2d(256))
self.network.append(InstanceNorm2d(256))
self.network.append(nn.ReLU())
self.network.append(nn.Conv2dTranspose(256, 128, 5, 1))
self.network.append(nn.BatchNorm2d(128))
self.network.append(InstanceNorm2d(128))
self.network.append(nn.ReLU())
self.network.append(nn.Conv2dTranspose(128, 64, 5, 2))
self.network.append(nn.BatchNorm2d(64))
self.network.append(InstanceNorm2d(64))
self.network.append(nn.ReLU())
self.network.append(nn.Conv2dTranspose(64, 1, 5, 2))
......@@ -64,11 +98,11 @@ class Discriminator(nn.Cell):
self.network = nn.SequentialCell()
self.network.append(nn.Conv2d(1, 64, 5, 2))
self.network.append(nn.BatchNorm2d(64))
self.network.append(InstanceNorm2d(64))
self.network.append(nn.LeakyReLU())
self.network.append(nn.Conv2d(64, 128, 5, 2))
self.network.append(nn.BatchNorm2d(128))
self.network.append(InstanceNorm2d(128))
self.network.append(nn.LeakyReLU())
self.network.append(nn.Flatten())
......
......@@ -41,6 +41,7 @@ def parameter_parser():
parser.add_argument("--batch_size_t", type=int, default=10, help="size of the test batches")
parser.add_argument("--batch_size_v", type=int, default=1000, help="size of the valid batches")
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
parser.add_argument("--data_path", type=str, default="data/MNIST_Data/", help="dataset path")
parser.add_argument("--data_path", type=str, default="mnist/", help="dataset path") # change to train data path
parser.add_argument("--ckpt_path", type=str, default="", help="eval ckpt path") # change to eval ckpt path
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
return parser.parse_args()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment