diff --git a/dataset.py b/dataset.py index 4e035206a509bdc83483aa36a635af5e6321552c..fed90cb3b9dca489bfbefb3011e56b5d81f32e11 100644 --- a/dataset.py +++ b/dataset.py @@ -17,8 +17,8 @@ class GetDatasetGenerator: line = f.readline() # 璇诲彇涓嬩竴琛� def __getitem__(self, index): - size1 = 1024 - size2 = 1024 + size1 = 2048 + size2 = 2048 r1 = random.randint(0, 6800 - size1) r2 = random.randint(0, 7200 - size2) image = np.float32(cv2.imread(self.path+"/JPEGImages/"+self.__data[index]+".bmp")) diff --git a/main.py b/main.py index 1a2985f84bfa25e42dd215b3b2692a3263c85ae8..fb6287909574bca25c836aa4f5270af1840614bb 100644 --- a/main.py +++ b/main.py @@ -29,7 +29,7 @@ dataset_path = '/home/GXkaifa1/hexiangdong' train_dataset_generator = GetDatasetGenerator(dataset_path + '/datasets', 'train') train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True) -train_dataset = train_dataset.batch(2, drop_remainder=True) +train_dataset = train_dataset.batch(4, drop_remainder=True) # lr_iter = exponential_lr(3e-5, 20, 0.9, 100, staircase=True)