Skip to content
Snippets Groups Projects
Unverified Commit a9cbc22c authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2428 [lstm]choice best checkpoint

Merge pull request !2428 from zhouneng/fix_issue_lstm
parents 0df0fba1 9f142b4b
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,9 @@ checkpoint_path: './checkpoint/'
checkpoint_file: './checkpoint/lstm-20_390.ckpt'
device_target: Ascend
enable_profiling: False
eval_start_epoch: 1
eval_interval: 1
run_eval: False
# ==============================================================================
# LSTM CONFIG IN ASCEND for 1p training
......
......@@ -10,6 +10,9 @@ checkpoint_path: './checkpoint/'
checkpoint_file: './checkpoint/lstm-20_390.ckpt'
device_target: Ascend
enable_profiling: False
eval_start_epoch: 1
eval_interval: 1
run_eval: False
# ==============================================================================
# LSTM CONFIG IN ASCEND for 8p training
......
......@@ -10,6 +10,9 @@ checkpoint_path: './checkpoint/'
checkpoint_file: './checkpoint/lstm-20_390.ckpt'
device_target: CPU
enable_profiling: False
eval_start_epoch: 1
eval_interval: 1
run_eval: False
# ==============================================================================
# LSTM CONFIG
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import stat
from mindspore import nn, save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
def apply_eval(eval_param_dict):
metric = nn.F1()
net = eval_param_dict["net"]
eval_dataset = eval_param_dict["dataset"]
for _, data in enumerate(eval_dataset.create_dict_iterator(num_epochs=1)):
feature = data["feature"]
label = data["label"]
prediction = net(feature)
metric.update(prediction, label)
res = metric.eval().mean()
net.set_train(True)
return res
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.best_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.best_ckpt_path):
self.remove_ckpoint_file(self.best_ckpt_path)
save_checkpoint(cb_params.train_network, self.best_ckpt_path)
print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)
......@@ -22,6 +22,7 @@ from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.dataset import convert_to_mindrecord
from src.dataset import lstm_create_dataset
from src.eval_callback import EvalCallBack, apply_eval
from src.lr_schedule import get_lr
from src.lstm import SentimentNet
......@@ -113,10 +114,21 @@ def train_lstm():
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=config.ckpt_path, config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
cb = [time_cb, ckpoint_cb, loss_cb]
if config.run_eval and rank == 0:
eval_net = network
eval_net.set_train(False)
eval_dataset = lstm_create_dataset(config.preprocess_path, config.batch_size, training=False)
eval_param_dict = {"net": eval_net, "dataset": eval_dataset}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=config.ckpt_path, besk_ckpt_name="lstm_best_f1.ckpt",
metrics_name="f1")
cb += [eval_cb]
if config.device_target == "CPU":
model.train(config.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False)
model.train(config.num_epochs, ds_train, callbacks=cb, dataset_sink_mode=False)
else:
model.train(config.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb])
model.train(config.num_epochs, ds_train, callbacks=cb)
print("============== Training Success ==============")
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment