diff --git a/oneflow/api/python/framework/nn_graph.cpp b/oneflow/api/python/framework/nn_graph.cpp
index f6fd51b1cace79e57ca06a84aea29369f9583b68..6e63bdd4f1f318b5640ad26358d92394d0e4d3cc 100644
--- a/oneflow/api/python/framework/nn_graph.cpp
+++ b/oneflow/api/python/framework/nn_graph.cpp
@@ -15,7 +15,9 @@ limitations under the License.
 */
 #include <pybind11/pybind11.h>
 #include <string>
+#include "oneflow/api/python/job_build/job_build_and_infer.h"
 #include "oneflow/api/python/of_api_registry.h"
+#include "oneflow/core/framework/tensor.h"
 #include "oneflow/core/framework/nn_graph.h"
 #include "oneflow/core/job/runtime.h"
 #include "oneflow/core/register/blob.h"
@@ -23,9 +25,9 @@ limitations under the License.
 namespace py = pybind11;
 
 namespace oneflow {
-ONEFLOW_API_PYBIND11_MODULE("", m) {
+ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) {
   using namespace oneflow;
-  py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, "NNGraph")
+  py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, "CNNGraph")
       .def(py::init<const std::string&>())
       .def_property_readonly("name", &NNGraph::job_name)
       .def("register_input_op_names",
@@ -51,5 +53,7 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
                              const std::shared_ptr<NNGraph>& nn_graph) {
     return RunLazyNNGraph(inputs, outputs, parameters, nn_graph).GetOrThrow();
   });
+  m.def("AddTensorAsGraphLoss",
+        [](const std::shared_ptr<one::Tensor>& t) { return AddTensorAsGraphLoss(t).GetOrThrow(); });
 }
 }  // namespace oneflow
diff --git a/oneflow/api/python/job_build/job_build_and_infer.h b/oneflow/api/python/job_build/job_build_and_infer.h
index 05896ac6ea31d89420e1d9f33cc5b7333f61fde6..3d88c27d10980c0fe7a84052f53eed13dc2de409 100644
--- a/oneflow/api/python/job_build/job_build_and_infer.h
+++ b/oneflow/api/python/job_build/job_build_and_infer.h
@@ -18,11 +18,14 @@ limitations under the License.
 
 #include "oneflow/core/job/global_for.h"
 #include "oneflow/core/common/protobuf.h"
+#include "oneflow/core/framework/tensor.h"
+#include "oneflow/core/framework/tensor_name_scope.h"
 #include "oneflow/core/job/job_build_and_infer_ctx.h"
-#include "oneflow/core/record/record.pb.h"
 #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
 #include "oneflow/core/job/job.pb.h"
 #include "oneflow/core/job/job_conf.cfg.h"
+#include "oneflow/core/job/lazy_mode.h"
+#include "oneflow/core/record/record.pb.h"
 
 namespace oneflow {
 
@@ -166,6 +169,14 @@ inline Maybe<std::string> JobBuildAndInferCtx_GetOpBlobLbn(const std::string& jo
   return job_ctx->GetOpBlobLbn(op_name, bn_in_op);
 }
 
+inline Maybe<void> AddTensorAsGraphLoss(const std::shared_ptr<one::Tensor>& t) {
+  CHECK_OR_RETURN(t->is_lazy());
+  CHECK_OR_RETURN(LazyMode::is_enabled());
+  const std::string& loss_lbn = one::TensorNameScope::Global()->Lookup(t);
+  CHECK_OR_RETURN("" != loss_lbn);
+  return JUST(GetCurInferCtx())->AddLossLogicalBlobName(loss_lbn);
+}
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_
diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp
index 2f980ebd9e0d0e2839f4126984790a9ba410d122..5024dbde7c8a05902412e8e3006b6f29bde4831c 100644
--- a/oneflow/core/framework/tensor.cpp
+++ b/oneflow/core/framework/tensor.cpp
@@ -14,7 +14,10 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 #include "oneflow/core/framework/tensor.h"
+#include "oneflow/core/common/maybe.h"
 #include "oneflow/core/job/parallel_desc.h"
+#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
+#include "oneflow/core/job/job_build_and_infer_ctx.h"
 #include "oneflow/core/framework/device.h"
 #include "oneflow/core/framework/dtype.h"
 #include "oneflow/core/framework/tensor_tuple.h"
diff --git a/oneflow/core/framework/tensor_name_scope.cpp b/oneflow/core/framework/tensor_name_scope.cpp
index ecd84b1140588cf2b831d6f2cbb38ec93ddbb2f3..ff0c484cb814f532fefb82a9fc95c0d3f57f4ec8 100644
--- a/oneflow/core/framework/tensor_name_scope.cpp
+++ b/oneflow/core/framework/tensor_name_scope.cpp
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 #include "oneflow/core/framework/tensor_name_scope.h"
+#include <cstdint>
 
 namespace oneflow {
 namespace one {
@@ -24,8 +25,8 @@ namespace one {
 }
 
 const std::string& TensorNameScope::Lookup(const std::shared_ptr<Tensor>& tensor) const {
-  std::lock_guard<std::mutex> lock(mutex_);
   uint64_t key = reinterpret_cast<uint64_t>(tensor.get());
+  std::lock_guard<std::mutex> lock(mutex_);
   const auto& it = tensor_names_.find(key);
   if (it != tensor_names_.end()) {
     return it->second;
diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index 58230af9a6c02e3ad42d672155cebc19fabde738..9e2cf2466e991dcd0c428f14faecf0e05740af18 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -19,6 +19,7 @@ from oneflow._oneflow_internal.exception import IndexException
 from oneflow.python.oneflow_export import oneflow_export
 import oneflow.python.framework.remote_blob as remote_blob_util
 import oneflow._oneflow_internal
+import oneflow._oneflow_internal.lazy_mode as lazy_mode
 import numpy as np
 import inspect
 from typing import Union
@@ -397,7 +398,13 @@ class Tensor:
     @_auto_determine
     @register_local_tensor_method()
     def backward(self, gradient=None, retain_graph=False, create_graph=False):
-        flow.autograd.backward(self, gradient, retain_graph, create_graph)
+        if not lazy_mode.is_enabled():
+            flow.autograd.backward(self, gradient, retain_graph, create_graph)
+        else:
+            assert (
+                self.is_lazy
+            ), "nn.Graph only accept lazy tensor to call backward() in lazy mode."
+            flow._oneflow_internal.nn.graph.AddTensorAsGraphLoss(self)
 
     @register_local_tensor_method()
     def _transform_ellipsis_type(self, key):
diff --git a/oneflow/python/nn/graph.py b/oneflow/python/nn/graph.py
index 1f34da4e49f60c98174c4a514aaafdda27be98a1..b3e46e415ee2d089b37c43815a9f29d02308aeba 100644
--- a/oneflow/python/nn/graph.py
+++ b/oneflow/python/nn/graph.py
@@ -15,17 +15,17 @@ limitations under the License.
 """
 from __future__ import absolute_import
 from collections import OrderedDict
+from typing import Dict
 from functools import partial
 
 import oneflow._oneflow_internal
 import oneflow.python.framework.c_api_util as c_api_util
 import oneflow.python.framework.graph_build_util as graph_build_util
 import oneflow.python.framework.session_context as session_ctx
-import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util
 from oneflow._oneflow_internal import Tensor as InternalTensor
 from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.framework.multi_client_session import MultiClientSession
-from oneflow.python.nn.graph_block import Block
+from oneflow.python.nn.graph_block import Block, BlockType
 from oneflow.python.nn.graph_optimizer import OptimizerConfig
 from oneflow.python.nn.module import Module
 from oneflow.python.nn.optimizer.optimizer import Optimizer
@@ -42,11 +42,11 @@ class Graph(object):
         self.config = GraphConfig()
         self._generate_name()
         self.config.proto.set_job_name(self._name)
-        self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name)
+        self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph(self._name)
         self._blocks = OrderedDict()
         self._optimizers = OrderedDict()
         self._is_compiled = False
-        self._state_tensortuple = None
+        self._var2var_op_name = dict()
         self._job_proto = None
 
     @property
@@ -72,8 +72,14 @@ class Graph(object):
         grad_clipping_conf=None,
         weight_decay_conf=None,
     ):
+        assert name is not None, "name cannot be None"
+        assert type(name) is str, "name must be an instance of str"
+        assert optimizer is not None, "optimizer cannot be None"
+        assert isinstance(
+            optimizer, Optimizer
+        ), "optimizer must be an instance of Optimizer"
         self._optimizers[name] = OptimizerConfig(
-            optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
+            name, optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
         )
 
     def _generate_name(self):
@@ -92,18 +98,34 @@ class Graph(object):
             for bu in bu_gen:
                 yield bu
 
+    def _preprocess_state(self):
+        state_list = list()
+        for state_block in self._state():
+            state_list.append(state_block.origin)
+            if state_block.type == BlockType.PARAMETER:
+                self._var2var_op_name[state_block.origin] = (
+                    state_block.name_prefix + state_block.name
+                )
+
+    def _complete_graph_config(self):
+        if len(self._optimizers):
+            self.config._train(True)
+        # TODO(xuxiaoyu): save variable name and it's l2 if optimizer has weight decay
+        # which means to used as l2.
+        for name, opt_config in self._optimizers.items():
+            self.config.add_optimizer_config(opt_config, self._var2var_op_name)
+
     def _compile(self, *args):
         assert not self._is_compiled, (
             "nn.Graph " + self._name + " has already been compiled."
         )
-        state = tuple(s.origin for s in self._state())
-        if len(state) > 0:
-            self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(state)
+
+        self._preprocess_state()
+        self._complete_graph_config()
 
         session = session_ctx.GetDefaultSession()
         assert type(session) is MultiClientSession
         session.TryInit()
-
         with graph_build_util.graph_build_context(self.config.proto, session):
             # Deal with input
             lazy_args = []
@@ -188,6 +210,8 @@ class Graph(object):
             raise KeyError('module name can\'t contain ".", got: {}'.format(name))
         elif name == "":
             raise KeyError('module name can\'t be empty string ""')
+        # TODO(xuxiaoyu): Add dict of Parameter id to Parameter Block, for using id
+        # to query Parameter Block.
         self._blocks[name] = Block("", name, module)
 
     def __setattr__(self, name: str, value=None):
@@ -195,8 +219,8 @@ class Graph(object):
             self._add_block(name, value)
         elif isinstance(value, Optimizer):
             raise AttributeError(
-                "'{}' object are not allowed to set Optimizer attribute named '{}', \
-                 please use add_optimizer(...) instead.".format(
+                "'{}' object are not allowed to set Optimizer attribute named '{}', "
+                "please use add_optimizer(...) instead.".format(
                     type(self).__name__, name
                 )
             )
@@ -243,14 +267,22 @@ class GraphConfig(FunctionConfig):
 
     @property
     def training(self):
-        if self.function_desc.job_config_proto.has_train_conf():
+        if self.proto.has_train_conf():
             return True
-        if self.function_desc.job_config_proto.has_predict_conf():
+        if self.proto.has_predict_conf():
             return False
         raise NotImplementedError
 
     def _train(self, mode: bool = True):
         if mode:
-            self.function_desc.job_config_proto.mutable_train_conf()
+            self.proto.mutable_train_conf()
+            self.proto.mutable_train_conf().set_loss_scale_factor(1.0)
         else:
-            self.function_desc.job_config_proto.mutable_predict_conf()
+            self.proto.mutable_predict_conf()
+
+    def add_optimizer_config(
+        self, optimizer_config: OptimizerConfig = None, var2var_op_name: Dict = None
+    ):
+        optimizer_config.optimizer.add_to_graph_train_config(
+            self.proto.mutable_train_conf(), var2var_op_name
+        )
diff --git a/oneflow/python/nn/optimizer/sgd.py b/oneflow/python/nn/optimizer/sgd.py
index 08f7d2ef5fd736e689cc627a55d28c3607fd3197..35082cf5f038d7d733ab38bc678c87da17be8881 100644
--- a/oneflow/python/nn/optimizer/sgd.py
+++ b/oneflow/python/nn/optimizer/sgd.py
@@ -16,6 +16,7 @@ limitations under the License.
 
 from typing import List, Dict, Callable, Union, Iterator
 import collections
+import math
 
 import oneflow as flow
 
@@ -135,3 +136,29 @@ class SGD(Optimizer):
 
             self._state["step"] = self._state["step"] + 1
             return loss
+
+    def add_to_graph_train_config(self, train_conf, var2var_op_name_dict):
+        for param_group in self.param_groups:
+            optimizer_conf = train_conf.mutable_optimizer_conf().Add()
+            lr = param_group["lr"]
+            beta = param_group["momentum"]
+            scale = param_group["scale"]
+            # TODO(): optimizer_conf need to have loss_scale_factor field to support multi scale factor
+            base_scale = train_conf.loss_scale_factor()
+            assert math.isclose(base_scale, 1, rel_tol=1e-4) or math.isclose(
+                scale, base_scale, rel_tol=1e-4
+            ), "nn.Graph only support one scale factor at the moment, base_scale {} vs scale {}".format(
+                base_scale, scale
+            )
+
+            train_conf.set_loss_scale_factor(scale)
+            optimizer_conf.set_base_learning_rate(lr)
+            if beta == 0:
+                optimizer_conf.mutable_naive_conf()
+            else:
+                optimizer_conf.mutable_momentum_conf().set_beta(beta)
+
+            for param in param_group.parameters:
+                if not param.requires_grad:
+                    continue
+                optimizer_conf.add_variable_op_names(var2var_op_name_dict[param])
diff --git a/oneflow/python/test/graph/test_forward_graph.py b/oneflow/python/test/graph/test_forward_graph.py
index 552644e921f091112ef47fa42b8e55815b0a1c65..02177e77bb52136b0869c8de34c9db0d989c2241 100644
--- a/oneflow/python/test/graph/test_forward_graph.py
+++ b/oneflow/python/test/graph/test_forward_graph.py
@@ -49,7 +49,7 @@ class CustomModule(flow.nn.Module):
 
 
 @flow.unittest.skip_unless_1n1d()
-class TestGraph(flow.unittest.TestCase):
+class TestForwardGraph(flow.unittest.TestCase):
     def test_forward_graph(test_case):
         class CustomGraph(flow.nn.Graph):
             def __init__(self, module):
diff --git a/oneflow/python/test/graph/test_graph_optimizer.py b/oneflow/python/test/graph/test_graph_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bc78f94d1c7c9b28fca8e0dd7343b651d240a63
--- /dev/null
+++ b/oneflow/python/test/graph/test_graph_optimizer.py
@@ -0,0 +1,95 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import unittest
+import os
+
+import numpy as np
+
+import oneflow
+import oneflow.experimental as flow
+
+
+@flow.unittest.skip_unless_1n1d()
+class TestGraphOptimizer(flow.unittest.TestCase):
+    def test_optimizer(test_case):
+        class CustomModule(flow.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.para0 = flow.nn.Parameter(flow.Tensor(1, 4))
+                self.para1 = flow.nn.Parameter(flow.Tensor(1, 4))
+                self.para2 = flow.nn.Parameter(flow.Tensor(1, 4))
+                self.para2.requires_grad_(False)
+                self.para3 = flow.nn.Parameter(flow.Tensor(1, 4))
+                self.para4 = flow.nn.Parameter(flow.Tensor(1, 4))
+
+            def forward(self, x):
+                return x
+
+        m = CustomModule()
+        learning_rate = 0.1
+        momentum = 0.2
+        scale = 0.3
+        sgd0 = flow.optim.SGD(
+            [
+                {
+                    "params": [m.para0, m.para1, m.para2],
+                    "lr": learning_rate,
+                    "momentum": momentum,
+                    "scale": scale,
+                }
+            ]
+        )
+        sgd1 = flow.optim.SGD(
+            [
+                {
+                    "params": [m.para3],
+                    "lr": learning_rate,
+                    "momentum": momentum,
+                    "scale": scale,
+                },
+                {
+                    "params": [m.para4],
+                    "lr": learning_rate,
+                    "momentum": momentum,
+                    "scale": scale,
+                },
+            ]
+        )
+
+        class CustomGraph0(flow.nn.Graph):
+            def __init__(self):
+                super().__init__()
+                self.m = m
+                self.add_optimizer("sgd0", sgd0)
+                self.add_optimizer("sgd1", sgd1)
+
+            def build(self, x):
+                out = self.m(x)
+                out.backward()
+                return out
+
+        g = CustomGraph0()
+        x = flow.Tensor(1, 1, 10, 10)
+        flow.nn.init.uniform_(x, a=-1.0, b=1.0)
+        z = g._compile(x)
+        print(repr(g))
+        print("g.config.proto: \n", g.config.proto)
+        print("graph proto: \n", g._graph_proto)
+
+
+if __name__ == "__main__":
+    unittest.main()