Skip to content
Snippets Groups Projects
Unverified Commit c202924b authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2579 Fix import error in mass

Merge pull request !2579 from chenhaozhe/fix-mass-config
parents 2af5d3d0 688b18c1
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,6 @@
"""Loss monitor."""
import time
from mindspore.train.callback import Callback
from config import TransformerConfig
class LossCallBack(Callback):
......@@ -33,11 +32,10 @@ class LossCallBack(Callback):
time_stamp_init = False
time_stamp_first = 0
def __init__(self, config: TransformerConfig, per_print_times: int = 1, rank_id: int = 0):
def __init__(self, per_print_times: int = 1, rank_id: int = 0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self.config = config
self._per_print_times = per_print_times
self.rank_id = rank_id
......
......@@ -220,7 +220,7 @@ def _build_training_pipeline(pre_training_dataset=None,
callbacks.append(time_cb)
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
if rank_size is not None and int(rank_size) > 1:
loss_monitor = LossCallBack(config, rank_id=MultiAscend.get_rank())
loss_monitor = LossCallBack(rank_id=MultiAscend.get_rank())
callbacks.append(loss_monitor)
if MultiAscend.get_rank() % 8 == 0:
ckpt_callback = ModelCheckpoint(
......@@ -234,7 +234,7 @@ def _build_training_pipeline(pre_training_dataset=None,
prefix=config.ckpt_prefix,
directory=os.path.join(ckpt_save_dir, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
config=ckpt_config)
loss_monitor = LossCallBack(config, rank_id=os.getenv('DEVICE_ID'))
loss_monitor = LossCallBack(rank_id=os.getenv('DEVICE_ID'))
callbacks.append(loss_monitor)
callbacks.append(ckpt_callback)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment