From a4a1c060f3a8ffadeaa3b97cf5e1a3f23526bd7a Mon Sep 17 00:00:00 2001 From: pawn_sxy <1542627907@qq.com> Date: Thu, 23 Dec 2021 21:17:32 +0800 Subject: [PATCH] fix dataset of AttentionCluster --- research/cv/AttentionCluster/src/datasets/mnist_noisy.py | 3 +++ research/cv/AttentionCluster/src/datasets/mnist_sampler.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/research/cv/AttentionCluster/src/datasets/mnist_noisy.py b/research/cv/AttentionCluster/src/datasets/mnist_noisy.py index 943674a76..a2aeee193 100644 --- a/research/cv/AttentionCluster/src/datasets/mnist_noisy.py +++ b/research/cv/AttentionCluster/src/datasets/mnist_noisy.py @@ -13,7 +13,10 @@ # limitations under the License. # ============================================================================ """mnist noisy dataset""" +import os import shutil +import numpy as np +from PIL import Image from src.datasets.mnist_sampler import load_pkl, dump_pkl, load_mnist, \ get_noisy_sampler, get_number_sampler, to_image, put_numbers diff --git a/research/cv/AttentionCluster/src/datasets/mnist_sampler.py b/research/cv/AttentionCluster/src/datasets/mnist_sampler.py index 23e4202b1..6cba12a0f 100644 --- a/research/cv/AttentionCluster/src/datasets/mnist_sampler.py +++ b/research/cv/AttentionCluster/src/datasets/mnist_sampler.py @@ -44,8 +44,8 @@ def load_mnist(data_dir, usage='train'): class NoisySampler: """Noisy Sampler""" - def __init__(self, cnt): - self.cnt = cnt + def __init__(self, cnt=None): + self.cnt = cnt if cnt is not None else {} def add(self, img): for x in np.ndarray.flatten(img): -- GitLab