diff --git a/official/nlp/mass/src/utils/loss_monitor.py b/official/nlp/mass/src/utils/loss_monitor.py
index 1d3467363af31c47d1e697561531229125bced80..a62837f5d1ea2748336765f1b7dadd5575f730c4 100644
--- a/official/nlp/mass/src/utils/loss_monitor.py
+++ b/official/nlp/mass/src/utils/loss_monitor.py
@@ -15,7 +15,6 @@
 """Loss monitor."""
 import time
 from mindspore.train.callback import Callback
-from config import TransformerConfig
 
 
 class LossCallBack(Callback):
@@ -33,11 +32,10 @@ class LossCallBack(Callback):
     time_stamp_init = False
     time_stamp_first = 0
 
-    def __init__(self, config: TransformerConfig, per_print_times: int = 1, rank_id: int = 0):
+    def __init__(self, per_print_times: int = 1, rank_id: int = 0):
         super(LossCallBack, self).__init__()
         if not isinstance(per_print_times, int) or per_print_times < 0:
             raise ValueError("print_step must be int and >= 0.")
-        self.config = config
         self._per_print_times = per_print_times
         self.rank_id = rank_id
 
diff --git a/official/nlp/mass/train.py b/official/nlp/mass/train.py
index 22be51cdd9b7380a8787fe73d14a08df96372e03..cb5cb8ff5f0a4906f6e863f892fe6f3c2873b4ea 100644
--- a/official/nlp/mass/train.py
+++ b/official/nlp/mass/train.py
@@ -220,7 +220,7 @@ def _build_training_pipeline(pre_training_dataset=None,
     callbacks.append(time_cb)
     ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
     if rank_size is not None and int(rank_size) > 1:
-        loss_monitor = LossCallBack(config, rank_id=MultiAscend.get_rank())
+        loss_monitor = LossCallBack(rank_id=MultiAscend.get_rank())
         callbacks.append(loss_monitor)
         if MultiAscend.get_rank() % 8 == 0:
             ckpt_callback = ModelCheckpoint(
@@ -234,7 +234,7 @@ def _build_training_pipeline(pre_training_dataset=None,
             prefix=config.ckpt_prefix,
             directory=os.path.join(ckpt_save_dir, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
             config=ckpt_config)
-        loss_monitor = LossCallBack(config, rank_id=os.getenv('DEVICE_ID'))
+        loss_monitor = LossCallBack(rank_id=os.getenv('DEVICE_ID'))
         callbacks.append(loss_monitor)
         callbacks.append(ckpt_callback)