Skip to content
Snippets Groups Projects
Unverified Commit dbd9d76e authored by daquexian's avatar daquexian Committed by GitHub
Browse files

fix flow.save (#4941)


Signed-off-by: default avatardaquexian <daquexian566@gmail.com>
parent 7911d713
No related branches found
No related tags found
No related merge requests found
......@@ -206,10 +206,16 @@ def _ReadSlice(
if isinstance(container, oneflow.Tensor):
def ReadFromTensor(tensor, start_nd_idx, stop_nd_idx):
with tensor._placement_scope():
return _LogicalSlice(
tensor._blob_object, start_nd_idx, stop_nd_idx, None
start_nd_idx = list(map(int, start_nd_idx))
stop_nd_idx = list(map(int, stop_nd_idx))
return tensor[
tuple(
[
slice(start_nd_idx[i], stop_nd_idx[i])
for i in range(len(start_nd_idx))
]
)
].numpy()
yield from _ForEachSlice(container, ReadFromTensor)
elif isinstance(container, EagerBlobTrait):
......
......@@ -621,7 +621,7 @@ class Tensor:
if internal_tensor.is_consistent:
TODO()
if isinstance(other, Tensor):
if isinstance(other, (Tensor, check_point_v2.FileBackendVariableBlob)):
src_np = other.numpy()
else:
assert isinstance(other, np.ndarray)
......
......@@ -17,6 +17,7 @@ import collections.abc
from itertools import repeat
import unittest
from typing import Tuple, Union
import tempfile
import numpy as np
......@@ -166,6 +167,27 @@ class TestModule(flow.unittest.TestCase):
test_case.assertEqual(module_num, 2)
def test_save_state_dict(test_case):
class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.param1 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))
self.param2 = flow.nn.Parameter(flow.Tensor(32, 1024, 1024))
def forward(self):
return self.param1 + self.param2
m = CustomModule()
res1 = m()
state_dict = m.state_dict()
with tempfile.TemporaryDirectory() as save_dir:
flow.save(state_dict, save_dir)
loaded_state_dict = flow.load(save_dir)
m.load_state_dict(loaded_state_dict)
res2 = m()
test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment