diff --git a/oneflow/python/framework/typing.py b/oneflow/python/framework/typing.py index c52e9c7d8ca2b37559c22344cfdc796e5570e5e2..7d1a0de094af3477edf771fadd53e6848bbae5b1 100644 --- a/oneflow/python/framework/typing.py +++ b/oneflow/python/framework/typing.py @@ -192,6 +192,16 @@ class Callback(typing.Generic[typing.TypeVar("T")]): pass +@oneflow_export("typing.Bundle") +class Bundle(typing.Generic[typing.TypeVar("T")]): + """ + One or a collection of typing.Numpy/typing.ListNumpy/typing.ListListNumpy, + such as x, [x], (x,), {"key": x} and the mixed form of them. + """ + + pass + + def OriginFrom(parameterised, generic): if inspect.isclass(parameterised) and inspect.isclass(generic): return issubclass(parameterised, generic) @@ -209,5 +219,7 @@ def OriginFrom(parameterised, generic): return parameterised.__origin__ is list if generic == Callback: return parameterised.__origin__ is Callback + if generic == Bundle: + return parameterised.__origin__ is Bundle raise NotImplementedError("python typing is a monster torturing everyone.") diff --git a/oneflow/python/framework/typing_util.py b/oneflow/python/framework/typing_util.py index a21bafa054aab9aab0aa33c3a1fde2a893a44bdf..629e0ffab0f5d6a6080c9fa833e68905e5dd4b52 100644 --- a/oneflow/python/framework/typing_util.py +++ b/oneflow/python/framework/typing_util.py @@ -18,6 +18,8 @@ from __future__ import absolute_import import typing import inspect import oneflow.python.framework.remote_blob as remote_blob_util +import oneflow.python.framework.local_blob as local_blob_util +import oneflow.python.framework.pull_util as pull_util import oneflow.python.framework.typing as oft import oneflow.python.experimental.enable_typing_check as enable_typing_check @@ -68,6 +70,14 @@ def CheckGlobalFunctionReturnAnnotation(cls): ), "T in oneflow.typing.Callback[T] cannot be omitted" assert len(cls.__args__) == 1 _CheckGlobalFunctionReturnAnnotation(cls.__args__[0]) + elif oft.OriginFrom(cls, oft.Bundle): + assert cls.__args__[0] in ( + oft.Numpy, + oft.ListNumpy, + oft.ListListNumpy, + ), "T in oneflow.typing.Bundle[T] must be one of (oneflow.typing.Numpy, oneflow.typing.ListNumpy, oneflow.typing.ListListNumpy)" + assert len(cls.__args__) == 1 + _CheckGlobalFunctionReturnAnnotation(cls.__args__[0]) else: _CheckGlobalFunctionReturnAnnotation(cls) @@ -103,6 +113,17 @@ def CheckReturnByAnnotation(function_name, ret, annotation): assert ret is None, error_str elif oft.OriginFrom(annotation, oft.Callback): _CheckReturnByAnnotation(function_name, ret, annotation.__args__[0]) + elif oft.OriginFrom(annotation, oft.Bundle): + if isinstance(ret, remote_blob_util.BlobDef): + _CheckReturnByAnnotation(function_name, ret, annotation.__args__[0]) + elif isinstance(ret, (list, tuple)): + for elem in ret: + CheckReturnByAnnotation(function_name, elem, annotation) + elif type(ret) is dict: + for val in ret.values(): + CheckReturnByAnnotation(function_name, val, annotation) + else: + raise NotImplementedError("invalid return %s found" % (type(ret))) else: _CheckReturnByAnnotation(function_name, ret, annotation) @@ -174,10 +195,38 @@ def TransformGlobalFunctionResult(future_blob, annotation): return lambda x: f(TransformReturnedLocalBlob(x, annotation)) return lambda f: future_blob.async_get(Transform(f)) + elif oft.OriginFrom(annotation, oft.Bundle): + return TransformReturnedBundle(future_blob.get(), annotation) else: return TransformReturnedLocalBlob(future_blob.get(), annotation) +def TransformReturnedBundle(bundle_blob, annotation): + """ + Transform returned bundle blob from global_function(job_func), + the returned bundle blob could be the form like x, [x], (x, ), + {"key": x} or the mixed form of them. + """ + if isinstance( + bundle_blob, + (local_blob_util.LocalMirroredTensor, local_blob_util.LocalMirroredTensorList), + ): + return TransformReturnedLocalBlob(bundle_blob, annotation.__args__[0]) + elif isinstance(bundle_blob, (list, tuple)): + return type(bundle_blob)( + TransformReturnedBundle(elem, annotation) for elem in bundle_blob + ) + elif type(bundle_blob) is dict: + return { + key: TransformReturnedBundle(val, annotation) + for key, val in bundle_blob.items() + } + else: + raise NotImplementedError( + "invalid return %s : %s found" % (bundle_blob, type(bundle_blob)) + ) + + def TransformReturnedLocalBlob(local_blob, annotation): if oft.OriginFrom(annotation, typing.Tuple): assert type(local_blob) is tuple diff --git a/oneflow/python/test/ops/test_global_function_signature.py b/oneflow/python/test/ops/test_global_function_signature.py index 622cf7fe9116e53f00cca84ce58acbf8960825bf..cfb3c7c3a0469e6e9a0af5f2db10d2027a98e5b8 100644 --- a/oneflow/python/test/ops/test_global_function_signature.py +++ b/oneflow/python/test/ops/test_global_function_signature.py @@ -344,6 +344,206 @@ def test_annotation_Callback_Tuple_ListListNumpy(test_case): foo([[data]])(Test) +def test_annotation_Bundle_Numpy(test_case): + flow.config.gpu_device_num(1) + + @flow.global_function() + def foo(x: oft.Numpy.Placeholder((10,))) -> oft.Bundle[oft.Numpy]: + return x + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo(data), data)) + + +def test_annotation_Bundle_List_Numpy(test_case): + flow.config.gpu_device_num(1) + + @flow.global_function() + def foo(x: oft.Numpy.Placeholder((10,))) -> oft.Bundle[oft.Numpy]: + return [x] + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo(data)[0], data)) + + +def test_annotation_Bundle_Dict_Numpy(test_case): + flow.config.gpu_device_num(1) + + @flow.global_function() + def foo(x: oft.Numpy.Placeholder((10,))) -> oft.Bundle[oft.Numpy]: + return {"x": x} + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo(data)["x"], data)) + + +def test_annotation_Bundle_Tuple_Numpy(test_case): + flow.config.gpu_device_num(1) + + @flow.global_function() + def foo(x: oft.Numpy.Placeholder((10,))) -> oft.Bundle[oft.Numpy]: + return (x,) + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo(data)[0], data)) + + +def test_annotation_Bundle_Mix_Nesting_Numpy(test_case): + flow.config.gpu_device_num(1) + + @flow.global_function() + def foo(x: oft.Numpy.Placeholder((10,))) -> oft.Bundle[oft.Numpy]: + return (x, (x,), [x, x, x], {"x": {256: x}}) + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo(data)[0], data)) + test_case.assertTrue(np.array_equal(foo(data)[1][0], data)) + test_case.assertTrue(np.array_equal(foo(data)[2][0], data)) + test_case.assertTrue(np.array_equal(foo(data)[2][1], data)) + test_case.assertTrue(np.array_equal(foo(data)[2][2], data)) + test_case.assertTrue(np.array_equal(foo(data)[3]["x"][256], data)) + + +def test_annotation_Bundle_ListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListNumpy]: + return x + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([data])[0], data)) + + +def test_annotation_Bundle_List_ListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListNumpy]: + return [x] + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([data])[0][0], data)) + + +def test_annotation_Bundle_Dict_ListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListNumpy]: + return {"x": x} + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([data])["x"][0], data)) + + +def test_annotation_Bundle_Tuple_ListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListNumpy]: + return (x,) + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([data])[0][0], data)) + + +def test_annotation_Bundle_Mix_Nesting_ListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListNumpy]: + return (x, (x,), [x, x, x], {"x": {256: x}}) + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([data])[0][0], data)) + test_case.assertTrue(np.array_equal(foo([data])[1][0][0], data)) + test_case.assertTrue(np.array_equal(foo([data])[2][0][0], data)) + test_case.assertTrue(np.array_equal(foo([data])[2][1][0], data)) + test_case.assertTrue(np.array_equal(foo([data])[2][2][0], data)) + test_case.assertTrue(np.array_equal(foo([data])[3]["x"][256][0], data)) + + +def test_annotation_Bundle_ListListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListListNumpy]: + return x + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([[data]])[0][0], data)) + + +def test_annotation_Bundle_List_ListListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListListNumpy]: + return [x] + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([[data]])[0][0][0], data)) + + +def test_annotation_Bundle_Dict_ListListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListListNumpy]: + return {"x": x} + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([[data]])["x"][0][0], data)) + + +def test_annotation_Bundle_Tuple_ListListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListListNumpy]: + return (x,) + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([[data]])[0][0][0], data)) + + +def test_annotation_Bundle_Mix_Nesting_ListListNumpy(test_case): + flow.config.gpu_device_num(1) + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.mirrored_view()) + + @flow.global_function(function_config=func_config) + def foo(x: oft.ListListNumpy.Placeholder((10,))) -> oft.Bundle[oft.ListListNumpy]: + return (x, (x,), [x, x, x], {"x": {256: x}}) + + data = np.ones((10,), dtype=np.float32) + test_case.assertTrue(np.array_equal(foo([[data]])[0][0][0], data)) + test_case.assertTrue(np.array_equal(foo([[data]])[1][0][0][0], data)) + test_case.assertTrue(np.array_equal(foo([[data]])[2][0][0][0], data)) + test_case.assertTrue(np.array_equal(foo([[data]])[2][1][0][0], data)) + test_case.assertTrue(np.array_equal(foo([[data]])[2][2][0][0], data)) + test_case.assertTrue(np.array_equal(foo([[data]])[3]["x"][256][0][0], data)) + + def test_annotation_return_List_Numpy(test_case): data = np.ones((10,), dtype=np.float32) @@ -364,11 +564,7 @@ def test_annotation_return_List_ListNumpy(test_case): data = np.ones((10,), dtype=np.float32) flow.clear_default_session() - flow.config.gpu_device_num(1) - func_config = flow.FunctionConfig() - func_config.default_logical_view(flow.scope.mirrored_view()) - @flow.global_function(function_config=func_config) def foo(x: oft.ListNumpy.Placeholder(shape=data.shape)) -> List[oft.ListNumpy]: return [x, x] @@ -381,11 +577,7 @@ def test_annotation_return_List_ListListNumpy(test_case): data = np.ones((10,), dtype=np.float32) flow.clear_default_session() - flow.config.gpu_device_num(1) - func_config = flow.FunctionConfig() - func_config.default_logical_view(flow.scope.mirrored_view()) - @flow.global_function(function_config=func_config) def foo( x: oft.ListListNumpy.Placeholder(shape=data.shape), ) -> List[oft.ListListNumpy]: @@ -401,11 +593,7 @@ def test_annotation_return_List_Nesting_Tuple(test_case): y = np.random.rand(10).astype(np.float32) flow.clear_default_session() - flow.config.gpu_device_num(1) - func_config = flow.FunctionConfig() - func_config.default_logical_view(flow.scope.mirrored_view()) - @flow.global_function(function_config=func_config) def foo( x: oft.Numpy.Placeholder(shape=x.shape), y: oft.ListNumpy.Placeholder(shape=y.shape),