diff --git a/official/nlp/pangu_alpha/src/pangu_alpha.py b/official/nlp/pangu_alpha/src/pangu_alpha.py index 0870d09d6efd2d725ce52440ba87e2c8c1335a24..2a98d31de0c43450ff086bf472999b97a55d740a 100644 --- a/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -37,7 +37,7 @@ class EmbeddingLayer(nn.Cell): self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size, embedding_size=config.hidden_size, param_init=initializer("normal", [config.vocab_size, config.hidden_size], - dtype=config.param_init_type), + dtype=mstype.float32), parallel_config=config.parallel_config.embedding_dp_mp_config) copied_parallel_config = copy.deepcopy(config.parallel_config) copied_parallel_config.vocab_emb_dp = True @@ -45,7 +45,7 @@ class EmbeddingLayer(nn.Cell): embedding_size=config.hidden_size, param_init=initializer("normal", [config.seq_length, config.hidden_size], - dtype=config.param_init_type), + dtype=mstype.float32), parallel_config=copied_parallel_config.embedding_dp_mp_config) self.add = P.Add().shard( ((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1))) @@ -80,7 +80,7 @@ class QueryLayer(TransformerEncoderLayer): hidden_dropout_rate=0.1, post_layernorm_residual=False, param_init_type=mstype.float32, - hidden_act='gelu', + hidden_act='fast_gelu', use_past=False, parallel_config=None, softmax_compute_type=mstype.float32): @@ -251,6 +251,7 @@ class PanguAlpha_Model(Cell): attention_dropout_rate=config.dropout_rate, hidden_dropout_rate=config.dropout_rate, lambda_func=set_parallel_configure_for_layer, + hidden_act=config.hidden_act, param_init_type=config.param_init_type, use_past=config.use_past, parallel_config=config.parallel_config, @@ -262,7 +263,7 @@ class PanguAlpha_Model(Cell): embedding_size=config.hidden_size, param_init=initializer("normal", [config.seq_length, config.hidden_size], - dtype=config.param_init_type), + dtype=mstype.float32), parallel_config=copied_parallel_config.embedding_dp_mp_config) self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1 if config.parallel_config.pipeline_stage > 1: