diff --git a/official/recommend/tbnet/train.py b/official/recommend/tbnet/train.py
index d7f350232ff153504808a6931e8ee97ed8e7521d..10d6d2f4eb40b9aec3ebd4fff316f129c7d5948b 100644
--- a/official/recommend/tbnet/train.py
+++ b/official/recommend/tbnet/train.py
@@ -28,7 +28,20 @@ from src import tbnet, config, metrics, dataset
 class MyLossMonitor(Callback):
     """My loss monitor definition."""
 
-    def epoch_end(self, run_context):
+    def on_train_epoch_end(self, run_context):
+        """Print loss at each epoch end."""
+        cb_params = run_context.original_args()
+        loss = cb_params.net_outputs
+
+        if isinstance(loss, (tuple, list)):
+            if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
+                loss = loss[0]
+
+        if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
+            loss = np.mean(loss.asnumpy())
+        print('loss:' + str(loss))
+
+    def on_eval_epoch_end(self, run_context):
         """Print loss at each epoch end."""
         cb_params = run_context.original_args()
         loss = cb_params.net_outputs