Skip to content
Snippets Groups Projects
Commit 35956f82 authored by huangxinjing's avatar huangxinjing
Browse files

Add gelu or fast_gelu according to the device

parent 5175ab29
No related branches found
No related tags found
No related merge requests found
......@@ -101,8 +101,7 @@ def run_train(args_opt):
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * data_parallel_num
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num,
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
optimizer_shard=bool(args_opt.optimizer_shard),
......@@ -116,15 +115,14 @@ def run_train(args_opt):
load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
enable_offload=bool(args_opt.opt_offload),
hidden_act='fast_gelu' if args_opt.device_taget != "GPU" else 'gelu',
parallel_config=parallel_config)
print("===config is: ", config, flush=True)
# Define network
pangu_alpha = PanguAlphaModel(config=config)
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
print("=====args_opt is: ", args_opt, flush=True)
# Warm-up and cosine decay learning rate
lr = LearningRate(learning_rate=args_opt.start_lr,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment