From ed6b702758c05dd023a724ff0260589f3f023cb2 Mon Sep 17 00:00:00 2001 From: unknown <ng.ngai.fai@huawei.com> Date: Tue, 7 Jun 2022 11:45:39 +0800 Subject: [PATCH] modify functions in class MyLossMonitor overriding Callback --- official/recommend/tbnet/train.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/official/recommend/tbnet/train.py b/official/recommend/tbnet/train.py index d7f350232..10d6d2f4e 100644 --- a/official/recommend/tbnet/train.py +++ b/official/recommend/tbnet/train.py @@ -28,7 +28,20 @@ from src import tbnet, config, metrics, dataset class MyLossMonitor(Callback): """My loss monitor definition.""" - def epoch_end(self, run_context): + def on_train_epoch_end(self, run_context): + """Print loss at each epoch end.""" + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + print('loss:' + str(loss)) + + def on_eval_epoch_end(self, run_context): """Print loss at each epoch end.""" cb_params = run_context.original_args() loss = cb_params.net_outputs -- GitLab