From a4a1c060f3a8ffadeaa3b97cf5e1a3f23526bd7a Mon Sep 17 00:00:00 2001
From: pawn_sxy <1542627907@qq.com>
Date: Thu, 23 Dec 2021 21:17:32 +0800
Subject: [PATCH] fix dataset of AttentionCluster

---
 research/cv/AttentionCluster/src/datasets/mnist_noisy.py   | 3 +++
 research/cv/AttentionCluster/src/datasets/mnist_sampler.py | 4 ++--
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/research/cv/AttentionCluster/src/datasets/mnist_noisy.py b/research/cv/AttentionCluster/src/datasets/mnist_noisy.py
index 943674a76..a2aeee193 100644
--- a/research/cv/AttentionCluster/src/datasets/mnist_noisy.py
+++ b/research/cv/AttentionCluster/src/datasets/mnist_noisy.py
@@ -13,7 +13,10 @@
 # limitations under the License.
 # ============================================================================
 """mnist noisy dataset"""
+import os
 import shutil
+import numpy as np
+from PIL import Image
 from src.datasets.mnist_sampler import load_pkl, dump_pkl, load_mnist, \
     get_noisy_sampler, get_number_sampler, to_image, put_numbers
 
diff --git a/research/cv/AttentionCluster/src/datasets/mnist_sampler.py b/research/cv/AttentionCluster/src/datasets/mnist_sampler.py
index 23e4202b1..6cba12a0f 100644
--- a/research/cv/AttentionCluster/src/datasets/mnist_sampler.py
+++ b/research/cv/AttentionCluster/src/datasets/mnist_sampler.py
@@ -44,8 +44,8 @@ def load_mnist(data_dir, usage='train'):
 
 class NoisySampler:
     """Noisy Sampler"""
-    def __init__(self, cnt):
-        self.cnt = cnt
+    def __init__(self, cnt=None):
+        self.cnt = cnt if cnt is not None else {}
 
     def add(self, img):
         for x in np.ndarray.flatten(img):
-- 
GitLab