Skip to content
Snippets Groups Projects
Unverified Commit 4895d9b6 authored by leaves-zwx's avatar leaves-zwx Committed by GitHub
Browse files

Fix issues in point of MultiClientSession (#5469)


* address review

* lazy init scope stack in single-client, instantly init scope stack after MultiClientSession created in multi-client

* fix typo

* address review

* fix clear default session

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 6e2e6693
No related branches found
No related tags found
No related merge requests found
......@@ -65,6 +65,7 @@ oneflow._oneflow_internal.RegisterGILForeignLockHelper()
import oneflow.python.framework.env_util as env_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_util as scope_util
from oneflow.python.framework.session_util import Session
from oneflow.python.framework.multi_client_session import MultiClientSession
......@@ -75,6 +76,7 @@ if env_util.HasAllMultiClientEnvVars():
session_ctx.OpenDefaultSession(
MultiClientSession(oneflow._oneflow_internal.NewSessionId())
)
scope_util.InitScopeStack()
else:
oneflow._oneflow_internal.SetIsMultiClient(False)
env_util.init_default_physical_env()
......
......@@ -80,8 +80,6 @@ def env_init():
scope_util.InitScopeStack()
else:
exit(0)
else:
scope_util.InitScopeStack()
return True
......
......@@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import enum
import inspect
from google.protobuf import text_format
import oneflow._oneflow_internal
......@@ -21,9 +23,12 @@ import oneflow.python.framework.c_api_util as c_api_util
class MultiClientSession(object):
class Status(enum.Enum):
CREATED = 1
INITED = 2
CLOSED = 3
def __init__(self, sess_id):
self.is_inited_ = False
self.is_closed_ = False
self.sess_ = oneflow._oneflow_internal.RegsiterSession(sess_id)
oneflow._oneflow_internal.CreateMultiClientSessionContext()
self.config_proto_ = self._make_config_proto()
......@@ -34,20 +39,28 @@ class MultiClientSession(object):
self.scope_attr_name2default_val_ = {}
self._update_scope_attr_name2defaultVal()
self.status_ = self.Status.CREATED
def __del__(self):
self.TryClose()
def TryInit(self):
if not self.is_inited_:
self._check_status(self.Status.CREATED, self.Status.INITED)
if self.status_ == self.Status.CREATED:
config_proto_str = text_format.MessageToString(self.config_proto)
oneflow._oneflow_internal.InitMultiClientSessionContext(config_proto_str)
self.is_inited_ = True
self.status_ = self.Status.INITED
def TryClose(self):
if not self.is_closed_:
if self.status_ != self.Status.CLOSED:
oneflow._oneflow_internal.DestroyMultiClientSessionContext()
oneflow._oneflow_internal.ClearSessionById(self.id)
self.is_closed_ = True
def __del__(self):
self.TryClose()
self.status_ = self.Status.CLOSED
@property
def status(self):
return self.status_
@property
def id(self):
......@@ -55,10 +68,13 @@ class MultiClientSession(object):
@property
def config_proto(self):
if self.config_proto_ is None:
self.config_proto_ = job_set_util.ConfigProto()
return self.config_proto_
@property
def resource(self):
self._check_status(self.Status.INITED)
return c_api_util.CurrentResource()
@property
def function_flag_name2default_val(self):
return self.function_flag_name2default_val_
......@@ -67,8 +83,34 @@ class MultiClientSession(object):
def scope_attr_name2default_val(self):
return self.scope_attr_name2default_val_
# compatible with single client session
@property
def is_running(self):
return self.status_ == self.Status.INITED
# compatible with single client session
def AnyGlobalFunctionDefined(self):
return False
def _check_status(self, *status):
check_success = False
for stat in status:
if self.status_ == stat:
check_success = True
break
if check_success is False:
caller_func_name = inspect.stack()[1].function
allowed_status = " or ".join([str(stat) for stat in status])
raise ValueError(
"The calling to {} is only allowed when status is {}, but current status is {}".format(
caller_func_name, allowed_status, self.status_
)
)
def _make_config_proto(self):
config_proto = job_set_util.ConfigProto()
config_proto.resource.SetInParent()
config_proto.session_id = self.id
return config_proto
......
......@@ -482,8 +482,12 @@ def api_clear_default_session() -> None:
@enable_if.condition(hob.in_normal_mode)
def clear_default_session():
session_ctx.TryCloseDefaultSession()
session_ctx.OpenDefaultSession(Session(oneflow._oneflow_internal.NewSessionId()))
is_multi_client = oneflow._oneflow_internal.IsMultiClient()
if not is_multi_client:
session_ctx.TryCloseDefaultSession()
session_ctx.OpenDefaultSession(
Session(oneflow._oneflow_internal.NewSessionId())
)
@oneflow_export("sync_default_session")
......@@ -516,9 +520,6 @@ def _GetDefaultConfigProto():
return config_proto
session_ctx.OpenDefaultSession(Session(oneflow._oneflow_internal.NewSessionId()))
@oneflow_export("InitEagerGlobalSession")
def TmpInitEagerGlobalSession():
config_pb = _GetDefaultConfigProto()
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12139"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
import oneflow
import oneflow.experimental as flow
import oneflow.python.framework.session_context as session_ctx
from oneflow.python.framework.multi_client_session import MultiClientSession
class TestMultiClientSession(unittest.TestCase):
def test_case1(self):
# print("test_case1")
self.assertTrue(flow.distributed.is_multi_client())
# print(f"is_multi_client: {flow.distributed.is_multi_client()}")
sess = session_ctx.GetDefaultSession()
# print(f"sess type: {type(sess)}")
self.assertTrue(isinstance(sess, MultiClientSession))
sess.TryInit()
self.assertEqual(sess.status, sess.Status.INITED)
# sess.TryClose()
# self.assertEqual(sess.status, sess.Status.CLOSED)
def test_case2(self):
print("test_case2")
self.assertTrue(flow.distributed.is_multi_client())
sess = session_ctx.GetDefaultSession()
self.assertTrue(isinstance(sess, MultiClientSession))
sess.TryInit()
self.assertEqual(sess.status, sess.Status.INITED)
sess.TryClose()
self.assertEqual(sess.status, sess.Status.CLOSED)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment