Skip to content
Snippets Groups Projects
Commit ed6b7027 authored by unknown's avatar unknown Committed by TonyNG
Browse files

modify functions in class MyLossMonitor overriding Callback

parent 11c67d53
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,20 @@ from src import tbnet, config, metrics, dataset
class MyLossMonitor(Callback):
"""My loss monitor definition."""
def epoch_end(self, run_context):
def on_train_epoch_end(self, run_context):
"""Print loss at each epoch end."""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
print('loss:' + str(loss))
def on_eval_epoch_end(self, run_context):
"""Print loss at each epoch end."""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment