From dbd9d76e34d1acf7e0fb8b35abbbbd6df554b51b Mon Sep 17 00:00:00 2001 From: daquexian <daquexian566@gmail.com> Date: Fri, 21 May 2021 13:14:27 +0800 Subject: [PATCH] fix flow.save (#4941) Signed-off-by: daquexian <daquexian566@gmail.com> --- oneflow/python/framework/check_point_v2.py | 12 +++++++++--- oneflow/python/framework/tensor.py | 2 +- oneflow/python/test/modules/test_module.py | 22 ++++++++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/oneflow/python/framework/check_point_v2.py b/oneflow/python/framework/check_point_v2.py index a3814ad97..8b8bea7e6 100644 --- a/oneflow/python/framework/check_point_v2.py +++ b/oneflow/python/framework/check_point_v2.py @@ -206,10 +206,16 @@ def _ReadSlice( if isinstance(container, oneflow.Tensor): def ReadFromTensor(tensor, start_nd_idx, stop_nd_idx): - with tensor._placement_scope(): - return _LogicalSlice( - tensor._blob_object, start_nd_idx, stop_nd_idx, None + start_nd_idx = list(map(int, start_nd_idx)) + stop_nd_idx = list(map(int, stop_nd_idx)) + return tensor[ + tuple( + [ + slice(start_nd_idx[i], stop_nd_idx[i]) + for i in range(len(start_nd_idx)) + ] ) + ].numpy() yield from _ForEachSlice(container, ReadFromTensor) elif isinstance(container, EagerBlobTrait): diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 408898f14..1b5d5b4a0 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -621,7 +621,7 @@ class Tensor: if internal_tensor.is_consistent: TODO() - if isinstance(other, Tensor): + if isinstance(other, (Tensor, check_point_v2.FileBackendVariableBlob)): src_np = other.numpy() else: assert isinstance(other, np.ndarray) diff --git a/oneflow/python/test/modules/test_module.py b/oneflow/python/test/modules/test_module.py index 6f7576b1f..670b22c8e 100644 --- a/oneflow/python/test/modules/test_module.py +++ b/oneflow/python/test/modules/test_module.py @@ -17,6 +17,7 @@ import collections.abc from itertools import repeat import unittest from typing import Tuple, Union +import tempfile import numpy as np @@ -166,6 +167,27 @@ class TestModule(flow.unittest.TestCase): test_case.assertEqual(module_num, 2) + def test_save_state_dict(test_case): + class CustomModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.param1 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024)) + self.param2 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024)) + + def forward(self): + return self.param1 + self.param2 + + m = CustomModule() + + res1 = m() + state_dict = m.state_dict() + with tempfile.TemporaryDirectory() as save_dir: + flow.save(state_dict, save_dir) + loaded_state_dict = flow.load(save_dir) + m.load_state_dict(loaded_state_dict) + res2 = m() + test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) + if __name__ == "__main__": unittest.main() -- GitLab