Skip to content
Snippets Groups Projects
Unverified Commit 08de4a0f authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3227 modify some network scripts

Merge pull request !3227 from anzhengqi/modify-networks
parents 44806c9b 39e5e035
No related branches found
No related tags found
No related merge requests found
......@@ -61,7 +61,7 @@ def infer_net():
# create dataset
dataset = create_dataset(dataset_path=config.data_path, do_train=False, batch_size=config.batch_size,
target=target)
image_size=config.eval_image_size, target=target)
step_size = dataset.get_dataset_size()
# define net
......
......@@ -89,7 +89,8 @@ class ImgDataset:
return len(self.data)
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, image_size=224,
target="Ascend", distribute=False):
"""
create a train or eval imagenet2012 dataset for resnet50
......@@ -123,7 +124,6 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
......@@ -163,7 +163,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
return data_set
def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, image_size=224,
target="Ascend", distribute=False):
"""
create a train or eval imagenet2012 dataset for resnet101
Args:
......@@ -195,7 +196,6 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.475 * 255, 0.451 * 255, 0.392 * 255]
std = [0.275 * 255, 0.267 * 255, 0.278 * 255]
......@@ -232,7 +232,8 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
return data_set
def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, image_size=224,
target="Ascend", distribute=False):
"""
create a train or eval imagenet2012 dataset for se-resnet50
......@@ -264,7 +265,6 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target=
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [123.68, 116.78, 103.94]
std = [1.0, 1.0, 1.0]
......
......@@ -18,6 +18,7 @@ pre_trained: False
num_classes: 2
batch_size: 64
epoch_size: 4
sink_size: -1
weight_decay: 3e-5
keep_checkpoint_max: 1
checkpoint_path: "./checkpoint/"
......
......@@ -18,6 +18,7 @@ pre_trained: False
num_classes: 2
batch_size: 64
epoch_size: 4
sink_size: -1
weight_decay: 3e-5
keep_checkpoint_max: 1
checkpoint_path: "./checkpoint/"
......
......@@ -14,11 +14,40 @@
# ============================================================================
"""TextCNN"""
import os
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn.cell import Cell
import mindspore
from mindspore.train.callback import Callback
class EvalCallback(Callback):
"""
Evaluation per epoch, and save the best accuracy checkpoint.
"""
def __init__(self, model, eval_ds, begin_eval_epoch=1, save_path="./"):
self.model = model
self.eval_ds = eval_ds
self.begin_eval_epoch = begin_eval_epoch
self.best_acc = 0
self.save_path = save_path
def epoch_end(self, run_context):
"""
evaluate at epoch end.
"""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.begin_eval_epoch:
res = self.model.eval(self.eval_ds)
acc = res["acc"]
if acc > self.best_acc:
self.best_acc = acc
mindspore.save_checkpoint(cb_params.train_network, os.path.join(self.save_path, "best_acc.ckpt"))
print("the best epoch is", cur_epoch, "best acc is", self.best_acc)
class SoftmaxCrossEntropyExpand(Cell):
......
......@@ -18,8 +18,9 @@ pre_trained: False
num_classes: 2
batch_size: 64
epoch_size: 4
sink_size: 200
weight_decay: 3e-5
keep_checkpoint_max: 1
keep_checkpoint_max: 10
checkpoint_path: "./checkpoint/"
checkpoint_file_path: "train_textcnn-4_149.ckpt"
word_len: 51
......
......@@ -18,6 +18,7 @@ pre_trained: False
num_classes: 2
batch_size: 64
epoch_size: 4
sink_size: -1
weight_decay: 3e-5
keep_checkpoint_max: 1
checkpoint_path: "./checkpoint/"
......
......@@ -30,7 +30,7 @@ from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_rank_id
from model_utils.config import config
from src.textcnn import TextCNN
from src.textcnn import SoftmaxCrossEntropyExpand
from src.textcnn import SoftmaxCrossEntropyExpand, EvalCallback
from src.dataset import MovieReview, SST2, Subjectivity
set_seed(1)
......@@ -55,7 +55,10 @@ def train_net():
instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9)
dataset = instance.create_train_dataset(batch_size=config.batch_size, epoch_size=config.epoch_size)
eval_dataset = instance.create_test_dataset(batch_size=config.batch_size)
batch_num = dataset.get_dataset_size()
if config.sink_size == -1:
config.sink_size = batch_num
base_lr = float(config.base_lr)
learning_rate = []
......@@ -79,13 +82,19 @@ def train_net():
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})
config_ck = CheckpointConfig(save_checkpoint_steps=int(config.epoch_size * batch_num / 2),
config_ck = CheckpointConfig(save_checkpoint_steps=config.sink_size,
keep_checkpoint_max=config.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck)
loss_cb = LossMonitor()
model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
eval_callback = EvalCallback(model, eval_dataset, save_path=ckpt_save_dir)
if config.device_target == "CPU":
model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
else:
epoch_count = config.epoch_size * batch_num // config.sink_size
model.train(epoch_count, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb, eval_callback],
sink_size=config.sink_size)
print("train success")
......
......@@ -508,3 +508,7 @@ Refer to the [ModelZoo FAQ](https://gitee.com/mindspore/models#FAQ) for some com
**A**: At the end stage of training, the model accuracy usually drifts irregularly. Because we have to use a third-party perl scripts for evaluation, we can't find the best checkpoint as soon as the training process finished.
You can try to evaluate the last several checkpoints to find the best one.
- **Q: why the shape match error such as "For 'Add', x.shape and y.shape need to broadcast." occurs?**
**A**: because all of the parameters is supported by the dataset in readme. if users use new datasets, please modify parameters in the same time.
......@@ -481,3 +481,7 @@ train.py已经设置了一些种子,避免数据集轮换和权重初始化的
- **Q: 为什么我最后的checkpoint的精度不好?**
**A**: 因为我们需要使用一个第三方的perl脚本来进行验证,所以我们没办法在训练的时候就获取最优checkpoint。你可以尝试对最后的多个checkpoint进行验证,从中获取最好的一个。
- **Q: 为什么报错类似"For 'Add', x.shape and y.shape need to broadcast.",Shape不匹配的问题?**
**A**: 因为已有的配置参数是根据readme中提供的数据集配置的,用户更换数据集后,需要根据参数定义重新配置相关参数.
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment