Skip to content
Snippets Groups Projects
Unverified Commit 0fef6004 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3468 Fix memory leak issue of Bert_Large_Boost Ascend.

Merge pull request !3468 from archer2049/master
parents 67ff439a 5e5bd29a
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,7 @@ class BucketDatasetGenerator: ...@@ -34,6 +34,7 @@ class BucketDatasetGenerator:
bucket_list (List): List of different sentence lengths, such as [128, 256, 512]. Default: None. bucket_list (List): List of different sentence lengths, such as [128, 256, 512]. Default: None.
valid_dataset_len (Int): Prevent communication failure at the end of the dataset. Default: 0.35. valid_dataset_len (Int): Prevent communication failure at the end of the dataset. Default: 0.35.
""" """
def __init__(self, dataset, batch_size, bucket_list=None, valid_dataset_len=0.35): def __init__(self, dataset, batch_size, bucket_list=None, valid_dataset_len=0.35):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
...@@ -49,6 +50,7 @@ class BucketDatasetGenerator: ...@@ -49,6 +50,7 @@ class BucketDatasetGenerator:
def _init_variables(self): def _init_variables(self):
self.data_bucket = {bucket: [] for bucket in self.bucket_list} self.data_bucket = {bucket: [] for bucket in self.bucket_list}
self.iter = 0 self.iter = 0
self.remaining_data = []
self.remaining_data_size = 1 self.remaining_data_size = 1
self.stage = 0 self.stage = 0
...@@ -86,20 +88,28 @@ class BucketDatasetGenerator: ...@@ -86,20 +88,28 @@ class BucketDatasetGenerator:
def _process_remaining_data(self): def _process_remaining_data(self):
"""process remaining data.""" """process remaining data."""
remaining_data_offset = self.remaining_data_size * self.batch_size for key in self.data_bucket.keys():
remaining_data = [] data = self.data_bucket[key]
if len(data) >= self.batch_size:
self.data_bucket[key] = self.data_bucket[key][self.batch_size:]
self.iter += 1
return self._package_data(data, key)
for value in self.data_bucket.values(): for value in self.data_bucket.values():
remaining_data += list(value) self.remaining_data += list(value)
if remaining_data_offset > len(remaining_data) or self.iter >= self.__len__(): self.data_bucket = dict()
if self.batch_size > len(self.remaining_data) or self.iter >= self.__len__():
self._init_variables() self._init_variables()
raise StopIteration raise StopIteration
self.remaining_data_size += 1
remaining_data = remaining_data[remaining_data_offset - self.batch_size : remaining_data_offset] remaining_data = self.remaining_data[:self.batch_size]
self.remaining_data = self.remaining_data[self.batch_size:]
self.iter += 1 self.iter += 1
return self._package_data(remaining_data, self.bucket_list[-1]) return self._package_data(remaining_data, self.bucket_list[-1])
def __iter__(self): def __iter__(self):
self.iter = 0 self._init_variables()
self.iterator = self.dataset.create_tuple_iterator(output_numpy=True) self.iterator = self.dataset.create_tuple_iterator(output_numpy=True)
return self return self
......
...@@ -33,6 +33,7 @@ class BucketDatasetGenerator: ...@@ -33,6 +33,7 @@ class BucketDatasetGenerator:
batch_size (Int): The training batchsize. batch_size (Int): The training batchsize.
bucket_list (List): List of different sentence lengths, such as [128, 256, 512]. Default: None. bucket_list (List): List of different sentence lengths, such as [128, 256, 512]. Default: None.
""" """
def __init__(self, dataset, batch_size, bucket_list=None): def __init__(self, dataset, batch_size, bucket_list=None):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
...@@ -85,9 +86,21 @@ class BucketDatasetGenerator: ...@@ -85,9 +86,21 @@ class BucketDatasetGenerator:
def _process_remaining_data(self): def _process_remaining_data(self):
"""process remaining data.""" """process remaining data."""
for key in self.data_bucket.keys():
data = self.data_bucket[key]
if len(data) >= self.batch_size:
self.data_bucket[key] = self.data_bucket[key][self.batch_size:]
self.iter += 1
return self._package_data(data, key)
for value in self.data_bucket.values():
self.remaining_data += list(value)
self.data_bucket = dict()
if self.batch_size > len(self.remaining_data) or self.iter >= self.__len__(): if self.batch_size > len(self.remaining_data) or self.iter >= self.__len__():
self._init_variables() self._init_variables()
raise StopIteration raise StopIteration
remaining_data = self.remaining_data[:self.batch_size] remaining_data = self.remaining_data[:self.batch_size]
self.remaining_data = self.remaining_data[self.batch_size:] self.remaining_data = self.remaining_data[self.batch_size:]
self.iter += 1 self.iter += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment