diff --git a/oneflow/python/test/dataloader/test_numpy_dataset.py b/oneflow/python/test/dataloader/test_numpy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2289be657560ba3f3c41e066fc1d5cc1f46ef009 --- /dev/null +++ b/oneflow/python/test/dataloader/test_numpy_dataset.py @@ -0,0 +1,51 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import numpy as np + +import oneflow.experimental as flow +import oneflow.python.utils.data as Data + + +class ScpDataset(Data.Dataset): + def __init__(self, chunksize=200, dim=81, length=2000): + self.chunksize = chunksize + self.dim = dim + self.length = length + + def __getitem__(self, index): + np.random.seed(index) + return np.random.randn(self.chunksize, self.dim) + + def __len__(self): + return self.length + + +@flow.unittest.skip_unless_1n1d() +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestNumpyDataset(flow.unittest.TestCase): + def test_numpy_dataset(test_case): + dataset = ScpDataset() + dataloader = Data.DataLoader(dataset, batch_size=16, shuffle=True) + for X in dataloader: + test_case.assertEqual(X.shape, flow.Size([16, 200, 81])) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/dataloader/test_tensor_dataset.py b/oneflow/python/test/dataloader/test_tensor_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6566b39f62c6715078f8ad660119c37143ce7783 --- /dev/null +++ b/oneflow/python/test/dataloader/test_tensor_dataset.py @@ -0,0 +1,81 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import numpy as np + +import oneflow.experimental as flow +import oneflow.python.utils.data as Data +import oneflow.experimental.nn as nn + + +class LinearNet(nn.Module): + def __init__(self, n_feature): + super(LinearNet, self).__init__() + self.linear = nn.Linear(n_feature, 1) + + def forward(self, x): + y = self.linear(x) + return y + + +@flow.unittest.skip_unless_1n1d() +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestTensorDataset(flow.unittest.TestCase): + def test_tensor_dataset(test_case): + + num_inputs = 2 + num_examples = 1000 + true_w = [2, -3.4] + true_b = 4.2 + + net = LinearNet(num_inputs) + flow.nn.init.normal_(net.linear.weight, mean=0, std=0.01) + flow.nn.init.constant_(net.linear.bias, val=0) + + loss = nn.MSELoss() + optimizer = flow.optim.SGD(net.parameters(), lr=0.03) + + features = flow.tensor( + np.random.normal(0, 1, (num_examples, num_inputs)), dtype=flow.float + ) + labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + labels += flow.tensor( + np.random.normal(0, 0.01, size=labels.size()), dtype=flow.float + ) + + batch_size = 10 + # combine features and labels + dataset = Data.TensorDataset(features, labels) + # random get small batch + data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, num_workers=0) + + num_epochs = 10 + for epoch in range(1, num_epochs + 1): + for X, y in data_iter: + output = net(X) + l = loss(output, y) + optimizer.zero_grad() + l.backward() + optimizer.step() + if epoch == num_epochs: + test_case.assertLess(l.numpy(), 0.00019) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/utils/__init__.py b/oneflow/python/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/oneflow/python/utils/data/__init__.py b/oneflow/python/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..565e7edba44b18f479dce230c73ac8cfd4a319e5 --- /dev/null +++ b/oneflow/python/utils/data/__init__.py @@ -0,0 +1,58 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.python.utils.data.sampler import ( + Sampler, + SequentialSampler, + RandomSampler, + SubsetRandomSampler, + BatchSampler, +) +from oneflow.python.utils.data.dataset import ( + Dataset, + IterableDataset, + TensorDataset, + ConcatDataset, + Subset, + random_split, +) +from oneflow.python.utils.data.dataset import IterableDataset as IterDataPipe +from oneflow.python.utils.data.dataloader import DataLoader, _DatasetKind +from oneflow.python.utils.data.decorator import ( + functional_datapipe, + guaranteed_datapipes_determinism, + non_deterministic, +) + + +__all__ = [ + "Sampler", + "SequentialSampler", + "RandomSampler", + "SubsetRandomSampler", + "BatchSampler", + "Dataset", + "IterableDataset", + "TensorDataset", + "ConcatDataset", + "Subset", + "random_split", + "DataLoader", + "_DatasetKind", + "IterDataPipe", + "functional_datapipe", + "guaranteed_datapipes_determinism", + "non_deterministic", +] diff --git a/oneflow/python/utils/data/_utils/__init__.py b/oneflow/python/utils/data/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2ee7cbaaa9d339aae624a263504e9ddade8bbf --- /dev/null +++ b/oneflow/python/utils/data/_utils/__init__.py @@ -0,0 +1,57 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +r"""Utility classes & functions for data loading. Code in this folder is mostly +used by ../dataloder.py. + +A lot of multiprocessing is used in data loading, which only supports running +functions defined in global environment (py2 can't serialize static methods). +Therefore, for code tidiness we put these functions into different files in this +folder. +""" + +import sys +import atexit + + +IS_WINDOWS = sys.platform == "win32" + +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" + + +python_exit_status = False +r"""Whether Python is shutting down. This flag is guaranteed to be set before +the Python core library resources are freed, but Python may already be exiting +for some time when this is set. + +Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar +hook in Python 3.7 multiprocessing library: +https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 +""" + + +def _set_python_exit_flag(): + global python_exit_status + python_exit_status = True + + +atexit.register(_set_python_exit_flag) + + +from . import collate, fetch diff --git a/oneflow/python/utils/data/_utils/collate.py b/oneflow/python/utils/data/_utils/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2291524ea215fb210f29ee12b4c6bd144d2e59 --- /dev/null +++ b/oneflow/python/utils/data/_utils/collate.py @@ -0,0 +1,114 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to +collate samples fetched from dataset into Tensor(s). + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import oneflow as flow +import re +import collections +import oneflow.python.utils as utils + +string_classes = (str, bytes) + +np_str_obj_array_pattern = re.compile(r"[SaUO]") + + +def default_convert(data): + r"""Converts each NumPy array data field into a tensor""" + elem_type = type(data) + if isinstance(data, (flow.Tensor, flow._oneflow_internal.Tensor)): + return data + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + # array of string classes and object + if ( + elem_type.__name__ == "ndarray" + and np_str_obj_array_pattern.search(data.dtype.str) is not None + ): + return data + return flow.tensor(data) + elif isinstance(data, collections.abc.Mapping): + return {key: default_convert(data[key]) for key in data} + elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple + return elem_type(*(default_convert(d) for d in data)) + elif isinstance(data, collections.abc.Sequence) and not isinstance( + data, string_classes + ): + return [default_convert(d) for d in data] + else: + # NOTE: torch just return data here, and not raise any exception! + raise TypeError(default_convert_err_msg_format.format(elem_type)) + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}" +) + +default_convert_err_msg_format = ( + "default_convert: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}" +) + + +def default_collate(batch): + r"""Puts each data field into a tensor with outer dimension batch size""" + + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, (flow.Tensor, flow._oneflow_internal.Tensor)): + # TODO: tensor.storage()._new_shared(numel) + return flow.experimental.stack(batch, dim=0) + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return default_collate([flow.Tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return flow.Tensor(batch) + elif isinstance(elem, float): + return flow.tensor(batch, dtype=flow.float64) + elif isinstance(elem, int): + return flow.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type(*(default_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError("each element in list of batch should be of equal size") + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/oneflow/python/utils/data/_utils/fetch.py b/oneflow/python/utils/data/_utils/fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..0db20628455fd46833c6d5199cfe71ff27c086a8 --- /dev/null +++ b/oneflow/python/utils/data/_utils/fetch.py @@ -0,0 +1,68 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch +data from an iterable-style or map-style dataset. This logic is shared in both +single- and multi-processing data loading. +""" + + +class _BaseDatasetFetcher(object): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + self.dataset = dataset + self.auto_collation = auto_collation + self.collate_fn = collate_fn + self.drop_last = drop_last + + def fetch(self, possibly_batched_index): + raise NotImplementedError() + + +class _IterableDatasetFetcher(_BaseDatasetFetcher): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + super(_IterableDatasetFetcher, self).__init__( + dataset, auto_collation, collate_fn, drop_last + ) + self.dataset_iter = iter(dataset) + + def fetch(self, possibly_batched_index): + if self.auto_collation: + data = [] + for _ in possibly_batched_index: + try: + data.append(next(self.dataset_iter)) + except StopIteration: + break + if len(data) == 0 or ( + self.drop_last and len(data) < len(possibly_batched_index) + ): + raise StopIteration + else: + data = next(self.dataset_iter) + return self.collate_fn(data) + + +class _MapDatasetFetcher(_BaseDatasetFetcher): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + super(_MapDatasetFetcher, self).__init__( + dataset, auto_collation, collate_fn, drop_last + ) + + def fetch(self, possibly_batched_index): + if self.auto_collation: + data = [self.dataset[idx] for idx in possibly_batched_index] + else: + data = self.dataset[possibly_batched_index] + return self.collate_fn(data) diff --git a/oneflow/python/utils/data/dataloader.py b/oneflow/python/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..58a58763391b5070fb55e2281433be38f49826db --- /dev/null +++ b/oneflow/python/utils/data/dataloader.py @@ -0,0 +1,528 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import sys +import traceback +import os +import warnings +from typing import Any, Callable, TypeVar, Generic, Sequence, List, Optional + +import oneflow as flow + + +class ExceptionWrapper(object): + r"""Wraps an exception plus traceback to communicate across threads""" + + def __init__(self, exc_info=None, where="in background"): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.where = where + + def reraise(self): + r"""Reraises the wrapped exception in the current thread""" + # Format a message such as: "Caught ValueError in DataLoader worker + # process 2. Original Traceback:", followed by the traceback. + msg = "Caught {} {}.\nOriginal {}".format( + self.exc_type.__name__, self.where, self.exc_msg + ) + if self.exc_type == KeyError: + # KeyError calls repr() on its argument (usually a dict key). This + # makes stack traces unreadable. It will not be changed in Python + # (https://bugs.python.org/issue2651), so we work around it. + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) + raise self.exc_type(msg) + + +string_classes = (str, bytes) + +from . import ( + IterableDataset, + Sampler, + SequentialSampler, + RandomSampler, + BatchSampler, + Dataset, +) +from . import _utils + +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[List[T]], Any] + + +# This function used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import flow.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate + + +class _DatasetKind(object): + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + Used as sampler for :class:`~flow.utils.data.IterableDataset`. + + Args: + data_source (Dataset): dataset to sample from + """ + + def __init__(self): + super(_InfiniteConstantSampler, self).__init__(None) + + def __iter__(self): + while True: + yield None + + +class DataLoader(Generic[T_co]): + r""" + Data loader. Combines a dataset and a sampler, and provides an iterable over + the given dataset. + + The :class:`~flow.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`flow.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of samples loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers samples prefetched across all workers. (default: ``2``) + persistent_workers (bool, optional): If ``True``, the data loader will not shutdown + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in OneFlow. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~flow.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess OneFlow can make because OneFlow + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, OneFlow can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~flow.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + """ + dataset: Dataset[T_co] + batch_size: Optional[int] + num_workers: int + drop_last: bool + timeout: float + sampler: Sampler + prefetch_factor: int + _iterator: Optional["_BaseDataLoaderIter"] + __initialized = False + + def __init__( + self, + dataset: Dataset[T_co], + batch_size: Optional[int] = 1, + shuffle: bool = False, + sampler: Optional[Sampler[int]] = None, + batch_sampler: Optional[Sampler[Sequence[int]]] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + generator=None, + *, + prefetch_factor: int = 2, + persistent_workers: bool = False + ): + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if num_workers >= 1: + warnings.warn( + "Not support multiprocessing dataloader yet, we will temporary set num_workers=0!" + ) + num_workers = 0 + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor != 2: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing." + ) + assert prefetch_factor > 0 + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.timeout = timeout + self.worker_init_fn = worker_init_fn + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if shuffle is not False: + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "shuffle option, but got shuffle={}".format(shuffle) + ) + elif sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "sampler option, but got sampler={}".format(sampler) + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "batch_sampler option, but got batch_sampler={}".format( + batch_sampler + ) + ) + else: + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with " "shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + # Cannot statically verify that dataset is Sized + # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + sampler = RandomSampler(dataset, generator=generator) # type: ignore + else: + sampler = SequentialSampler(dataset) + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + self._iterator = None + + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0 or self.num_workers == 1: + return _SingleProcessDataLoaderIter(self) + else: + raise NotImplementedError("Multiprocessing dataloader is not support yet!") + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + "{} attribute should not be set after {} is " + "initialized".format(attr, self.__class__.__name__) + ) + + super(DataLoader, self).__setattr__(attr, val) + + # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up + # since '_BaseDataLoaderIter' references 'DataLoader'. + def __iter__(self) -> "_BaseDataLoaderIter": + # When using a single worker the returned iterator should be + # created everytime to avoid reseting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore + if ( + self.batch_size is not None + ): # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + +class _BaseDataLoaderIter(object): + def __init__(self, loader: DataLoader) -> None: + self._dataset = loader.dataset + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + self._prefetch_factor = loader.prefetch_factor + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = flow.Tensor([0], dtype=flow.int64).uniform_().numpy().item() + # TODO: flow.empty() + # self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item() + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = "enumerate(DataLoader)#{}.__next__".format( + self.__class__.__name__ + ) + + def __iter__(self) -> "_BaseDataLoaderIter": + return self + + def _reset(self, loader, first_iter=False): + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self) -> Any: + if self._sampler_iter is None: + self._reset() + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " + "samples have been fetched. " + ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded) + if self._num_workers > 1: + warn_msg += "Multiprocessing dataloader is not support yet!" + warnings.warn(warn_msg) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super(_SingleProcessDataLoaderIter, self).__init__(loader) + assert self._timeout == 0 + assert 0 <= self._num_workers <= 1 + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + return self._dataset_fetcher.fetch(index) diff --git a/oneflow/python/utils/data/dataset.py b/oneflow/python/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e7133c973eb0d0f0eac1ad226fa6a05865db9417 --- /dev/null +++ b/oneflow/python/utils/data/dataset.py @@ -0,0 +1,312 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import bisect +import warnings +import functools +from typing import ( + TypeVar, + Generic, + Iterable, + Iterator, + Sequence, + List, + Optional, + Tuple, + Dict, + Callable, +) + +import oneflow as flow +from oneflow.python.framework.tensor import Tensor + + +default_generator = flow.Generator() + +# Taken from python 3.5 docs +def _accumulate(iterable, fn=lambda x, y: x + y): + "Return running totals" + # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 + # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + yield total + for element in it: + total = fn(total, element) + yield total + + +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") + + +class Dataset(Generic[T_co]): + r"""An abstract class representing a :class:`Dataset`. + + All datasets that represent a map from keys to data samples should subclass + it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given key. Subclasses could also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~flow.utils.data.Sampler` implementations and the default options + of :class:`~flow.utils.data.DataLoader`. + + .. note:: + :class:`~flow.utils.data.DataLoader` by default constructs a index + sampler that yields integral indices. To make it work with a map-style + dataset with non-integral indices/keys, a custom sampler must be provided. + """ + + def __getitem__(self, index) -> T_co: + raise NotImplementedError + + def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]": + return ConcatDataset([self, other]) + + +class IterableDataset(Dataset[T_co]): + r"""An iterable Dataset. + + All datasets that represent an iterable of data samples should subclass it. + Such form of datasets is particularly useful when data come from a stream. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this dataset. + + When a subclass is used with :class:`~flow.utils.data.DataLoader`, each + item in the dataset will be yielded from the :class:`~flow.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the dataset object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. + + Example 1: splitting workload across all workers in :meth:`__iter__`:: + + >>> class MyIterableDataset(flow.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... iter_start = self.start + ... iter_end = self.end + ... return iter(range(iter_start, iter_end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(flow.utils.data.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + + + Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: + + >>> class MyIterableDataset(flow.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(flow.utils.data.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + + """ + functions: Dict[str, Callable] = {} + reduce_ex_hook: Optional[Callable] = None + + def __iter__(self) -> Iterator[T_co]: + raise NotImplementedError + + # TODO: + # def __add__(self, other: Dataset[T_co]): + # return ChainDataset([self, other]) + + def __getattr__(self, attribute_name): + if attribute_name in IterableDataset.functions: + function = functools.partial( + IterableDataset.functions[attribute_name], self + ) + return function + else: + raise AttributeError + + @classmethod + def register_function(cls, function_name, function): + IterableDataset.functions[function_name] = function + + @classmethod + def register_datapipe_as_function(cls, function_name, cls_to_register): + if function_name in IterableDataset.functions: + raise Exception( + "Unable to add DataPipe function name {} as it is already taken".format( + function_name + ) + ) + + def class_function(cls, source_dp, *args, **kwargs): + return cls(source_dp, *args, **kwargs) + + function = functools.partial(class_function, cls_to_register) + IterableDataset.functions[function_name] = function + + def __reduce_ex__(self, *args, **kwargs): + if IterableDataset.reduce_ex_hook is not None: + try: + return IterableDataset.reduce_ex_hook(self) + except NotImplementedError: + pass + return super().__reduce_ex__(*args, **kwargs) + + @classmethod + def set_reduce_ex_hook(cls, hook_fn): + if IterableDataset.reduce_ex_hook is not None and hook_fn is not None: + raise Exception("Attempt to override existing reduce_ex_hook") + IterableDataset.reduce_ex_hook = hook_fn + + +class TensorDataset(Dataset[Tuple[Tensor, ...]]): + r"""Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Args: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + tensors: Tuple[Tensor, ...] + + def __init__(self, *tensors: Tensor) -> None: + assert all( + tensors[0].size(0) == tensor.size(0) for tensor in tensors + ), "Size mismatch between tensors" + self.tensors = tensors + + def __getitem__(self, index): + return tuple(tensor[index] for tensor in self.tensors) + + def __len__(self): + return self.tensors[0].size(0) + + +class ConcatDataset(Dataset[T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + datasets: List[Dataset[T_co]] + cumulative_sizes: List[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ConcatDataset, self).__init__() + # Cannot verify that datasets is Sized + assert len(datasets) > 0, "datasets should not be an empty iterable" # type: ignore + self.datasets = list(datasets) + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + +class Subset(Dataset[T_co]): + r""" + Subset of a dataset at specified indices. + + Args: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + dataset: Dataset[T_co] + indices: Sequence[int] + + def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + +def random_split( + dataset: Dataset[T], + lengths: Sequence[int], + generator: Optional[flow.Generator] = default_generator, +) -> List[Subset[T]]: + r""" + Randomly split a dataset into non-overlapping new datasets of given lengths. + Optionally fix the generator for reproducible results, e.g.: + + >>> random_split(range(10), [3, 7], generator=flow.Generator().manual_seed(42)) + + Args: + dataset (Dataset): Dataset to be split + lengths (sequence): lengths of splits to be produced + generator (Generator): Generator used for the random permutation. + """ + # Cannot verify that dataset is Sized + if sum(lengths) != len(dataset): # type: ignore + raise ValueError( + "Sum of input lengths does not equal the length of the input dataset!" + ) + + indices = flow.randperm(sum(lengths), generator=generator).tolist() + return [ + Subset(dataset, indices[offset - length : offset]) + for offset, length in zip(_accumulate(lengths), lengths) + ] diff --git a/oneflow/python/utils/data/decorator.py b/oneflow/python/utils/data/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..20fc77af5942d115eec6b980adcf3a6d57a1c953 --- /dev/null +++ b/oneflow/python/utils/data/decorator.py @@ -0,0 +1,126 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from typing import Any, Callable, Optional, Type, Union +from oneflow.python.utils.data import IterDataPipe + + +class functional_datapipe(object): + name: str + + def __init__(self, name: str) -> None: + self.name = name + + def __call__(self, cls): + if isinstance(cls, Type): # type: ignore + if not issubclass(cls, IterDataPipe): + raise TypeError("`functional_datapipe` can only decorate IterDataPipe") + # with non_deterministic decorator + else: + if not isinstance(cls, non_deterministic) and not ( + hasattr(cls, "__self__") and isinstance(cls.__self__, non_deterministic) + ): + raise TypeError("`functional_datapipe` can only decorate IterDataPipe") + IterDataPipe.register_datapipe_as_function(self.name, cls) + return cls + + +_determinism: bool = False + + +class guaranteed_datapipes_determinism(object): + prev: bool + + def __init__(self) -> None: + global _determinism + self.prev = _determinism + _determinism = True + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _determinism + _determinism = self.prev + + +class non_deterministic(object): + cls: Optional[Type[IterDataPipe]] = None + # TODO: Lambda for picking + deterministic_fn: Callable[[], bool] + + def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None: + # 1. Decorator doesn't have any argument + if isinstance(arg, Type): # type: ignore + if not issubclass(arg, IterDataPipe): # type: ignore + raise TypeError( + "Only `IterDataPipe` can be decorated with `non_deterministic`" + ", but {} is found".format(arg.__name__) + ) + self.cls = arg # type: ignore + # 2. Decorator has an argument of a function + # This class should behave differently given different inputs. Use this + # function to verify the determinism for each instance. + # When the function returns True, the instance is non-deterministic. Otherwise, + # the instance is a deterministic DataPipe. + elif isinstance(arg, Callable): # type:ignore + self.deterministic_fn = arg # type: ignore + else: + raise TypeError("{} can not be decorated by non_deterministic".format(arg)) + + def __call__(self, *args, **kwargs): + global _determinism + # Decorate IterDataPipe + if self.cls is not None: + if _determinism: + raise TypeError( + "{} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. " + "You can turn off determinism for this DataPipe if that is acceptable " + "for your application".format(self.cls.__name__) + ) + return self.cls(*args, **kwargs) # type: ignore + + # Decorate with a functional argument + if not ( + isinstance(args[0], Type) + and issubclass( # type: ignore + args[0], IterDataPipe + ) + ): + raise TypeError( + "Only `IterDataPipe` can be decorated, but {} is found".format( + args[0].__name__ + ) + ) + self.cls = args[0] + return self.deterministic_wrapper_fn + + def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe: + res = self.deterministic_fn(*args, **kwargs) # type: ignore + if not isinstance(res, bool): + raise TypeError( + "deterministic_fn of `non_deterministic` decorator is required " + "to return a boolean value, but {} is found".format(type(res)) + ) + global _determinism + if _determinism and res: + raise TypeError( + "{} is non-deterministic with the inputs, but you set " + "'guaranteed_datapipes_determinism'. You can turn off determinism " + "for this DataPipe if that is acceptable for your application".format( + self.cls.__name__ + ) + ) # type: ignore + return self.cls(*args, **kwargs) # type: ignore diff --git a/oneflow/python/utils/data/sampler.py b/oneflow/python/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ae201477c8e454992d8d2bdeb126ef6c667c29f0 --- /dev/null +++ b/oneflow/python/utils/data/sampler.py @@ -0,0 +1,250 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import builtins +from typing import Iterator, Union, Optional, Sequence, List, TypeVar, Generic, Sized +import numpy as np +import oneflow as flow +from oneflow.python.framework.tensor import Tensor + + +T_co = TypeVar("T_co", covariant=True) + + +class Sampler(Generic[T_co]): + r"""Base class for all Samplers. + + Every Sampler subclass has to provide an :meth:`__iter__` method, providing a + way to iterate over indices of dataset elements, and a :meth:`__len__` method + that returns the length of the returned iterators. + + .. note:: The :meth:`__len__` method isn't strictly required by + :class:`~flow.utils.data.DataLoader`, but is expected in any + calculation involving the length of a :class:`~flow.utils.data.DataLoader`. + """ + + def __init__(self, data_source: Optional[Sized]) -> None: + pass + + def __iter__(self) -> Iterator[T_co]: + raise NotImplementedError + + # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # + # Many times we have an abstract class representing a collection/iterable of + # data, e.g., `flow.utils.data.Sampler`, with its subclasses optionally + # implementing a `__len__` method. In such cases, we must make sure to not + # provide a default implementation, because both straightforward default + # implementations have their issues: + # + # + `return NotImplemented`: + # Calling `len(subclass_instance)` raises: + # TypeError: 'NotImplementedType' object cannot be interpreted as an integer + # + # + `raise NotImplementedError()`: + # This prevents triggering some fallback behavior. E.g., the built-in + # `list(X)` tries to call `len(X)` first, and executes a different code + # path if the method is not found or `NotImplemented` is returned, while + # raising an `NotImplementedError` will propagate and and make the call + # fail where it could have use `__iter__` to complete the call. + # + # Thus, the only two sensible things to do are + # + # + **not** provide a default `__len__`. + # + # + raise a `TypeError` instead, which is what Python uses when users call + # a method that is not defined on an object. + # (@ssnl verifies that this works on at least Python 3.7.) + + +class SequentialSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Dataset): dataset to sample from + """ + data_source: Sized + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self) -> int: + return len(self.data_source) + + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. This argument + is supposed to be specified only when `replacement` is ``True``. + generator (Generator): Generator used in sampling. + """ + data_source: Sized + replacement: bool + + def __init__( + self, + data_source: Sized, + replacement: bool = False, + num_samples: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if not isinstance(self.replacement, bool): + raise TypeError( + "replacement should be a boolean value, but got " + "replacement={}".format(self.replacement) + ) + + if self._num_samples is not None and not replacement: + raise ValueError( + "With replacement=False, num_samples should not be specified, " + "since a random permute will be performed." + ) + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError( + "num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples) + ) + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + if self.generator is None: + generator = flow.Generator() + generator.manual_seed( + int(flow.Tensor(1, dtype=flow.int64).xavier_uniform_().numpy()[0]) + ) + else: + generator = self.generator + if self.replacement: + raise NotImplementedError("Not support replacement yet!") + # # TODO: flow.randint + # for _ in range(self.num_samples // 32): + # yield from flow.randint( + # high=n, size=(32,), dtype=flow.int64, generator=generator + # ).tolist() + # yield from flow.randint( + # high=n, + # size=(self.num_samples % 32,), + # dtype=flow.int64, + # generator=generator, + # ).tolist() + else: + yield from np.random.permutation(n).tolist() + # TODO: yield from flow.randperm(n, generator=generator).tolist() + + def __len__(self): + return self.num_samples + + +class SubsetRandomSampler(Sampler[int]): + r"""Samples elements randomly from a given list of indices, without replacement. + + Args: + indices (sequence): a sequence of indices + generator (Generator): Generator used in sampling. + """ + indices: Sequence[int] + + def __init__(self, indices: Sequence[int], generator=None) -> None: + self.indices = indices + self.generator = generator + + def __iter__(self): + return ( + self.indices[i] + for i in flow.randperm(len(self.indices), generator=self.generator) + ) + + def __len__(self): + return len(self.indices) + + +class BatchSampler(Sampler[List[int]]): + r"""Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler or Iterable): Base sampler. Can be any iterable object + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if ( + not isinstance(batch_size, int) + or isinstance(batch_size, bool) + or batch_size <= 0 + ): + raise ValueError( + "batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size) + ) + if not isinstance(drop_last, bool): + raise ValueError( + "drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last) + ) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + # Can only be called if self.sampler has __len__ implemented + # We cannot enforce this condition, so we turn off typechecking for the + # implementation below. + # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore