Skip to content
Snippets Groups Projects
Unverified Commit 30a37272 authored by Luyang's avatar Luyang Committed by GitHub
Browse files

Dev flow.utils.data part1 (#5406)


* refine and add test case

* support ellipsis type slice

* refine

* refine

* support slice assign ellipsis type

* refine

* register fn to localtensor

* basic implemetation of dataloader

* format

* refine as comments

* fix comments

* remove useless code

* add test case into unnitest

* refine ascomments

* refine as comments

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent b0dd35e5
No related branches found
No related tags found
No related merge requests found
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow.experimental as flow
import oneflow.python.utils.data as Data
class ScpDataset(Data.Dataset):
def __init__(self, chunksize=200, dim=81, length=2000):
self.chunksize = chunksize
self.dim = dim
self.length = length
def __getitem__(self, index):
np.random.seed(index)
return np.random.randn(self.chunksize, self.dim)
def __len__(self):
return self.length
@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestNumpyDataset(flow.unittest.TestCase):
def test_numpy_dataset(test_case):
dataset = ScpDataset()
dataloader = Data.DataLoader(dataset, batch_size=16, shuffle=True)
for X in dataloader:
test_case.assertEqual(X.shape, flow.Size([16, 200, 81]))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow.experimental as flow
import oneflow.python.utils.data as Data
import oneflow.experimental.nn as nn
class LinearNet(nn.Module):
def __init__(self, n_feature):
super(LinearNet, self).__init__()
self.linear = nn.Linear(n_feature, 1)
def forward(self, x):
y = self.linear(x)
return y
@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTensorDataset(flow.unittest.TestCase):
def test_tensor_dataset(test_case):
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
net = LinearNet(num_inputs)
flow.nn.init.normal_(net.linear.weight, mean=0, std=0.01)
flow.nn.init.constant_(net.linear.bias, val=0)
loss = nn.MSELoss()
optimizer = flow.optim.SGD(net.parameters(), lr=0.03)
features = flow.tensor(
np.random.normal(0, 1, (num_examples, num_inputs)), dtype=flow.float
)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += flow.tensor(
np.random.normal(0, 0.01, size=labels.size()), dtype=flow.float
)
batch_size = 10
# combine features and labels
dataset = Data.TensorDataset(features, labels)
# random get small batch
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, num_workers=0)
num_epochs = 10
for epoch in range(1, num_epochs + 1):
for X, y in data_iter:
output = net(X)
l = loss(output, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
if epoch == num_epochs:
test_case.assertLess(l.numpy(), 0.00019)
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow.python.utils.data.sampler import (
Sampler,
SequentialSampler,
RandomSampler,
SubsetRandomSampler,
BatchSampler,
)
from oneflow.python.utils.data.dataset import (
Dataset,
IterableDataset,
TensorDataset,
ConcatDataset,
Subset,
random_split,
)
from oneflow.python.utils.data.dataset import IterableDataset as IterDataPipe
from oneflow.python.utils.data.dataloader import DataLoader, _DatasetKind
from oneflow.python.utils.data.decorator import (
functional_datapipe,
guaranteed_datapipes_determinism,
non_deterministic,
)
__all__ = [
"Sampler",
"SequentialSampler",
"RandomSampler",
"SubsetRandomSampler",
"BatchSampler",
"Dataset",
"IterableDataset",
"TensorDataset",
"ConcatDataset",
"Subset",
"random_split",
"DataLoader",
"_DatasetKind",
"IterDataPipe",
"functional_datapipe",
"guaranteed_datapipes_determinism",
"non_deterministic",
]
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
r"""Utility classes & functions for data loading. Code in this folder is mostly
used by ../dataloder.py.
A lot of multiprocessing is used in data loading, which only supports running
functions defined in global environment (py2 can't serialize static methods).
Therefore, for code tidiness we put these functions into different files in this
folder.
"""
import sys
import atexit
IS_WINDOWS = sys.platform == "win32"
MP_STATUS_CHECK_INTERVAL = 5.0
r"""Interval (in seconds) to check status of processes to avoid hanging in
multiprocessing data loading. This is mainly used in getting data from
another process, in which case we need to periodically check whether the
sender is alive to prevent hanging."""
python_exit_status = False
r"""Whether Python is shutting down. This flag is guaranteed to be set before
the Python core library resources are freed, but Python may already be exiting
for some time when this is set.
Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
hook in Python 3.7 multiprocessing library:
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
"""
def _set_python_exit_flag():
global python_exit_status
python_exit_status = True
atexit.register(_set_python_exit_flag)
from . import collate, fetch
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import oneflow as flow
import re
import collections
import oneflow.python.utils as utils
string_classes = (str, bytes)
np_str_obj_array_pattern = re.compile(r"[SaUO]")
def default_convert(data):
r"""Converts each NumPy array data field into a tensor"""
elem_type = type(data)
if isinstance(data, (flow.Tensor, flow._oneflow_internal.Tensor)):
return data
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
# array of string classes and object
if (
elem_type.__name__ == "ndarray"
and np_str_obj_array_pattern.search(data.dtype.str) is not None
):
return data
return flow.tensor(data)
elif isinstance(data, collections.abc.Mapping):
return {key: default_convert(data[key]) for key in data}
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
return elem_type(*(default_convert(d) for d in data))
elif isinstance(data, collections.abc.Sequence) and not isinstance(
data, string_classes
):
return [default_convert(d) for d in data]
else:
# NOTE: torch just return data here, and not raise any exception!
raise TypeError(default_convert_err_msg_format.format(elem_type))
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)
default_convert_err_msg_format = (
"default_convert: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, (flow.Tensor, flow._oneflow_internal.Tensor)):
# TODO: tensor.storage()._new_shared(numel)
return flow.experimental.stack(batch, dim=0)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([flow.Tensor(b) for b in batch])
elif elem.shape == (): # scalars
return flow.Tensor(batch)
elif isinstance(elem, float):
return flow.tensor(batch, dtype=flow.float64)
elif isinstance(elem, int):
return flow.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""
class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(
dataset, auto_collation, collate_fn, drop_last
)
self.dataset_iter = iter(dataset)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (
self.drop_last and len(data) < len(possibly_batched_index)
):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(
dataset, auto_collation, collate_fn, drop_last
)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
This diff is collapsed.
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import bisect
import warnings
import functools
from typing import (
TypeVar,
Generic,
Iterable,
Iterator,
Sequence,
List,
Optional,
Tuple,
Dict,
Callable,
)
import oneflow as flow
from oneflow.python.framework.tensor import Tensor
default_generator = flow.Generator()
# Taken from python 3.5 docs
def _accumulate(iterable, fn=lambda x, y: x + y):
"Return running totals"
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = fn(total, element)
yield total
T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~flow.utils.data.Sampler` implementations and the default options
of :class:`~flow.utils.data.DataLoader`.
.. note::
:class:`~flow.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
return ConcatDataset([self, other])
class IterableDataset(Dataset[T_co]):
r"""An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.
Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:`__iter__`, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:`~flow.utils.data.DataLoader`, each
item in the dataset will be yielded from the :class:`~flow.utils.data.DataLoader`
iterator. When :attr:`num_workers > 0`, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers.
Example 1: splitting workload across all workers in :meth:`__iter__`::
>>> class MyIterableDataset(flow.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... iter_start = self.start
... iter_end = self.end
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(flow.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
>>> class MyIterableDataset(flow.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(flow.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
"""
functions: Dict[str, Callable] = {}
reduce_ex_hook: Optional[Callable] = None
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
# TODO:
# def __add__(self, other: Dataset[T_co]):
# return ChainDataset([self, other])
def __getattr__(self, attribute_name):
if attribute_name in IterableDataset.functions:
function = functools.partial(
IterableDataset.functions[attribute_name], self
)
return function
else:
raise AttributeError
@classmethod
def register_function(cls, function_name, function):
IterableDataset.functions[function_name] = function
@classmethod
def register_datapipe_as_function(cls, function_name, cls_to_register):
if function_name in IterableDataset.functions:
raise Exception(
"Unable to add DataPipe function name {} as it is already taken".format(
function_name
)
)
def class_function(cls, source_dp, *args, **kwargs):
return cls(source_dp, *args, **kwargs)
function = functools.partial(class_function, cls_to_register)
IterableDataset.functions[function_name] = function
def __reduce_ex__(self, *args, **kwargs):
if IterableDataset.reduce_ex_hook is not None:
try:
return IterableDataset.reduce_ex_hook(self)
except NotImplementedError:
pass
return super().__reduce_ex__(*args, **kwargs)
@classmethod
def set_reduce_ex_hook(cls, hook_fn):
if IterableDataset.reduce_ex_hook is not None and hook_fn is not None:
raise Exception("Attempt to override existing reduce_ex_hook")
IterableDataset.reduce_ex_hook = hook_fn
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(
tensors[0].size(0) == tensor.size(0) for tensor in tensors
), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, "datasets should not be an empty iterable" # type: ignore
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
class Subset(Dataset[T_co]):
r"""
Subset of a dataset at specified indices.
Args:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
def random_split(
dataset: Dataset[T],
lengths: Sequence[int],
generator: Optional[flow.Generator] = default_generator,
) -> List[Subset[T]]:
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Optionally fix the generator for reproducible results, e.g.:
>>> random_split(range(10), [3, 7], generator=flow.Generator().manual_seed(42))
Args:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
generator (Generator): Generator used for the random permutation.
"""
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)
indices = flow.randperm(sum(lengths), generator=generator).tolist()
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(_accumulate(lengths), lengths)
]
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any, Callable, Optional, Type, Union
from oneflow.python.utils.data import IterDataPipe
class functional_datapipe(object):
name: str
def __init__(self, name: str) -> None:
self.name = name
def __call__(self, cls):
if isinstance(cls, Type): # type: ignore
if not issubclass(cls, IterDataPipe):
raise TypeError("`functional_datapipe` can only decorate IterDataPipe")
# with non_deterministic decorator
else:
if not isinstance(cls, non_deterministic) and not (
hasattr(cls, "__self__") and isinstance(cls.__self__, non_deterministic)
):
raise TypeError("`functional_datapipe` can only decorate IterDataPipe")
IterDataPipe.register_datapipe_as_function(self.name, cls)
return cls
_determinism: bool = False
class guaranteed_datapipes_determinism(object):
prev: bool
def __init__(self) -> None:
global _determinism
self.prev = _determinism
_determinism = True
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _determinism
_determinism = self.prev
class non_deterministic(object):
cls: Optional[Type[IterDataPipe]] = None
# TODO: Lambda for picking
deterministic_fn: Callable[[], bool]
def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None:
# 1. Decorator doesn't have any argument
if isinstance(arg, Type): # type: ignore
if not issubclass(arg, IterDataPipe): # type: ignore
raise TypeError(
"Only `IterDataPipe` can be decorated with `non_deterministic`"
", but {} is found".format(arg.__name__)
)
self.cls = arg # type: ignore
# 2. Decorator has an argument of a function
# This class should behave differently given different inputs. Use this
# function to verify the determinism for each instance.
# When the function returns True, the instance is non-deterministic. Otherwise,
# the instance is a deterministic DataPipe.
elif isinstance(arg, Callable): # type:ignore
self.deterministic_fn = arg # type: ignore
else:
raise TypeError("{} can not be decorated by non_deterministic".format(arg))
def __call__(self, *args, **kwargs):
global _determinism
# Decorate IterDataPipe
if self.cls is not None:
if _determinism:
raise TypeError(
"{} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. "
"You can turn off determinism for this DataPipe if that is acceptable "
"for your application".format(self.cls.__name__)
)
return self.cls(*args, **kwargs) # type: ignore
# Decorate with a functional argument
if not (
isinstance(args[0], Type)
and issubclass( # type: ignore
args[0], IterDataPipe
)
):
raise TypeError(
"Only `IterDataPipe` can be decorated, but {} is found".format(
args[0].__name__
)
)
self.cls = args[0]
return self.deterministic_wrapper_fn
def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe:
res = self.deterministic_fn(*args, **kwargs) # type: ignore
if not isinstance(res, bool):
raise TypeError(
"deterministic_fn of `non_deterministic` decorator is required "
"to return a boolean value, but {} is found".format(type(res))
)
global _determinism
if _determinism and res:
raise TypeError(
"{} is non-deterministic with the inputs, but you set "
"'guaranteed_datapipes_determinism'. You can turn off determinism "
"for this DataPipe if that is acceptable for your application".format(
self.cls.__name__
)
) # type: ignore
return self.cls(*args, **kwargs) # type: ignore
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import builtins
from typing import Iterator, Union, Optional, Sequence, List, TypeVar, Generic, Sized
import numpy as np
import oneflow as flow
from oneflow.python.framework.tensor import Tensor
T_co = TypeVar("T_co", covariant=True)
class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~flow.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~flow.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
#
# Many times we have an abstract class representing a collection/iterable of
# data, e.g., `flow.utils.data.Sampler`, with its subclasses optionally
# implementing a `__len__` method. In such cases, we must make sure to not
# provide a default implementation, because both straightforward default
# implementations have their issues:
#
# + `return NotImplemented`:
# Calling `len(subclass_instance)` raises:
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
#
# + `raise NotImplementedError()`:
# This prevents triggering some fallback behavior. E.g., the built-in
# `list(X)` tries to call `len(X)` first, and executes a different code
# path if the method is not found or `NotImplemented` is returned, while
# raising an `NotImplementedError` will propagate and and make the call
# fail where it could have use `__iter__` to complete the call.
#
# Thus, the only two sensible things to do are
#
# + **not** provide a default `__len__`.
#
# + raise a `TypeError` instead, which is what Python uses when users call
# a method that is not defined on an object.
# (@ssnl verifies that this works on at least Python 3.7.)
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
is supposed to be specified only when `replacement` is ``True``.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(
self,
data_source: Sized,
replacement: bool = False,
num_samples: Optional[int] = None,
generator=None,
) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError(
"replacement should be a boolean value, but got "
"replacement={}".format(self.replacement)
)
if self._num_samples is not None and not replacement:
raise ValueError(
"With replacement=False, num_samples should not be specified, "
"since a random permute will be performed."
)
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError(
"num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples)
)
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.generator is None:
generator = flow.Generator()
generator.manual_seed(
int(flow.Tensor(1, dtype=flow.int64).xavier_uniform_().numpy()[0])
)
else:
generator = self.generator
if self.replacement:
raise NotImplementedError("Not support replacement yet!")
# # TODO: flow.randint
# for _ in range(self.num_samples // 32):
# yield from flow.randint(
# high=n, size=(32,), dtype=flow.int64, generator=generator
# ).tolist()
# yield from flow.randint(
# high=n,
# size=(self.num_samples % 32,),
# dtype=flow.int64,
# generator=generator,
# ).tolist()
else:
yield from np.random.permutation(n).tolist()
# TODO: yield from flow.randperm(n, generator=generator).tolist()
def __len__(self):
return self.num_samples
class SubsetRandomSampler(Sampler[int]):
r"""Samples elements randomly from a given list of indices, without replacement.
Args:
indices (sequence): a sequence of indices
generator (Generator): Generator used in sampling.
"""
indices: Sequence[int]
def __init__(self, indices: Sequence[int], generator=None) -> None:
self.indices = indices
self.generator = generator
def __iter__(self):
return (
self.indices[i]
for i in flow.randperm(len(self.indices), generator=self.generator)
)
def __len__(self):
return len(self.indices)
class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if (
not isinstance(batch_size, int)
or isinstance(batch_size, bool)
or batch_size <= 0
):
raise ValueError(
"batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size)
)
if not isinstance(drop_last, bool):
raise ValueError(
"drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last)
)
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment