diff --git a/official/cv/resnet/train.py b/official/cv/resnet/train.py
index 1f158548ca822c9366ec1acf6c3980f38bbdca02..166146ee7f956abb5be69a44e01dd1f1ab563598 100644
--- a/official/cv/resnet/train.py
+++ b/official/cv/resnet/train.py
@@ -203,8 +203,13 @@ def init_weight(net, param_dict):
     """init_weight"""
     if config.pre_trained:
         if param_dict:
-            config.has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy())
-            config.has_trained_step = int(param_dict["step_num"].data.asnumpy())
+            if param_dict.get("epoch_num") and param_dict.get("step_num"):
+                config.has_trained_epoch = int(param_dict["epoch_num"].data.asnumpy())
+                config.has_trained_step = int(param_dict["step_num"].data.asnumpy())
+            else:
+                config.has_trained_epoch = 0
+                config.has_trained_step = 0
+
             if config.filter_weight:
                 filter_list = [x.name for x in net.end_point.get_parameters()]
                 filter_checkpoint_parameter_by_list(param_dict, filter_list)