diff --git a/official/cv/resnet/train.py b/official/cv/resnet/train.py index 1f158548ca822c9366ec1acf6c3980f38bbdca02..166146ee7f956abb5be69a44e01dd1f1ab563598 100644 --- a/official/cv/resnet/train.py +++ b/official/cv/resnet/train.py @@ -203,8 +203,13 @@ def init_weight(net, param_dict): """init_weight""" if config.pre_trained: if param_dict: - config.has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy()) - config.has_trained_step = int(param_dict["step_num"].data.asnumpy()) + if param_dict.get("epoch_num") and param_dict.get("step_num"): + config.has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy()) + config.has_trained_step = int(param_dict["step_num"].data.asnumpy()) + else: + config.has_trained_epoch = 0 + config.has_trained_step = 0 + if config.filter_weight: filter_list = [x.name for x in net.end_point.get_parameters()] filter_checkpoint_parameter_by_list(param_dict, filter_list)