Skip to content
Snippets Groups Projects
Commit a4a1c060 authored by pawn_sxy's avatar pawn_sxy
Browse files

fix dataset of AttentionCluster

parent a468c884
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,10 @@
# limitations under the License.
# ============================================================================
"""mnist noisy dataset"""
import os
import shutil
import numpy as np
from PIL import Image
from src.datasets.mnist_sampler import load_pkl, dump_pkl, load_mnist, \
get_noisy_sampler, get_number_sampler, to_image, put_numbers
......
......@@ -44,8 +44,8 @@ def load_mnist(data_dir, usage='train'):
class NoisySampler:
"""Noisy Sampler"""
def __init__(self, cnt):
self.cnt = cnt
def __init__(self, cnt=None):
self.cnt = cnt if cnt is not None else {}
def add(self, img):
for x in np.ndarray.flatten(img):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment