diff --git a/official/nlp/pangu_alpha/src/pangu_alpha.py b/official/nlp/pangu_alpha/src/pangu_alpha.py index da85764a2a2266062ac9f963a01edf889ae23aef..2571ce09c3de8f1ef205deafe353fff19d62fdb3 100644 --- a/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -26,7 +26,7 @@ from mindspore.nn import Cell from mindspore.nn.transformer.transformer import VocabEmbedding, TransformerEncoder, TransformerEncoderLayer, \ AttentionMask from mindspore.nn.transformer import MoEConfig -from mindspore.nn.transformer.layers import _LayerNorm, _Dropout +from mindspore.nn.transformer.layers import _LayerNorm class EmbeddingLayer(nn.Cell): @@ -49,8 +49,8 @@ class EmbeddingLayer(nn.Cell): parallel_config=copied_parallel_config.embedding_dp_mp_config) self.add = P.Add().shard( ((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1))) - self.dropout = _Dropout(1 - config.dropout_rate) - self.dropout.shard(((config.parallel_config.data_parallel, 1, 1),)) + self.dropout = nn.Dropout(1 - config.dropout_rate) + self.dropout.dropout.shard(((config.parallel_config.data_parallel, 1, 1),)) self.is_first_iteration = True self.use_past = config.use_past self.batch_size = config.batch_size diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py index 38a8ff9f509a8ddf9e883a84e14743cef42debab..06c8701540e3726372ff1ecf11cc9c5af453ec9f 100644 --- a/official/nlp/pangu_alpha/train.py +++ b/official/nlp/pangu_alpha/train.py @@ -124,7 +124,7 @@ def run_train(args_opt): r"""The main training process.""" os.environ['HCCL_CONNECT_TIMEOUT'] = str(args_opt.hccl_connect_time) # Set execution mode - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, variable_memory_max_size="30GB") + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, max_device_memory="30GB") # Set parallel context rank = 0 device_num = 1 @@ -411,7 +411,7 @@ def run_train_pipeline(args_opt): os.environ['HCCL_CONNECT_TIMEOUT'] = str(args_opt.hccl_connect_time) context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target) - context.set_context(variable_memory_max_size="30GB") + context.set_context(max_device_memory="30GB") rank_id = 0 device_num = 1 if args_opt.distribute == "true":