From dbd9d76e34d1acf7e0fb8b35abbbbd6df554b51b Mon Sep 17 00:00:00 2001
From: daquexian <daquexian566@gmail.com>
Date: Fri, 21 May 2021 13:14:27 +0800
Subject: [PATCH] fix flow.save (#4941)

Signed-off-by: daquexian <daquexian566@gmail.com>
---
 oneflow/python/framework/check_point_v2.py | 12 +++++++++---
 oneflow/python/framework/tensor.py         |  2 +-
 oneflow/python/test/modules/test_module.py | 22 ++++++++++++++++++++++
 3 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/oneflow/python/framework/check_point_v2.py b/oneflow/python/framework/check_point_v2.py
index a3814ad97..8b8bea7e6 100644
--- a/oneflow/python/framework/check_point_v2.py
+++ b/oneflow/python/framework/check_point_v2.py
@@ -206,10 +206,16 @@ def _ReadSlice(
     if isinstance(container, oneflow.Tensor):
 
         def ReadFromTensor(tensor, start_nd_idx, stop_nd_idx):
-            with tensor._placement_scope():
-                return _LogicalSlice(
-                    tensor._blob_object, start_nd_idx, stop_nd_idx, None
+            start_nd_idx = list(map(int, start_nd_idx))
+            stop_nd_idx = list(map(int, stop_nd_idx))
+            return tensor[
+                tuple(
+                    [
+                        slice(start_nd_idx[i], stop_nd_idx[i])
+                        for i in range(len(start_nd_idx))
+                    ]
                 )
+            ].numpy()
 
         yield from _ForEachSlice(container, ReadFromTensor)
     elif isinstance(container, EagerBlobTrait):
diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index 408898f14..1b5d5b4a0 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -621,7 +621,7 @@ class Tensor:
         if internal_tensor.is_consistent:
             TODO()
 
-        if isinstance(other, Tensor):
+        if isinstance(other, (Tensor, check_point_v2.FileBackendVariableBlob)):
             src_np = other.numpy()
         else:
             assert isinstance(other, np.ndarray)
diff --git a/oneflow/python/test/modules/test_module.py b/oneflow/python/test/modules/test_module.py
index 6f7576b1f..670b22c8e 100644
--- a/oneflow/python/test/modules/test_module.py
+++ b/oneflow/python/test/modules/test_module.py
@@ -17,6 +17,7 @@ import collections.abc
 from itertools import repeat
 import unittest
 from typing import Tuple, Union
+import tempfile
 
 import numpy as np
 
@@ -166,6 +167,27 @@ class TestModule(flow.unittest.TestCase):
 
         test_case.assertEqual(module_num, 2)
 
+    def test_save_state_dict(test_case):
+        class CustomModule(flow.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.param1 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))
+                self.param2 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))
+
+            def forward(self):
+                return self.param1 + self.param2
+
+        m = CustomModule()
+
+        res1 = m()
+        state_dict = m.state_dict()
+        with tempfile.TemporaryDirectory() as save_dir:
+            flow.save(state_dict, save_dir)
+            loaded_state_dict = flow.load(save_dir)
+            m.load_state_dict(loaded_state_dict)
+        res2 = m()
+        test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
+
 
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab