From b0c3d7e2a861fd34c994df0f3df2a5f01cdf5f16 Mon Sep 17 00:00:00 2001
From: cheng cheng <472491134@qq.com>
Date: Fri, 16 Jul 2021 03:38:12 +0800
Subject: [PATCH] LazyInterpret for FeedVariableOpExpr (#5490)

Co-authored-by: Xiaoyu Xu <xiaoyulink@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 .../op_interpreter/lazy_op_interpreter.cpp    | 60 ++++++++++++-
 .../test/graph/test_variable_op_expr.py       | 89 +++++++++++++++++++
 2 files changed, 146 insertions(+), 3 deletions(-)
 create mode 100644 oneflow/python/test/graph/test_variable_op_expr.py

diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp
index 9ad410d33..785090c9c 100644
--- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp
+++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp
@@ -113,9 +113,63 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FeedInputOpExpr& op_expr, const Ten
 
 Maybe<void> LazyInterpreter::ApplyImpl(const FeedVariableOpExpr& op_expr, const TensorTuple& inputs,
                                        TensorTuple* outputs, const OpExprInterpContext& ctx) const {
-  // TODO(chengcheng)
-  OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name()
-                     << " has not been supported in LazyInterpreter::Apply.";
+  // NOTE(chengcheng): inputs[0] is the EagerTensor
+  CHECK_EQ_OR_RETURN(inputs.size(), 1);
+  CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);
+  const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);
+  CHECK_OR_RETURN(input_tensor->is_eager());
+
+  const auto& scope = JUST(GetCurrentScope());
+  int64_t scope_symbol_id = JUST(scope->symbol_id());
+
+  OperatorConf op_conf;
+  op_conf.set_name(op_expr.op_name());           // construct by python nn.Graph
+  op_conf.set_scope_symbol_id(scope_symbol_id);  // TODO(chengcheng): NewScope by cur scope.
+  op_conf.set_device_tag(GetDeviceTagOfTensor(input_tensor));
+  // NOTE(chengcheng):
+  //   We contruct VariableOpConf instead of FeedVariableOpConf because FeedVariableOpExpr JUST
+  //   for getting input EagerTensor.
+  VariableOpConf* var_conf = op_conf.mutable_variable_conf();
+  var_conf->set_out("out");
+  input_tensor->shape()->ToProto(var_conf->mutable_shape());
+  var_conf->set_data_type(input_tensor->dtype());
+  // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited
+  //   by EagerTensor.
+  var_conf->mutable_initializer()->mutable_empty_conf();
+  if (input_tensor->is_consistent()) {
+    // TODO(chengcheng): GenerateParallelDistributionString by tensor.
+  }
+  if (!input_tensor->requires_grad()) { var_conf->set_trainable(false); }
+  // TODO(chengcheng, xuxiaoyu): Set L1/L2 RegularizerConf by nn.Graph Optimizer
+
+  auto infer_ctx = JUST(GetCurInferCtx());
+  OpAttribute op_attr = *JUST(infer_ctx->AddAndInferConsistentOp(op_conf));
+
+  const std::string& op_name = op_conf.name();
+
+  // temp debug log
+  std::cout << "cclog: Lazy nn.Graph AddOpName: " << op_name << std::endl
+            << " and the origin op_conf is :" << op_conf.DebugString();
+
+  int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(op_conf));
+  const std::shared_ptr<ParallelDesc>& blob_parallel_desc_sym =
+      JUST(GetSymbol<cfg::ParallelConf, ParallelDesc>(parallel_desc_sym_id));
+
+  // Check outputs num and setup output tensor properties.
+  CHECK_EQ_OR_RETURN(outputs->size(), 1);
+  CHECK_EQ_OR_RETURN(op_expr.output_size(), 1);
+
+  const std::string obn = "out";  // NOTE(chengcheng): obn is NOT op_expr.indexed_obns
+  const auto& parallel_attr =
+      JUST(compatible_py::GetOpArgParallelAttribute(blob_parallel_desc_sym, op_attr, obn));
+  const auto& blob_attr = JUST(compatible_py::GetOpArgBlobAttribute(op_attr, obn));
+
+  CHECK_OR_RETURN(!outputs->at(0).get());
+  (*outputs)[0] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /*is_lazy=*/true));
+  // NOTE(chengcheng): Record variable op output LazyTenosr
+  TensorNameScope::Global()->Record(outputs->at(0), op_name + "/" + obn);
+  // NOTE(chengcheng): Record EagerTensor as variable tensor name
+  TensorNameScope::Global()->Record(input_tensor, op_name + "/" + obn);
   return Maybe<void>::Ok();
 }
 
diff --git a/oneflow/python/test/graph/test_variable_op_expr.py b/oneflow/python/test/graph/test_variable_op_expr.py
new file mode 100644
index 000000000..952666628
--- /dev/null
+++ b/oneflow/python/test/graph/test_variable_op_expr.py
@@ -0,0 +1,89 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+
+import numpy as np
+import os
+
+os.environ["MASTER_ADDR"] = "127.0.0.1"
+os.environ["MASTER_PORT"] = "12139"
+os.environ["WORLD_SIZE"] = "1"
+os.environ["RANK"] = "0"
+os.environ["LOCAL_RANK"] = "0"
+
+import oneflow
+import oneflow.experimental as flow
+import oneflow.python.framework.session_context as session_ctx
+import oneflow._oneflow_internal
+from oneflow.python.framework.multi_client_session import MultiClientSession
+import oneflow.python.framework.c_api_util as c_api_util
+
+
+@flow.unittest.skip_unless_1n1d()
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    "default use eager mode to test this case",
+)
+class TestFeedVariableTensor(unittest.TestCase):
+    def test_feed_var_tensor(test_case):
+        test_case.assertTrue(oneflow.distributed.is_multi_client())
+        test_case.assertTrue(
+            oneflow.python.framework.env_util.HasAllMultiClientEnvVars()
+        )
+
+        x = flow.Tensor(1, 1, 10, 10)
+        flow.nn.init.uniform_(x, a=-1.0, b=1.0)
+
+        session = session_ctx.GetDefaultSession()
+        test_case.assertTrue(isinstance(session, MultiClientSession))
+        session.TryInit()
+
+        with oneflow._oneflow_internal.lazy_mode.gard(True):
+
+            oneflow._oneflow_internal.JobBuildAndInferCtx_Open(
+                "cc_test_variable_op_expr_job"
+            )
+            job_conf = (
+                oneflow._oneflow_internal.oneflow.core.job.job_conf.JobConfigProto()
+            )
+            job_conf.set_job_name("cc_test_variable_op_expr_job")
+            job_conf.mutable_predict_conf()
+            c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
+
+            op_name = "cc_Variable_0"
+            var_conf = (
+                oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedVariableOpConf()
+            )
+            var_conf.set_in_0("EagerTensorInput")
+            var_conf.set_out_0("out_0")
+
+            var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(
+                op_name, var_conf, ["in_0"], ["out_0"]
+            )
+            attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
+
+            if not x.is_determined:
+                x.determine()
+            x_tensor_in_c = x._local_or_consistent_tensor
+
+            out_tensor = var_op.apply([x_tensor_in_c], attrs)[0]
+            test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))
+            test_case.assertTrue(out_tensor.is_lazy)
+            test_case.assertTrue(out_tensor.is_consistent)
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab