Skip to content
Snippets Groups Projects
Unverified Commit bce61374 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2322 Reduce remaining data memory.

Merge pull request !2322 from linqingke/bert
parents 59a3b3bc caea9d27
No related branches found
No related tags found
No related merge requests found
......@@ -23,7 +23,7 @@ echo "For hyper parameter, please note that you should customize the scripts:
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
echo "=============================================================================================================="
CUR_DIR=`pwd`
ulimit -s 102400
ulimit -s 302400
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py \
--run_script_dir=${CUR_DIR}/run_pretrain.py \
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \
......
......@@ -813,6 +813,18 @@ class BertNetworkMatchBucket(nn.Cell):
bucket_list = [seq_length]
self.bucket_list = [bucket for bucket in bucket_list if bucket <= seq_length]
if network.reducer_flag:
reuse_attr = 'reuse_communication_node'
if not network.grad_reducer.split_fusion:
hccl_op = network.grad_reducer.allreduce
network.grad_reducer.allreduce = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion'))
else:
new_op_list = []
for hccl_op in network.grad_reducer.op_list:
new_op = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion'))
new_op_list.append(new_op)
network.grad_reducer.op_list = new_op_list
def construct(self,
input_ids,
input_mask,
......
......@@ -32,9 +32,8 @@ class BucketDatasetGenerator:
dataset (Dataset): The training dataset.
batch_size (Int): The training batchsize.
bucket_list (List): List of different sentence lengths, such as [128, 256, 512]. Default: None.
valid_dataset_len (Int): Prevent communication failure at the end of the dataset. Default: 0.35.
"""
def __init__(self, dataset, batch_size, bucket_list=None, valid_dataset_len=0.35):
def __init__(self, dataset, batch_size, bucket_list=None):
self.dataset = dataset
self.batch_size = batch_size
self.bucket_list = bucket_list
......@@ -42,14 +41,12 @@ class BucketDatasetGenerator:
self.random_list = np.random.binomial(n=(bucket_size - 1), p=0.55, size=self.__len__())
self.random_list = (self.random_list + 2) % bucket_size
self.random_list = [bucket_list[i] for i in self.random_list]
valid_dataset_len = int(valid_dataset_len * self.__len__())
self.random_list = self.random_list[:valid_dataset_len] + [bucket_list[-1]] * self.__len__()
self._init_variables()
def _init_variables(self):
self.data_bucket = {bucket: [] for bucket in self.bucket_list}
self.iter = 0
self.remaining_data_size = 1
self.remaining_data = []
self.stage = 0
def __next__(self):
......@@ -68,6 +65,8 @@ class BucketDatasetGenerator:
self.iter += 1
return self._package_data(data, key)
self.stage = 1
for value in self.data_bucket.values():
self.remaining_data += list(value)
return self._process_remaining_data()
def _package_data(self, data, key):
......@@ -86,20 +85,16 @@ class BucketDatasetGenerator:
def _process_remaining_data(self):
"""process remaining data."""
remaining_data_offset = self.remaining_data_size * self.batch_size
remaining_data = []
for value in self.data_bucket.values():
remaining_data += list(value)
if remaining_data_offset > len(remaining_data) or self.iter >= self.__len__():
if self.batch_size > len(self.remaining_data) or self.iter >= self.__len__():
self._init_variables()
raise StopIteration
self.remaining_data_size += 1
remaining_data = remaining_data[remaining_data_offset - self.batch_size : remaining_data_offset]
remaining_data = self.remaining_data[:self.batch_size]
self.remaining_data = self.remaining_data[self.batch_size:]
self.iter += 1
return self._package_data(remaining_data, self.bucket_list[-1])
def __iter__(self):
self.iter = 0
self._init_variables()
self.iterator = self.dataset.create_tuple_iterator(output_numpy=True)
return self
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment