Skip to content
Snippets Groups Projects
Commit ed6b7027 authored by unknown's avatar unknown Committed by TonyNG
Browse files

modify functions in class MyLossMonitor overriding Callback

parent 11c67d53
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,20 @@ from src import tbnet, config, metrics, dataset ...@@ -28,7 +28,20 @@ from src import tbnet, config, metrics, dataset
class MyLossMonitor(Callback): class MyLossMonitor(Callback):
"""My loss monitor definition.""" """My loss monitor definition."""
def epoch_end(self, run_context): def on_train_epoch_end(self, run_context):
"""Print loss at each epoch end."""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
print('loss:' + str(loss))
def on_eval_epoch_end(self, run_context):
"""Print loss at each epoch end.""" """Print loss at each epoch end."""
cb_params = run_context.original_args() cb_params = run_context.original_args()
loss = cb_params.net_outputs loss = cb_params.net_outputs
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment