Skip to content
Snippets Groups Projects
Unverified Commit 0b5dc89b authored by cheng cheng's avatar cheng cheng Committed by GitHub
Browse files

LazyInterpret build LocalTensor if input is local (#5582)


* LazyInterpret build LocalTensor if input is local

* refine graph default scope

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent bd365af1
No related branches found
No related tags found
No related merge requests found
......@@ -109,7 +109,8 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FeedInputOpExpr& op_expr, const Ten
const auto& blob_attr = JUST(compatible_py::GetOpArgBlobAttribute(op_attr, obn));
CHECK_OR_RETURN(!outputs->at(0).get());
(*outputs)[0] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /*is_lazy=*/true));
(*outputs)[0] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /* is_lazy= */ true,
/* is_local= */ input_tensor->is_local()));
TensorNameScope::Global()->Record(outputs->at(0), GenLogicalBlobName(op_conf.name(), obn));
return Maybe<void>::Ok();
}
......@@ -163,7 +164,8 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FeedVariableOpExpr& op_expr, const
const auto& blob_attr = JUST(compatible_py::GetOpArgBlobAttribute(op_attr, obn));
CHECK_OR_RETURN(!outputs->at(0).get());
(*outputs)[0] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /*is_lazy=*/true));
(*outputs)[0] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /* is_lazy= */ true,
/* is_local */ input_tensor->is_local()));
// NOTE(chengcheng): Record variable op output LazyTenosr
TensorNameScope::Global()->Record(outputs->at(0), GenLogicalBlobName(op_conf.name(), obn));
// NOTE(chengcheng): Record EagerTensor as variable tensor name
......@@ -178,8 +180,6 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FetchOutputOpExpr& op_expr, const T
CHECK_EQ_OR_RETURN(op_expr.input_size(), 1);
const std::shared_ptr<Tensor>& input_tensor = inputs.at(0);
CHECK_OR_RETURN(input_tensor->is_lazy());
// NOTE(chengcheng): Lazy always consistent.
CHECK_OR_RETURN(input_tensor->is_consistent());
const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor);
CHECK_OR_RETURN(!input_lbn.empty()); // lbn must exist.
......@@ -223,7 +223,8 @@ Maybe<void> LazyInterpreter::ApplyImpl(const FetchOutputOpExpr& op_expr, const T
CHECK_OR_RETURN(!outputs->at(0).get());
// TODO(chengcheng): Build EagerLocalTensor if parllel attr is this rank.
(*outputs)[0] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /*is_lazy=*/false));
(*outputs)[0] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /* is_lazy= */ false,
/* is_local= */ input_tensor->is_local()));
return Maybe<void>::Ok();
}
......@@ -248,10 +249,12 @@ Maybe<void> LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu
// if inputs size == 0, need handle in SourceUserOp impl.
CHECK_GE_OR_RETURN(inputs.size(), 1);
const std::string device_tag = GetDeviceTagOfTensor(inputs.at(0));
const bool is_local = inputs.at(0)->is_local();
op_conf->set_device_tag(device_tag);
for (int i = 0; i < inputs.size(); ++i) {
const auto& input_tensor = inputs.at(i);
CHECK_OR_RETURN(device_tag == GetDeviceTagOfTensor(input_tensor));
CHECK_EQ_OR_RETURN(is_local, input_tensor->is_local());
const std::string& ibn = op_expr.indexed_ibns().at(i);
const std::string& lbn = TensorNameScope::Global()->Lookup(inputs[i]);
if (lbn.empty()) {
......@@ -300,7 +303,8 @@ Maybe<void> LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu
JUST(compatible_py::GetOpArgParallelAttribute(blob_parallel_desc_sym, op_attr, obn));
const auto& blob_attr = JUST(compatible_py::GetOpArgBlobAttribute(op_attr, obn));
if (!(outputs->at(i).get())) {
(*outputs)[i] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr, /*is_lazy=*/true));
(*outputs)[i] = JUST(OpInterpUtil::BuildTensor(blob_attr, parallel_attr,
/* is_lazy= */ true, is_local));
} else {
// TODO(chengcheng, hjchen2) Reset shape, dtype and so on for InplaceUserOp.
UNIMPLEMENTED();
......
......@@ -101,10 +101,10 @@ template<>
/* static */ Maybe<Tensor> OpInterpUtil::BuildTensor(
const std::shared_ptr<compatible_py::OpArgBlobAttribute>& blob_attr,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& parallel_attr,
const bool is_lazy) {
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& parallel_attr, const bool is_lazy,
const bool is_local) {
const auto& dtype = DataType(blob_attr->get_dtype());
if (parallel_attr->is_mirrored()) {
if (is_local) {
const auto& device =
JUST(Device::MakeDeviceByParallelDesc(*parallel_attr->parallel_desc_symbol()));
const auto& tensor = JUST(MirroredTensor::MakeTensor(
......
......@@ -55,7 +55,7 @@ class OpInterpUtil {
static Maybe<Tensor> BuildTensor(
const std::shared_ptr<compatible_py::OpArgBlobAttribute>& blob_attr,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& parallel_attr,
const bool is_lazy);
const bool is_lazy, const bool is_local);
};
} // namespace one
......
......@@ -36,10 +36,10 @@ lazy_mode = oneflow._oneflow_internal.lazy_mode
@contextmanager
def graph_build_context(config_proto, session):
prev_scope = oneflow._oneflow_internal.GetCurrentScope()
device_tag_and_ids = placement_util.GetDefaultMachineDeviceIds(session.resource)
new_scope = scope_util.MakeInitialScope(
config_proto,
*device_tag_and_ids,
"cpu", # NOTE(chengcheng): graph init scope is useless, just set cpu 0:0 for test.
["0:0"],
None, # TODO(): set hierarchy from user graph config
False, # is_mirrored
)
......
......@@ -72,7 +72,7 @@ class TestFeedInputTensor(unittest.TestCase):
out_tensor = input_op.apply([x_tensor_in_c], attrs)[0]
test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))
test_case.assertTrue(out_tensor.is_lazy)
test_case.assertTrue(out_tensor.is_consistent)
test_case.assertTrue(out_tensor.is_local)
if __name__ == "__main__":
......
......@@ -82,12 +82,12 @@ class TestFetchOutputTensor(unittest.TestCase):
lazy_tensor = input_op.apply([x_tensor_in_c], attrs)[0]
test_case.assertEqual(lazy_tensor.shape, (1, 1, 10, 10))
test_case.assertTrue(lazy_tensor.is_lazy)
test_case.assertTrue(lazy_tensor.is_consistent)
test_case.assertTrue(lazy_tensor.is_local)
eager_tensor = output_op.apply([lazy_tensor], attrs)[0]
test_case.assertEqual(eager_tensor.shape, (1, 1, 10, 10))
test_case.assertTrue(not eager_tensor.is_lazy)
test_case.assertTrue(eager_tensor.is_consistent)
test_case.assertTrue(eager_tensor.is_local)
if __name__ == "__main__":
......
......@@ -72,7 +72,7 @@ class TestFeedVariableTensor(unittest.TestCase):
out_tensor = var_op.apply([x_tensor_in_c], attrs)[0]
test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))
test_case.assertTrue(out_tensor.is_lazy)
test_case.assertTrue(out_tensor.is_consistent)
test_case.assertTrue(out_tensor.is_local)
if __name__ == "__main__":
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment