Skip to content
Snippets Groups Projects
Commit f10e2d71 authored by zhaoting's avatar zhaoting
Browse files

fix STGAN log save

parent 517e3346
No related branches found
No related tags found
No related merge requests found
Showing
with 51 additions and 70 deletions
......@@ -69,22 +69,22 @@ After installing MindSpore via the official website, you can start training and
```python
# train STGAN
sh scripts/run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]
bash scripts/run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]
# distributed training
sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]
# evaluate STGAN
sh scripts/run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]
bash scripts/run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]
```
- running on GPU
```python
# train STGAN
sh scripts/run_standalone_train_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]
bash scripts/run_standalone_train_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]
# distributed training
sh scripts/run_distribute_train_gpu.sh [EXPERIMENT_NAME] [DATA_PATH]
bash scripts/run_distribute_train_gpu.sh [EXPERIMENT_NAME] [DATA_PATH]
# evaluate STGAN, if you want to evaluate distributed training result, you should enter ./train_parallel
sh scripts/run_eval_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]
bash scripts/run_eval_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]
```
## [Script Description](#contents)
......@@ -143,9 +143,9 @@ Major parameters in train.py and utils/args.py as follows:
```bash
python train.py --dataroot ./dataset --experiment_name 128 > log 2>&1 &
# or run the script
sh scripts/run_standalone_train.sh ./dataset 128 0
bash scripts/run_standalone_train.sh ./dataset 128 0
# distributed training
sh scripts/run_distribute_train.sh ./config/rank_table_8pcs.json 128 /data/dataset
bash scripts/run_distribute_train.sh ./config/rank_table_8pcs.json 128 /data/dataset
```
- running on GPU
......@@ -153,9 +153,9 @@ Major parameters in train.py and utils/args.py as follows:
```bash
python train.py --dataroot ./dataset --experiment_name 128 --platform="GPU" > log 2>&1 &
# or run the script
sh scripts/run_standalone_train_gpu.sh ./dataset 128 0
bash scripts/run_standalone_train_gpu.sh ./dataset 128 0
# distributed training
sh scripts/run_distribute_train_gpu.sh 128 /data/dataset
bash scripts/run_distribute_train_gpu.sh 128 /data/dataset
```
After training, the loss value will be achieved as follows:
......@@ -183,7 +183,7 @@ Before running the command below, please check the checkpoint path used for eval
```bash
python eval.py --dataroot ./dataset --experiment_name 128 > eval_log.txt 2>&1 &
# or run the script
sh scripts/run_eval.sh ./dataset 128 0 ./ckpt/generator.ckpt
bash scripts/run_eval.sh ./dataset 128 0 ./ckpt/generator.ckpt
```
- running on GPU
......@@ -191,7 +191,7 @@ Before running the command below, please check the checkpoint path used for eval
```bash
python eval.py --dataroot ./dataset --experiment_name 128 --platform="GPU" > eval_log.txt 2>&1 &
# or run the script (if you want to evaluate distributed training result, you should enter ./train_parallel, then run the script)
sh scripts/run_eval_gpu.sh ./dataset 128 0 ./ckpt/generator.ckpt
bash scripts/run_eval_gpu.sh ./dataset 128 0 ./ckpt/generator.ckpt
```
You can view the results in the output directory, which contains a batch of result sample images.
......
......@@ -14,14 +14,13 @@
# ============================================================================
""" Model Export """
import numpy as np
from mindspore import context, Tensor
from mindspore import Tensor
from mindspore.train.serialization import export
from src.models import STGANModel
from src.utils import get_args
if __name__ == '__main__':
args = get_args("test")
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id)
model = STGANModel(args)
input_shp = [1, 3, 128, 128]
input_shp_2 = [1, 4]
......
......@@ -151,17 +151,15 @@ class STGANModel(BaseModel):
self.train_G(self.real_x, self.label_org, self.label_trg, attr_diff)
# saving losses
if (self.current_iteration / 5) % self.args.print_freq == 0:
with open(os.path.join(self.train_log_path, 'loss.log'),
'a+') as f:
f.write('Iter: %s\n' % self.current_iteration)
f.write(
'loss D: %s, loss D_real: %s, loss D_fake: %s, loss D_gp: %s, loss D_adv: %s, loss D_cls: %s \n'
% (loss_D, loss_real_D, loss_fake_D, loss_gp_D,
loss_adv_D, loss_cls_D))
f.write(
'loss G: %s, loss G_rec: %s, loss G_fake: %s, loss G_adv: %s, loss G_cls: %s \n\n'
% (loss_G, loss_rec_G, loss_fake_G, loss_adv_G,
loss_cls_G))
print('Iter: %s\n' % self.current_iteration)
print(
'loss D: %s, loss D_real: %s, loss D_fake: %s, loss D_gp: %s, loss D_adv: %s, loss D_cls: %s \n'
% (loss_D, loss_real_D, loss_fake_D, loss_gp_D,
loss_adv_D, loss_cls_D))
print(
'loss G: %s, loss G_rec: %s, loss G_fake: %s, loss G_adv: %s, loss G_cls: %s \n\n'
% (loss_G, loss_rec_G, loss_fake_G, loss_adv_G,
loss_cls_G))
def eval(self, data_loader):
""" Eval function of STGAN
......
......@@ -17,7 +17,7 @@
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]"
echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]"
exit 1
fi
......
......@@ -17,7 +17,7 @@
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [EXPERIMENT_NAME] [DATA_PATH]"
echo "Usage: bash run_distribute_train_gpu.sh [EXPERIMENT_NAME] [DATA_PATH]"
exit 1
fi
......
......@@ -17,7 +17,7 @@
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]"
echo "Usage: bash run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]"
exit 1
fi
......
......@@ -17,7 +17,7 @@
if [ $# != 4 ]
then
echo "Usage: sh run_eval_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]"
echo "Usage: bash run_eval_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]"
exit 1
fi
......
......@@ -17,7 +17,7 @@
if [ $# != 3 ]
then
echo "Usage: sh run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]"
echo "Usage: bash run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]"
exit 1
fi
......
......@@ -17,7 +17,7 @@
if [ $# != 3 ]
then
echo "Usage: sh run_standalone_train_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]"
echo "Usage: bash run_standalone_train_gpu.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]"
exit 1
fi
......
......@@ -61,9 +61,12 @@ class BaseModel(ABC):
'w') as f:
f.write(json.dumps(vars(self.args)))
if self.current_iteration == -1:
with open(os.path.join(self.config_save_path, 'latest.conf'),
'r') as f:
self.current_iteration = int(f.read())
if os.path.exists(os.path.join(self.config_save_path, 'latest.conf')):
with open(os.path.join(self.config_save_path, 'latest.conf'),
'r') as f:
self.current_iteration = int(f.read())
else:
self.current_iteration = 0
# sample save path
if self.isTrain:
......
......@@ -413,8 +413,7 @@ class TrainOneStepGenerator(nn.Cell):
grads = self.grad(self.network, self.weights)(real_x, c_org, c_trg,
attr_diff, sens)
grads = self.grad_reducer(grads)
self.optimizer(grads)
return (loss_G, fake_x, loss_G,
return (F.depend(loss_G, self.optimizer(grads)), fake_x, loss_G,
loss_fake_G, loss_cls_G, loss_rec_G, loss_adv_G)
......@@ -452,6 +451,5 @@ class TrainOneStepDiscriminator(nn.Cell):
grads = self.grad(self.network, self.weights)(real_x, c_org, c_trg,
attr_diff, alpha, sens)
grads = self.grad_reducer(grads)
self.optimizer(grads)
return (loss_D, loss_D, loss_real_D,
return (F.depend(loss_D, self.optimizer(grads)), loss_D, loss_real_D,
loss_fake_D, loss_cls_D, loss_gp_D, loss_adv_D, attr_diff)
......@@ -151,17 +151,15 @@ class STGANModel(BaseModel):
self.train_G(self.real_x, self.label_org, self.label_trg, attr_diff)
# saving losses
if (self.current_iteration / 5) % self.args.print_freq == 0:
with open(os.path.join(self.train_log_path, 'loss.log'),
'a+') as f:
f.write('Iter: %s\n' % self.current_iteration)
f.write(
'loss D: %s, loss D_real: %s, loss D_fake: %s, loss D_gp: %s, loss D_adv: %s, loss D_cls: %s \n'
% (loss_D, loss_real_D, loss_fake_D, loss_gp_D,
loss_adv_D, loss_cls_D))
f.write(
'loss G: %s, loss G_rec: %s, loss G_fake: %s, loss G_adv: %s, loss G_cls: %s \n\n'
% (loss_G, loss_rec_G, loss_fake_G, loss_adv_G,
loss_cls_G))
print('Iter: %s\n' % self.current_iteration)
print(
'loss D: %s, loss D_real: %s, loss D_fake: %s, loss D_gp: %s, loss D_adv: %s, loss D_cls: %s \n'
% (loss_D, loss_real_D, loss_fake_D, loss_gp_D,
loss_adv_D, loss_cls_D))
print(
'loss G: %s, loss G_rec: %s, loss G_fake: %s, loss G_adv: %s, loss G_cls: %s \n\n'
% (loss_G, loss_rec_G, loss_fake_G, loss_adv_G,
loss_cls_G))
def eval(self, data_loader):
""" Eval function of STGAN
......
......@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""arguments"""
import os
import argparse
import ast
import datetime
......@@ -266,19 +265,7 @@ def get_args(phase):
assert args.experiment_name != default_experiment_name, "--experiment_name should be assigned in test mode"
if args.continue_train:
assert args.experiment_name != default_experiment_name, "--experiment_name should be assigned in continue"
if args.device_num > 1 and args.platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target=args.platform,
save_graphs=args.save_graphs,
device_id=int(os.environ["DEVICE_ID"]))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
device_num=args.device_num)
init()
args.rank = int(os.environ["DEVICE_ID"])
elif args.device_num > 1 and args.platform == "GPU":
if args.device_num > 1:
init()
context.reset_auto_parallel_context()
args.rank = get_rank()
......
......@@ -14,8 +14,6 @@
# ============================================================================
""" STGAN TRAIN"""
import time
import tqdm
from mindspore.common import set_seed
from src.models import STGANModel
......@@ -45,9 +43,9 @@ def train():
model = STGANModel(args)
it_count = 0
for _ in tqdm.trange(args.n_epochs, desc='Epoch Loop', unit='epoch'):
for epoch in range(args.n_epochs):
start_epoch_time = time.time()
for _ in tqdm.trange(iter_per_epoch, desc='Step Loop', unit='step'):
for _ in range(iter_per_epoch):
if model.current_iteration > it_count:
it_count += 1
continue
......@@ -66,14 +64,14 @@ def train():
model.eval(data_loader)
except KeyboardInterrupt:
logger.info('You have entered CTRL+C.. Wait to finalize')
print('You have entered CTRL+C.. Wait to finalize')
model.save_networks()
it_count += 1
model.current_iteration = it_count
if args.rank == 0:
with open('performance.log', "a") as f:
f.write('average speed: {}ms/step\n'.format((time.time() - start_epoch_time)*1000/iter_per_epoch))
print('epoch: {}, average speed: {}ms/step'.format(
epoch, (time.time() - start_epoch_time) * 1000 / iter_per_epoch))
model.save_networks()
print('\n\n=============== finish training ===============\n\n')
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment