diff --git a/oneflow/core/eager/eager_oneflow.cpp b/oneflow/core/eager/eager_oneflow.cpp index 491fd2f1ae455b88d343f24c2a8a3958d1903b12..975153da1c823f818f371cb56ddbcb9008aaf353 100644 --- a/oneflow/core/eager/eager_oneflow.cpp +++ b/oneflow/core/eager/eager_oneflow.cpp @@ -24,6 +24,7 @@ limitations under the License. #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/vm/string_symbol.h" #include "oneflow/core/eager/eager_symbol.cfg.h" +#include "oneflow/core/job/env_desc.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/cluster_instruction.h" @@ -34,7 +35,6 @@ limitations under the License. #include "oneflow/core/operator/op_conf_symbol.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/util.h" -#include "oneflow/api/python/env/env.h" namespace oneflow { namespace vm { @@ -94,7 +94,7 @@ Maybe<void> EagerOneflow::RunPhysicalInstruction(vm::InstructionMsgList* instruc Maybe<void> EagerOneflow::RunLogicalInstruction(vm::InstructionMsgList* instruction_list, const vm::cfg::EagerSymbolList& eager_symbol_list) { - if (JUST(IsMultiClient())) { + if (JUST(GlobalMultiClientEnv())) { // NOTE(chengcheng): in Multi-Client LogicalRun will degenerate directly to PhysicalRun, // because each rank will process instructions ONLY from itself, NOT the master. return RunPhysicalInstruction(instruction_list, eager_symbol_list); diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 81a0ebd5d3382c49ba73c1d2debe7463f297a348..2901ec2ef15e7c25162a849ae20d6cac6bc307ac 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -39,7 +39,7 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instruction_replay.h" -#include "oneflow/api/python/env/env.h" +#include "oneflow/core/job/env_desc.h" namespace oneflow { @@ -1578,7 +1578,7 @@ InstructionsBuilder::GetMut2OperandBlobObjects( } Maybe<void> LogicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) { - if (JUST(IsMultiClient())) { + if (JUST(GlobalMultiClientEnv())) { // NOTE(chengcheng): in Multi-Client LogicalRun will degenerate directly to PhysicalRun, // because each rank will process instructions ONLY from itself, NOT the master. return PhysicalRun(Build); diff --git a/oneflow/core/framework/multi_client_session_context.cpp b/oneflow/core/framework/multi_client_session_context.cpp index 8ebca1ebfdccc1185e599ba796c2b26f6ca4130f..4227352769bfee3304bb627fcbc2acbd45a3d9b6 100644 --- a/oneflow/core/framework/multi_client_session_context.cpp +++ b/oneflow/core/framework/multi_client_session_context.cpp @@ -23,7 +23,6 @@ limitations under the License. #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/rpc/include/global_process_ctx.h" -#include "oneflow/api/python/env/env.h" #ifdef WITH_CUDA #include <cuda.h> #endif // WITH_CUDA @@ -64,7 +63,7 @@ MultiClientSessionContext::~MultiClientSessionContext() { Maybe<void> MultiClientSessionContext::TryInit(const ConfigProto& config_proto) { if (!is_inited_) { - CHECK_OR_RETURN(JUST(IsMultiClient())); + CHECK_OR_RETURN(JUST(GlobalMultiClientEnv())); DumpVersionInfo(); Resource resource = config_proto.resource(); diff --git a/oneflow/core/job/env_desc.cpp b/oneflow/core/job/env_desc.cpp index ecafa5c72fd071e96d480fd6ec1f9e26b9d64597..064e4e13c500fb9a08e7b8ab535db08139c20f69 100644 --- a/oneflow/core/job/env_desc.cpp +++ b/oneflow/core/job/env_desc.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/env_desc.h" +#include "oneflow/core/job/global_for.h" namespace oneflow { @@ -49,4 +50,10 @@ int64_t EnvDesc::GetMachineId(const std::string& addr) const { return machine_id; } +Maybe<bool> GlobalMultiClientEnv() { + Maybe<bool>* is_multi_client = Global<Maybe<bool>, MultiClient>::Get(); + CHECK_NOTNULL_OR_RETURN(is_multi_client); + return *is_multi_client; +} + } // namespace oneflow diff --git a/oneflow/core/job/env_desc.h b/oneflow/core/job/env_desc.h index 4733202aa002b62ad1f74255814d48ef77bcf73d..9389471c4561b6cfe6957f74075ce3fae047455e 100644 --- a/oneflow/core/job/env_desc.h +++ b/oneflow/core/job/env_desc.h @@ -44,6 +44,8 @@ class EnvDesc final { EnvProto env_proto_; }; +Maybe<bool> GlobalMultiClientEnv(); + } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CLUSTER_DESC_H_ diff --git a/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp b/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp index 9133c381c91688f7ecc0ba84f511ff5a8cef7bfd..ca6fbce2d139f38a59915e7cb74d3433e0115589 100644 --- a/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/job/env_desc.h" #include "oneflow/core/common/util.h" #include <json.hpp> @@ -105,10 +106,15 @@ Maybe<void> EagerJobBuildAndInferCtxMgr::VirtualCloseJob() { bool EagerExecutionEnabled() { return *Global<bool, EagerExecution>::Get(); } Maybe<JobBuildAndInferCtxMgr*> GlobalJobBuildAndInferCtxMgr() { - if (EagerExecutionEnabled() && !LazyMode::is_enabled()) { - return JUST(GlobalMaybe<EagerJobBuildAndInferCtxMgr>()); - } else { + if (JUST(GlobalMultiClientEnv())) { return JUST(GlobalMaybe<LazyJobBuildAndInferCtxMgr>()); + } else { + // single-client + if (EagerExecutionEnabled()) { + return JUST(GlobalMaybe<EagerJobBuildAndInferCtxMgr>()); + } else { + return JUST(GlobalMaybe<LazyJobBuildAndInferCtxMgr>()); + } } } diff --git a/oneflow/core/vm/id_generator.cpp b/oneflow/core/vm/id_generator.cpp index 731f818b061ae45cf34f7582e7b501da2db4b0a9..ca8a9c1e8556176622fe0e7924f65acef3931dda 100644 --- a/oneflow/core/vm/id_generator.cpp +++ b/oneflow/core/vm/id_generator.cpp @@ -16,13 +16,13 @@ limitations under the License. #include "oneflow/core/vm/id_generator.h" #include "oneflow/core/vm/id_util.h" #include "oneflow/core/control/global_process_ctx.h" -#include "oneflow/api/python/env/env.h" +#include "oneflow/core/job/env_desc.h" namespace oneflow { namespace vm { Maybe<int64_t> LogicalIdGenerator::NewSymbolId() { - if (JUST(IsMultiClient())) { + if (JUST(GlobalMultiClientEnv())) { // NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to // PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master. return IdUtil::NewPhysicalSymbolId(GlobalProcessCtx::Rank()); @@ -32,7 +32,7 @@ Maybe<int64_t> LogicalIdGenerator::NewSymbolId() { } Maybe<int64_t> LogicalIdGenerator::NewObjectId() { - if (JUST(IsMultiClient())) { + if (JUST(GlobalMultiClientEnv())) { // NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to // PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master. return IdUtil::NewPhysicalObjectId(GlobalProcessCtx::Rank()); diff --git a/oneflow/python/test/graph/test_graph.py b/oneflow/python/test/graph/test_graph.py index e64633037ee776292a9b495c3484daf4ff0445e6..816ab199c5484393979860daf1fc82d44c8f605e 100644 --- a/oneflow/python/test/graph/test_graph.py +++ b/oneflow/python/test/graph/test_graph.py @@ -59,10 +59,6 @@ class CustomModule(flow.nn.Module): @flow.unittest.skip_unless_1n1d() -@unittest.skipIf( - not flow.unittest.env.eager_execution_enabled(), - ".numpy() doesn't work in lazy mode", -) class TestGraph(flow.unittest.TestCase): def test_add_nested_module(test_case): x = flow.Tensor(1, 1, 10, 10) diff --git a/oneflow/python/test/graph/test_input_op_expr.py b/oneflow/python/test/graph/test_input_op_expr.py index 22011b046697b6f6e3adc5f98da07eee4bab9241..befbb062495418fbfeb03656c04b5db5b80f7bc9 100644 --- a/oneflow/python/test/graph/test_input_op_expr.py +++ b/oneflow/python/test/graph/test_input_op_expr.py @@ -33,10 +33,6 @@ import oneflow.python.framework.c_api_util as c_api_util @flow.unittest.skip_unless_1n1d() -@unittest.skipIf( - not flow.unittest.env.eager_execution_enabled(), - "default use eager mode to test this case", -) class TestFeedInputTensor(unittest.TestCase): def test_feed_input_tensor(test_case): test_case.assertTrue(oneflow.distributed.is_multi_client()) diff --git a/oneflow/python/test/graph/test_variable_op_expr.py b/oneflow/python/test/graph/test_variable_op_expr.py index 9526666288c3527cfdb20d480910dba429a4d8f4..8e9fcd9e6e06db936a569f32cc7eac40c3c64492 100644 --- a/oneflow/python/test/graph/test_variable_op_expr.py +++ b/oneflow/python/test/graph/test_variable_op_expr.py @@ -33,10 +33,6 @@ import oneflow.python.framework.c_api_util as c_api_util @flow.unittest.skip_unless_1n1d() -@unittest.skipIf( - not flow.unittest.env.eager_execution_enabled(), - "default use eager mode to test this case", -) class TestFeedVariableTensor(unittest.TestCase): def test_feed_var_tensor(test_case): test_case.assertTrue(oneflow.distributed.is_multi_client())