diff --git a/official/recommend/tbnet/train.py b/official/recommend/tbnet/train.py index d7f350232ff153504808a6931e8ee97ed8e7521d..10d6d2f4eb40b9aec3ebd4fff316f129c7d5948b 100644 --- a/official/recommend/tbnet/train.py +++ b/official/recommend/tbnet/train.py @@ -28,7 +28,20 @@ from src import tbnet, config, metrics, dataset class MyLossMonitor(Callback): """My loss monitor definition.""" - def epoch_end(self, run_context): + def on_train_epoch_end(self, run_context): + """Print loss at each epoch end.""" + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + print('loss:' + str(loss)) + + def on_eval_epoch_end(self, run_context): """Print loss at each epoch end.""" cb_params = run_context.original_args() loss = cb_params.net_outputs