diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py index 5b092b8a852ffd39b146e64578a7c231e40bc3f3..c6697e79c6860bc8bfb0d9486998ba19e87bb97e 100644 --- a/official/nlp/pangu_alpha/train.py +++ b/official/nlp/pangu_alpha/train.py @@ -437,6 +437,7 @@ def run_train_pipeline(args_opt): config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num // micro_batch_interleaved, num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, + use_moe=bool(args_opt.use_moe), num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16, @@ -516,8 +517,6 @@ if __name__ == "__main__": if bool(opt.enable_alltoall) is True and bool(opt.use_moe) is False: raise ValueError("The alltoall communication is only effective when applying moe") if opt.stage_num > 1: - if bool(opt.use_moe) or bool(opt.opt_offload): - raise ValueError("Currently, moe and host device mode is not supported in pipeline parallel.") run_train_pipeline(opt) else: run_train(opt)