diff --git a/official/cv/retinanet/train.py b/official/cv/retinanet/train.py
index 59475849a3a808e5130865a57167103c55bda2df..fcb40b81f586ac74701205acc2b7f43004d0c0c6 100644
--- a/official/cv/retinanet/train.py
+++ b/official/cv/retinanet/train.py
@@ -126,6 +126,7 @@ def main():
     config.lr_init = ast.literal_eval(config.lr_init)
     config.lr_end_rate = ast.literal_eval(config.lr_end_rate)
     device_id = get_device_id()
+    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
     if config.device_target == "Ascend":
         if context.get_context("mode") == context.PYNATIVE_MODE:
             context.set_context(mempool_block_size="31GB")
@@ -136,7 +137,6 @@ def main():
         config.distribute = False
     else:
         raise ValueError(f"device_target support ['Ascend', 'GPU', 'CPU'], but get {config.device_target}")
-    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
     if config.distribute:
         init()
         device_num = get_device_num()