diff --git a/research/cv/AttentionCluster/src/datasets/mnist_noisy.py b/research/cv/AttentionCluster/src/datasets/mnist_noisy.py index 943674a7643279172f96f8665111d913376385fe..a2aeee193997dad405c6b7ab26b22a6298d857ea 100644 --- a/research/cv/AttentionCluster/src/datasets/mnist_noisy.py +++ b/research/cv/AttentionCluster/src/datasets/mnist_noisy.py @@ -13,7 +13,10 @@ # limitations under the License. # ============================================================================ """mnist noisy dataset""" +import os import shutil +import numpy as np +from PIL import Image from src.datasets.mnist_sampler import load_pkl, dump_pkl, load_mnist, \ get_noisy_sampler, get_number_sampler, to_image, put_numbers diff --git a/research/cv/AttentionCluster/src/datasets/mnist_sampler.py b/research/cv/AttentionCluster/src/datasets/mnist_sampler.py index 23e4202b143fe26a25327fee0aa9cb66a4706172..6cba12a0f6e138c1c820322d291533fb658f60e7 100644 --- a/research/cv/AttentionCluster/src/datasets/mnist_sampler.py +++ b/research/cv/AttentionCluster/src/datasets/mnist_sampler.py @@ -44,8 +44,8 @@ def load_mnist(data_dir, usage='train'): class NoisySampler: """Noisy Sampler""" - def __init__(self, cnt): - self.cnt = cnt + def __init__(self, cnt=None): + self.cnt = cnt if cnt is not None else {} def add(self, img): for x in np.ndarray.flatten(img):