diff --git a/official/nlp/mass/src/utils/loss_monitor.py b/official/nlp/mass/src/utils/loss_monitor.py index 1d3467363af31c47d1e697561531229125bced80..a62837f5d1ea2748336765f1b7dadd5575f730c4 100644 --- a/official/nlp/mass/src/utils/loss_monitor.py +++ b/official/nlp/mass/src/utils/loss_monitor.py @@ -15,7 +15,6 @@ """Loss monitor.""" import time from mindspore.train.callback import Callback -from config import TransformerConfig class LossCallBack(Callback): @@ -33,11 +32,10 @@ class LossCallBack(Callback): time_stamp_init = False time_stamp_first = 0 - def __init__(self, config: TransformerConfig, per_print_times: int = 1, rank_id: int = 0): + def __init__(self, per_print_times: int = 1, rank_id: int = 0): super(LossCallBack, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: raise ValueError("print_step must be int and >= 0.") - self.config = config self._per_print_times = per_print_times self.rank_id = rank_id diff --git a/official/nlp/mass/train.py b/official/nlp/mass/train.py index 22be51cdd9b7380a8787fe73d14a08df96372e03..cb5cb8ff5f0a4906f6e863f892fe6f3c2873b4ea 100644 --- a/official/nlp/mass/train.py +++ b/official/nlp/mass/train.py @@ -220,7 +220,7 @@ def _build_training_pipeline(pre_training_dataset=None, callbacks.append(time_cb) ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) if rank_size is not None and int(rank_size) > 1: - loss_monitor = LossCallBack(config, rank_id=MultiAscend.get_rank()) + loss_monitor = LossCallBack(rank_id=MultiAscend.get_rank()) callbacks.append(loss_monitor) if MultiAscend.get_rank() % 8 == 0: ckpt_callback = ModelCheckpoint( @@ -234,7 +234,7 @@ def _build_training_pipeline(pre_training_dataset=None, prefix=config.ckpt_prefix, directory=os.path.join(ckpt_save_dir, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) - loss_monitor = LossCallBack(config, rank_id=os.getenv('DEVICE_ID')) + loss_monitor = LossCallBack(rank_id=os.getenv('DEVICE_ID')) callbacks.append(loss_monitor) callbacks.append(ckpt_callback)