diff --git a/official/nlp/pangu_alpha/src/pangu_alpha.py b/official/nlp/pangu_alpha/src/pangu_alpha.py index 2571ce09c3de8f1ef205deafe353fff19d62fdb3..45a17a3aaf2070404e158e166d3fe35c922fecf7 100644 --- a/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -80,7 +80,7 @@ class QueryLayer(TransformerEncoderLayer): hidden_dropout_rate=0.1, post_layernorm_residual=False, param_init_type=mstype.float32, - hidden_act='gelu', + hidden_act='fast_gelu', use_past=False, parallel_config=None, softmax_compute_type=mstype.float32): @@ -252,6 +252,7 @@ class PanguAlpha_Model(Cell): attention_dropout_rate=config.dropout_rate, hidden_dropout_rate=config.dropout_rate, lambda_func=set_parallel_configure_for_layer, + hidden_act=config.hidden_act, param_init_type=config.param_init_type, use_past=config.use_past, parallel_config=config.parallel_config,