diff --git a/official/nlp/pangu_alpha/src/pangu_alpha.py b/official/nlp/pangu_alpha/src/pangu_alpha.py
index da85764a2a2266062ac9f963a01edf889ae23aef..2571ce09c3de8f1ef205deafe353fff19d62fdb3 100644
--- a/official/nlp/pangu_alpha/src/pangu_alpha.py
+++ b/official/nlp/pangu_alpha/src/pangu_alpha.py
@@ -26,7 +26,7 @@ from mindspore.nn import Cell
 from mindspore.nn.transformer.transformer import VocabEmbedding, TransformerEncoder, TransformerEncoderLayer, \
     AttentionMask
 from mindspore.nn.transformer import MoEConfig
-from mindspore.nn.transformer.layers import _LayerNorm, _Dropout
+from mindspore.nn.transformer.layers import _LayerNorm
 
 
 class EmbeddingLayer(nn.Cell):
@@ -49,8 +49,8 @@ class EmbeddingLayer(nn.Cell):
                                                  parallel_config=copied_parallel_config.embedding_dp_mp_config)
         self.add = P.Add().shard(
             ((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1)))
-        self.dropout = _Dropout(1 - config.dropout_rate)
-        self.dropout.shard(((config.parallel_config.data_parallel, 1, 1),))
+        self.dropout = nn.Dropout(1 - config.dropout_rate)
+        self.dropout.dropout.shard(((config.parallel_config.data_parallel, 1, 1),))
         self.is_first_iteration = True
         self.use_past = config.use_past
         self.batch_size = config.batch_size
diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py
index 38a8ff9f509a8ddf9e883a84e14743cef42debab..06c8701540e3726372ff1ecf11cc9c5af453ec9f 100644
--- a/official/nlp/pangu_alpha/train.py
+++ b/official/nlp/pangu_alpha/train.py
@@ -124,7 +124,7 @@ def run_train(args_opt):
     r"""The main training process."""
     os.environ['HCCL_CONNECT_TIMEOUT'] = str(args_opt.hccl_connect_time)
     # Set execution mode
-    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, variable_memory_max_size="30GB")
+    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, max_device_memory="30GB")
     # Set parallel context
     rank = 0
     device_num = 1
@@ -411,7 +411,7 @@ def run_train_pipeline(args_opt):
     os.environ['HCCL_CONNECT_TIMEOUT'] = str(args_opt.hccl_connect_time)
 
     context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target)
-    context.set_context(variable_memory_max_size="30GB")
+    context.set_context(max_device_memory="30GB")
     rank_id = 0
     device_num = 1
     if args_opt.distribute == "true":