diff --git a/official/cv/retinanet/train.py b/official/cv/retinanet/train.py index 59475849a3a808e5130865a57167103c55bda2df..fcb40b81f586ac74701205acc2b7f43004d0c0c6 100644 --- a/official/cv/retinanet/train.py +++ b/official/cv/retinanet/train.py @@ -126,6 +126,7 @@ def main(): config.lr_init = ast.literal_eval(config.lr_init) config.lr_end_rate = ast.literal_eval(config.lr_end_rate) device_id = get_device_id() + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) if config.device_target == "Ascend": if context.get_context("mode") == context.PYNATIVE_MODE: context.set_context(mempool_block_size="31GB") @@ -136,7 +137,6 @@ def main(): config.distribute = False else: raise ValueError(f"device_target support ['Ascend', 'GPU', 'CPU'], but get {config.device_target}") - context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) if config.distribute: init() device_num = get_device_num()