From ed6b702758c05dd023a724ff0260589f3f023cb2 Mon Sep 17 00:00:00 2001
From: unknown <ng.ngai.fai@huawei.com>
Date: Tue, 7 Jun 2022 11:45:39 +0800
Subject: [PATCH] modify functions in class MyLossMonitor overriding Callback

---
 official/recommend/tbnet/train.py | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/official/recommend/tbnet/train.py b/official/recommend/tbnet/train.py
index d7f350232..10d6d2f4e 100644
--- a/official/recommend/tbnet/train.py
+++ b/official/recommend/tbnet/train.py
@@ -28,7 +28,20 @@ from src import tbnet, config, metrics, dataset
 class MyLossMonitor(Callback):
     """My loss monitor definition."""
 
-    def epoch_end(self, run_context):
+    def on_train_epoch_end(self, run_context):
+        """Print loss at each epoch end."""
+        cb_params = run_context.original_args()
+        loss = cb_params.net_outputs
+
+        if isinstance(loss, (tuple, list)):
+            if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
+                loss = loss[0]
+
+        if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
+            loss = np.mean(loss.asnumpy())
+        print('loss:' + str(loss))
+
+    def on_eval_epoch_end(self, run_context):
         """Print loss at each epoch end."""
         cb_params = run_context.original_args()
         loss = cb_params.net_outputs
-- 
GitLab