Skip to content
Snippets Groups Projects
Unverified Commit d699a57c authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2643 [PanGuAlpha] Fix softmax dtype and eod token error

Merge pull request !2643 from huangxinjing/fix_softmax_error
parents 1b3b087c fd7bfa1d
No related branches found
No related tags found
No related merge requests found
......@@ -281,6 +281,7 @@ class PanguAlpha_Model(Cell):
hidden_act=config.hidden_act,
param_init_type=config.param_init_type,
use_past=config.use_past,
softmax_compute_type=config.softmax_compute_type,
parallel_config=config.parallel_config)
if config.parallel_config.recompute:
self.top_query_layer.recompute()
......
......@@ -176,7 +176,7 @@ if __name__ == '__main__':
parser.add_argument('--file_batch_size', type=int, default=1024)
parser.add_argument('--num_process', type=int, default=64)
parser.add_argument('--seq_length', type=int, default=1025)
parser.add_argument('--eot', type=int, default=50256)
parser.add_argument('--eot', type=int, default=3, help="Eod of text depends on the vocab file.")
parser.add_argument('--data_column_name', type=str, default='input_ids')
......
......@@ -167,13 +167,13 @@ def run_train(args_opt):
config = PanguAlphaConfig(batch_size=batch_size // micro_batch_interleaved, num_heads=args_opt.num_heads,
hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers,
ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset),
ffn_hidden_size=args_opt.embedding_size * 4, eod_reset=bool(args_opt.eod_reset),
load_ckpt_path=args_opt.load_ckpt_path, expert_num=args_opt.expert_num,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
enable_offload=bool(args_opt.opt_offload), use_moe=bool(args_opt.use_moe),
per_token_num_experts_chosen=args_opt.per_token_num_experts_chosen,
hidden_act='fast_gelu' if args_opt.device_target != "GPU" else 'gelu',
parallel_config=parallel_config)
parallel_config=parallel_config, eod_token=args_opt.eod_id)
print("===config is: ", config, flush=True)
# Define network
pangu_alpha = PanguAlphaModel(config=config)
......@@ -440,9 +440,9 @@ def run_train_pipeline(args_opt):
config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num // micro_batch_interleaved,
num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size,
seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size,
use_moe=bool(args_opt.use_moe),
use_moe=bool(args_opt.use_moe), eod_token=args_opt.eod_id,
num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4,
eod_token=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
eod_reset=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
enable_offload=bool(args_opt.opt_offload), parallel_config=parallel_config)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment