Skip to content
Snippets Groups Projects
Commit 2561d89b authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!189 [resnet]increase the robustness of breakpoint continuous training function

Merge pull request !189 from zhouneng/code_docs_enable_continue_train
parents f502233c 9122224a
No related branches found
No related tags found
No related merge requests found
......@@ -203,8 +203,13 @@ def init_weight(net, param_dict):
"""init_weight"""
if config.pre_trained:
if param_dict:
if param_dict.get("epoch_num") and param_dict.get("step_num"):
config.has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy())
config.has_trained_step = int(param_dict["step_num"].data.asnumpy())
else:
config.has_trained_epoch = 0
config.has_trained_step = 0
if config.filter_weight:
filter_list = [x.name for x in net.end_point.get_parameters()]
filter_checkpoint_parameter_by_list(param_dict, filter_list)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment