diff --git a/official/nlp/bert/src/dataset.py b/official/nlp/bert/src/dataset.py index e0bb69b8297be14104ce5bcecfffdf52ba089ee2..370a12e5ba6c38b79c76d70d2aff79cdc0f89656 100644 --- a/official/nlp/bert/src/dataset.py +++ b/official/nlp/bert/src/dataset.py @@ -41,12 +41,10 @@ class BucketDatasetGenerator: self.random_list = np.random.binomial(n=(bucket_size - 1), p=0.55, size=self.__len__()) self.random_list = (self.random_list + 2) % bucket_size self.random_list = [bucket_list[i] for i in self.random_list] - self.max_time = self.batch_size * 5 self._init_variables() def _init_variables(self): self.data_bucket = {bucket: [] for bucket in self.bucket_list} - self.time_count = 0 self.iter = 0 self.remaining_data_size = 1 self.stage = 0 @@ -62,13 +60,10 @@ class BucketDatasetGenerator: break for key in self.data_bucket.keys(): data = self.data_bucket[key] - is_current_key = (self.random_list[self.iter] == key or self.time_count > self.max_time) - if len(data) >= self.batch_size and is_current_key: - self.time_count = 0 + if len(data) >= self.batch_size and self.random_list[self.iter] == key: self.data_bucket[key] = self.data_bucket[key][self.batch_size:] self.iter += 1 return self._package_data(data, key) - self.time_count += 1 self.stage = 1 return self._process_remaining_data()