Skip to content
Snippets Groups Projects
Commit bd6464ab authored by huangxinjing's avatar huangxinjing
Browse files

Fix loss callback

parent c724c792
No related branches found
No related tags found
No related merge requests found
......@@ -16,9 +16,9 @@
Callbacks
"""
import time
import math
import numpy as np
from mindspore.train.callback import Callback
from mindspore import context
from mindspore.context import ParallelMode
......@@ -30,14 +30,16 @@ class LossCallBack(Callback):
If the loss in NAN or INF terminating training.
"""
def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1):
def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1,
is_last_stage=True):
super(LossCallBack, self).__init__()
self._dataset_size = dataset_size
self.local_rank = local_rank
self.has_trained_epoch = has_trained_epoch
self.has_trained_step = has_trained_step
self.micro_size = micro_size
print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
self.is_last_stage = is_last_stage
print("Load the trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
def step_end(self, run_context):
"""
......@@ -50,8 +52,11 @@ class LossCallBack(Callback):
if percent == 0:
epoch_num -= 1
date = time.asctime(time.localtime(time.time()))
loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size
print("time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}".
loss_value = 'no loss for this stage'
if self.is_last_stage:
loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size
loss_value = np.mean(loss_value)
print("time: {} local_rank: {}, epoch: {}, step: {}, loss is {}, overflow is {}, loss scale is {}".
format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch),
cb_params.cur_step_num + int(self.has_trained_step), loss_value,
cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy()))
......
......@@ -37,7 +37,7 @@ class EmbeddingLayer(nn.Cell):
self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
embedding_size=config.hidden_size,
param_init=initializer("normal", [config.vocab_size, config.hidden_size],
dtype=config.param_init_type),
dtype=mstype.float32),
parallel_config=config.parallel_config.embedding_dp_mp_config)
copied_parallel_config = copy.deepcopy(config.parallel_config)
copied_parallel_config.vocab_emb_dp = True
......@@ -45,7 +45,7 @@ class EmbeddingLayer(nn.Cell):
embedding_size=config.hidden_size,
param_init=initializer("normal",
[config.seq_length, config.hidden_size],
dtype=config.param_init_type),
dtype=mstype.float32),
parallel_config=copied_parallel_config.embedding_dp_mp_config)
self.add = P.Add().shard(
((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1)))
......@@ -262,7 +262,7 @@ class PanguAlpha_Model(Cell):
embedding_size=config.hidden_size,
param_init=initializer("normal",
[config.seq_length, config.hidden_size],
dtype=config.param_init_type),
dtype=mstype.float32),
parallel_config=copied_parallel_config.embedding_dp_mp_config)
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
if config.parallel_config.pipeline_stage > 1:
......
......@@ -112,8 +112,7 @@ def run_train(args_opt):
r"""The main training process."""
os.environ['HCCL_CONNECT_TIMEOUT'] = "6000"
# Set execution mode
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(variable_memory_max_size="30GB")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, variable_memory_max_size="30GB")
# Set parallel context
rank = 0
device_num = 1
......@@ -177,16 +176,16 @@ def run_train(args_opt):
param_init_type=config.param_init_type)
else:
optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
# Initial scaling sens
loss_scale_value = math.pow(2, 32)
epoch_num = args_opt.epoch_size
# Dataset loading mindrecord files
ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0,
eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id,
device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num)
actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size)
callback = [TimeMonitor(args_opt.sink_size), LossCallBack(args_opt.sink_size, rank, 0, 0)]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
step_per_epoch = ds.get_dataset_size()
actual_epoch_num = int(epoch_num * step_per_epoch / args_opt.sink_size)
loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved)
callback = [TimeMonitor(args_opt.sink_size), loss_callback]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=math.pow(2, 32), scale_factor=2, scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell(
pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
config=config)
......@@ -207,9 +206,8 @@ def run_train(args_opt):
if not flag:
restore_checkpoint(args_opt, args_opt.sink_size, ds, model,
pangu_alpha_with_grads, epoch=actual_epoch_num)
callback = [TimeMonitor(args_opt.sink_size), LossCallBack(args_opt.sink_size, rank, args_opt.has_trained_epoches,
args_opt.has_trained_steps)]
loss_callback.has_trained_epoch = args_opt.has_trained_epoch
loss_callback.has_trained_step = args_opt.has_trained_step
add_checkpoint_callback_policy(args_opt, callback, rank)
if args_opt.incremental_training:
strategy = model.infer_train_layout(train_dataset=ds, sink_size=args_opt.sink_size)
......@@ -219,7 +217,7 @@ def run_train(args_opt):
range(0, 512)]
print(f"Loading from path {ckpt_file_list[0]}", flush=True)
load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy)
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
print("Dataset size: {}, actual_epoch_num: {}".format(step_per_epoch, actual_epoch_num), flush=True)
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=args_opt.sink_size, dataset_sink_mode=True)
......@@ -417,8 +415,9 @@ def run_train_pipeline(args_opt):
model_parallel_num = args_opt.op_level_model_parallel_num
stage_device_num = int(device_num / args_opt.stage_num)
data_parallel_num = int(stage_device_num / model_parallel_num)
is_last_stage = (rank_id // stage_device_num) == args_opt.stage_num -1
if data_parallel_num <= 1 and args_opt.optimizer_shard == 1:
raise ValueError("The dp must large than 1 when applying optimizer shard.")
raise ValueError("The data parallel number must be larger than 1 when applying optimizer shard.")
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
micro_batch_interleaved = args_opt.micro_batch_interleaved
......@@ -435,14 +434,14 @@ def run_train_pipeline(args_opt):
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
enable_offload=bool(args_opt.opt_offload), parallel_config=parallel_config)
print("===config is: ", config, flush=True)
print("[Configure] is: ", config, flush=True)
pangu_alpha = PanguAlphaModel(config=config)
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
pangu_alpha_with_loss_net = PipelineCell(MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
micro_batch_interleaved),
config.parallel_config.micro_batch_num)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
print("=====args_opt is: ", args_opt, flush=True)
print("[args_opt] is: ", args_opt, flush=True)
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps)
params = pangu_alpha.infer_param_pipeline_stage()
......@@ -464,8 +463,9 @@ def run_train_pipeline(args_opt):
step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id,
micro_size=parallel_config.micro_batch_num)]
loss_callback = LossCallBack(step_per_epoch, rank_id, is_last_stage=is_last_stage,
micro_size=parallel_config.micro_batch_num * micro_batch_interleaved)
callback = [TimeMonitor(callback_size), loss_callback]
loss_scale_value = math.pow(2, 32)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
......@@ -492,8 +492,8 @@ def run_train_pipeline(args_opt):
if not flag:
restore_checkpoint(args_opt, callback_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, args_opt.has_trained_epoches,
args_opt.has_trained_steps)]
loss_callback.has_trained_epoch = args_opt.has_trained_epoch
loss_callback.has_trained_step = args_opt.has_trained_step
add_checkpoint_callback_policy(args_opt, callback, rank_id)
model.train(actual_epoch_num, ds, callbacks=callback,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment