Skip to content
Snippets Groups Projects
Unverified Commit 3563ba46 authored by Li Xinqi's avatar Li Xinqi Committed by GitHub
Browse files

Dev inplace obn graph (#1868)

* InplaceObnGraph

* more checks in InplaceObnGraph::InitNodes

* framework of InplaceObnGraph::ComputeSafeInplaceObns

* refine InplaceObnGraph::ComputeSafeInplaceObns

* replace InplaceObnGraph with InplaceLbiGraph

* fix three types of mut_ref conflicts

* InplaceLbiGraph::FindFirstConstRefConflictMutRefEdge

* fix bugs in InplaceLbiGraph::ComputeSafeInplaceObns

* InplaceLbiGraph::DisconnectUnReachabeAndDataMutableEdge

* InplaceLbiGraph::FixMutRefConflictsFromSourceOpNode

* InplaceLbiGraph::FixMutRefConflictsFromSourceOpNode

* Graph::FindFirstBackEdgeDstNode

* more CHECK_ISNULL

* fix a bug in Graph::FindFirstBackEdgeDstNode()

a

* fix bugs in Graph<NodeType, EdgeType>::ForEachConnectedComponent

* rename GetIsMutableIbnConsumer => FindSoleMutableIbnConsumer

* refine InplaceLbiGraph::IsConstRefConflictMutRefNode

* there could be no mut_ref node found in InplaceLbiGraph::FindFirstInterOpRefConflictMutRefEdge

* refine InplaceLbiGraph::FindFirstInterOpRefConflictMutRefEdge

* remove unnecessary CHECK in InplaceLbiGraph::GetSafeInplaceObnEdges

* fix a line of comment in InplaceLbiGraph::GetSafeInplaceObnEdges

* shouldn't delete the edge to updt_node

* refine InplaceLbiGraph::FixMutRefConflictsFromSourceOpNode

* refine FindFirstIntraOpRefConflictMutRefEdge

* fix a bug in InplaceLbiGraph::FindFirstIntraOpRefConflictMutRefEdge

* CheckSubGraph

* change some lambdas to functions
parent 606d8446
No related branches found
No related tags found
No related merge requests found
......@@ -217,7 +217,7 @@ int Actor::HandlerNormal(const ActorMsg& msg) {
} else if (inplace_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
CHECK_EQ(0, inplace_consumed_rs_.TryPushBackRegst(regst));
CHECK(regst->packed_blob()->dptr()
== inplace_produced_rs_.Front(regst->regst_desc_id())->packed_blob()->dptr());
== inplace_produced_rs_.Front(regst->regst_desc_id())->packed_blob()->dptr());
} else if (TryUpdtStateAsProducedRegst(regst) == 0) {
// do nothing
} else {
......@@ -236,12 +236,14 @@ int Actor::HandlerNormal(const ActorMsg& msg) {
UNIMPLEMENTED();
}
// handler halts
bool naive_exist_and_eord_and_empty = is_naive_consumed_eord_
&& naive_consumed_rs_.available_regst_desc_cnt() == 0;
bool naive_not_exist_and_customized_eord = naive_consumed_rs_.total_regst_desc_cnt() == 0
&& IsCustomizedReadAlwaysUnReadyFromNow();
bool inplace_condition = inplace_consumed_rs_.total_regst_desc_cnt() != 0 ?
(is_inplace_consumed_eord_ && inplace_consumed_rs_.available_regst_desc_cnt() == 0) : (true);
bool naive_exist_and_eord_and_empty =
is_naive_consumed_eord_ && naive_consumed_rs_.available_regst_desc_cnt() == 0;
bool naive_not_exist_and_customized_eord =
naive_consumed_rs_.total_regst_desc_cnt() == 0 && IsCustomizedReadAlwaysUnReadyFromNow();
bool inplace_condition =
inplace_consumed_rs_.total_regst_desc_cnt() != 0
? (is_inplace_consumed_eord_ && inplace_consumed_rs_.available_regst_desc_cnt() == 0)
: (true);
if (inplace_condition
&& (naive_exist_and_eord_and_empty || naive_not_exist_and_customized_eord)) {
CHECK_EQ(naive_consumed_rs_.available_regst_desc_cnt(), 0);
......@@ -330,7 +332,7 @@ void Actor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
}
void Actor::AsyncSendInplaceProducedRegstMsgToConsumer() {
std::vector<int64_t>regst_desc_ids;
std::vector<int64_t> regst_desc_ids;
inplace_produced_rs_.ForEachFrontRegst([&](Regst* regst) {
CHECK(regst->regst_desc()->regst_desc_type().has_data_regst_desc());
int64_t real_consumer_cnt = HandleRegstToConsumer(regst, [](int64_t) { return true; });
......@@ -402,14 +404,14 @@ int64_t Actor::HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)>
return real_consumer_cnt;
}
bool Actor::IsReadReady() { return naive_consumed_rs_.IsCurSlotReady()
&& inplace_consumed_rs_.IsCurSlotReady()
&& IsCustomizedReadReady(); }
bool Actor::IsReadReady() {
return naive_consumed_rs_.IsCurSlotReady() && inplace_consumed_rs_.IsCurSlotReady()
&& IsCustomizedReadReady();
}
bool Actor::IsWriteReady() {
return naive_produced_rs_.IsCurSlotReady()
&& inplace_produced_rs_.IsCurSlotReady()
&& IsCustomizedWriteReady();
return naive_produced_rs_.IsCurSlotReady() && inplace_produced_rs_.IsCurSlotReady()
&& IsCustomizedWriteReady();
}
void Actor::AsyncLaunchKernel(const KernelCtx& kernel_ctx,
......
......@@ -180,7 +180,6 @@ class Actor {
}
virtual void AsyncSendCustomizedConsumedRegstMsgToProducer() {}
int64_t actor_id_;
int64_t act_id_;
std::unique_ptr<ParallelContext> parallel_ctx_;
......
......@@ -28,6 +28,8 @@
DECLARE_string(log_dir);
#define CHECK_ISNULL(e) CHECK((e) == nullptr)
namespace std {
template<typename T0, typename T1>
struct hash<std::pair<T0, T1>> {
......
......@@ -26,6 +26,11 @@ class Graph {
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,
const std::function<void(NodeType*)>& Handler) const;
void DfsForEachNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,
const std::function<void(NodeType*)>& Handler) const;
void TopoForEachNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
......@@ -52,6 +57,21 @@ class Graph {
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode)
const;
void ForEachConnectedComponent(
const std::function<void(const HashSet<NodeType*>&)>& Handler) const;
void ForEachConnectedComponent(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,
const std::function<void(const HashSet<NodeType*>&)>& Handler) const;
NodeType* FindFirstBackEdgeDstNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext)
const;
NodeType* FindFirstBackEdgeDstNode() const;
// Getters
std::list<NodeType*> source_nodes() const;
std::list<NodeType*> sink_nodes() const;
......@@ -78,6 +98,16 @@ class Graph {
void ToDotWithAutoFilePath();
private:
void ForEachConnectedComponent(
const std::function<void(const std::function<void(NodeType*)>&)>& ForEachPotentialStart,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,
const std::function<void(const HashSet<NodeType*>&)>& Handler) const;
NodeType* FindFirstBackEdgeDstNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,
size_t* node_cnt) const;
std::vector<std::unique_ptr<NodeType>> nodes_;
std::vector<std::unique_ptr<EdgeType>> edges_;
};
......@@ -232,6 +262,74 @@ void Graph<NodeType, EdgeType>::BfsForEachNode(
}
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::DfsForEachNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,
const std::function<void(NodeType*)>& Handler) const {
HashSet<NodeType*> visited_nodes;
std::stack<NodeType*> stack;
for (NodeType* start : starts) { stack.push(start); }
while (!stack.empty()) {
NodeType* cur_node = stack.top();
stack.pop();
if (visited_nodes.find(cur_node) == visited_nodes.end()) {
Handler(cur_node);
visited_nodes.insert(cur_node);
ForEachNext(cur_node, [&](NodeType* next) {
if (visited_nodes.find(next) == visited_nodes.end()) { stack.push(next); }
});
}
}
}
template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::FindFirstBackEdgeDstNode() const {
if (nodes_.empty()) { return nullptr; }
const auto& starts = source_nodes();
if (starts.empty()) { return nodes_.at(0).get(); }
size_t node_cnt = 0;
auto ForEachNext = &NodeType::ForEachNodeOnOutEdge;
NodeType* ret = FindFirstBackEdgeDstNode(starts, ForEachNext, &node_cnt);
if (ret == nullptr && node_cnt != nodes_.size()) {
HashSet<NodeType*> visited_nodes;
BfsForEachNode(starts, ForEachNext, [&](NodeType* node) { visited_nodes.emplace(node); });
for (const auto& node : nodes_) {
if (visited_nodes.find(node.get()) == visited_nodes.end()) { return node.get(); }
}
}
return ret;
}
template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::FindFirstBackEdgeDstNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext)
const {
size_t node_cnt = 0;
return FindFirstBackEdgeDstNode(starts, ForEachNext, &node_cnt);
}
template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::FindFirstBackEdgeDstNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,
size_t* node_cnt) const {
NodeType* back_edge_dst_node = nullptr;
HashSet<NodeType*> visited_nodes;
*node_cnt = 0;
DfsForEachNode(starts, ForEachNext, [&](NodeType* node) {
++*node_cnt;
if (back_edge_dst_node != nullptr) { return; }
visited_nodes.emplace(node);
ForEachNext(node, [&](NodeType* next_node) {
if (back_edge_dst_node != nullptr) { return; }
if (visited_nodes.find(next_node) != visited_nodes.end()) { back_edge_dst_node = next_node; }
});
});
return back_edge_dst_node;
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(
const std::list<NodeType*>& starts,
......@@ -268,23 +366,25 @@ void Graph<NodeType, EdgeType>::DfsTopoForEachNodeSortByDistanceToSink(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const {
std::list<NodeType*> nodes;
TopoForEachNode(starts, ForEachInNode, ForEachOutNode,
[&](NodeType* node) { nodes.push_back(node); });
std::list<NodeType*> sinks;
for (NodeType* node : nodes) {
bool is_sink = true;
ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; });
if (is_sink) { sinks.push_back(node); }
}
HashMap<NodeType*, int64_t> node2distance_to_sink;
TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) {
int64_t distance_to_sink = -1;
ForEachOutNode(node, [&](NodeType* out_node) {
distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]);
{
std::list<NodeType*> nodes;
TopoForEachNode(starts, ForEachInNode, ForEachOutNode,
[&](NodeType* node) { nodes.push_back(node); });
std::list<NodeType*> sinks;
for (NodeType* node : nodes) {
bool is_sink = true;
ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; });
if (is_sink) { sinks.push_back(node); }
}
TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) {
int64_t distance_to_sink = -1;
ForEachOutNode(node, [&](NodeType* out_node) {
distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]);
});
node2distance_to_sink[node] = distance_to_sink + 1;
});
node2distance_to_sink[node] = distance_to_sink + 1;
});
}
auto ForEachOutNodeSortedByDistanceToSink = [&](NodeType* node,
const std::function<void(NodeType*)>& Handler) {
std::vector<NodeType*> out_nodes;
......@@ -352,6 +452,44 @@ Graph<NodeType, EdgeType>::MakePredicatorIsReachable(
};
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachConnectedComponent(
const std::function<void(const HashSet<NodeType*>&)>& Handler) const {
ForEachConnectedComponent(
[&](const std::function<void(NodeType*)>& Handler) { ForEachNode(Handler); },
&NodeType::ForEachNodeOnInOutEdge, Handler);
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachConnectedComponent(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,
const std::function<void(const HashSet<NodeType*>&)>& Handler) const {
auto ForEachPotentialStart = [&](const std::function<void(NodeType*)>& Handler) {
BfsForEachNode(starts, ForEachConnected, Handler);
};
ForEachConnectedComponent(ForEachPotentialStart, ForEachConnected, Handler);
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachConnectedComponent(
const std::function<void(const std::function<void(NodeType*)>&)>& ForEachPotentialStart,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachConnected,
const std::function<void(const HashSet<NodeType*>&)>& Handler) const {
HashMap<NodeType*, int32_t> node2component_id;
int32_t cur_component_id = 0;
ForEachPotentialStart([&](NodeType* start) {
if (node2component_id.find(start) != node2component_id.end()) { return; }
++cur_component_id;
BfsForEachNode({start}, ForEachConnected, [&](NodeType* node) {
CHECK(node2component_id.emplace(node, cur_component_id).second);
});
});
HashMap<int32_t, HashSet<NodeType*>> component_id2nodes;
for (const auto& pair : node2component_id) { component_id2nodes[pair.second].insert(pair.first); }
for (const auto& pair : component_id2nodes) { Handler(pair.second); }
}
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_GRAPH_H_
This diff is collapsed.
#ifndef ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/graph.h"
namespace oneflow {
class InplaceLbiEdge;
class InplaceLbiNode : public Node<InplaceLbiNode, InplaceLbiEdge> {
public:
virtual ~InplaceLbiNode() = default;
const LogicalBlobId& lbi() const { return lbi_; }
const InplaceLbiEdge* GetValidInEdge(
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;
const InplaceLbiEdge* GetSoleValidInEdge(
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;
void ForEachNodeOnValidOutEdge(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
const std::function<void(const InplaceLbiNode*)>& Handler) const;
virtual bool IsMutRef(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;
bool IsConstRef(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;
protected:
OF_DISALLOW_COPY_AND_MOVE(InplaceLbiNode);
explicit InplaceLbiNode(const LogicalBlobId& lbi) : lbi_(lbi) {}
private:
LogicalBlobId lbi_;
};
class NormalInplaceLbiNode final : public InplaceLbiNode {
public:
OF_DISALLOW_COPY_AND_MOVE(NormalInplaceLbiNode);
explicit NormalInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {}
~NormalInplaceLbiNode() override = default;
bool IsMutRef(const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const override;
};
class SourceOpInplaceLbiNode final : public InplaceLbiNode {
public:
OF_DISALLOW_COPY_AND_MOVE(SourceOpInplaceLbiNode);
explicit SourceOpInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {}
~SourceOpInplaceLbiNode() = default;
};
class UpdateInplaceLbiNode final : public InplaceLbiNode {
public:
OF_DISALLOW_COPY_AND_MOVE(UpdateInplaceLbiNode);
explicit UpdateInplaceLbiNode(const LogicalBlobId& lbi) : InplaceLbiNode(lbi) {}
~UpdateInplaceLbiNode() = default;
};
class InplaceLbiEdge final : public Edge<InplaceLbiNode, InplaceLbiEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(InplaceLbiEdge);
InplaceLbiEdge(const Operator* op, const std::string& ibn, const std::string& obn)
: op_(op), ibn_(ibn), obn_(obn) {}
~InplaceLbiEdge() = default;
const Operator& op() const { return *op_; }
const std::string& ibn() const { return ibn_; }
const std::string& obn() const { return obn_; }
bool IsMutRef() const;
bool IsConstRef() const { return !IsMutRef(); }
private:
const Operator* op_;
const std::string ibn_;
const std::string obn_;
};
class InplaceLbiGraph final : public Graph<const InplaceLbiNode, const InplaceLbiEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(InplaceLbiGraph);
InplaceLbiGraph(const OpBlobArgList& obas,
const std::function<const Operator*(const std::string&)>& Op4OpName) {
Init(obas, Op4OpName);
}
~InplaceLbiGraph() = default;
void ComputeSafeInplaceObns(OpBlobArgList* obas,
const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName) const;
private:
void Init(const OpBlobArgList& obas,
const std::function<const Operator*(const std::string&)>& Op4OpName);
std::function<InplaceLbiNode*(const LogicalBlobId&)> MakeMutFindOrCreateNode(
std::function<const Operator*(const std::string&)> Op4OpName);
void ComputeSafeInplaceObns(const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName,
const std::function<void(const InplaceLbiEdge*)>& Handler) const;
void ComputeSafeInplaceObns(const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName,
const std::function<void(const InplaceLbiEdge*)>& Handler) const;
void GetSafeInplaceObnEdges(const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName,
HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;
const InplaceLbiEdge* FindFirstConstRefConflictMutRefEdge(
const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName) const;
const InplaceLbiEdge* FindFirstIntraOpRefConflictMutRefEdge(
const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge) const;
const InplaceLbiEdge* FindFirstInterOpRefConflictMutRefEdge(
const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName) const;
bool IsConstRefConflictMutRefNode(
const InplaceLbiNode* mut_ref_node, const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName) const;
void FixConstRefOrMutRefConflictsToUpdtNode(
const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const LogicalBlobId&, const std::string&)>&
IsReachableFromLbiToOpName,
HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;
void FixMutRefConflictsFromSourceOpNode(
const SourceOpInplaceLbiNode* root,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;
void ForEachTree(const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
const std::function<void(const HashSet<const InplaceLbiNode*>&)>& Handler) const;
void FindAllEdges(const HashSet<const InplaceLbiNode*>& nodes,
const std::function<bool(const InplaceLbiEdge*)>& IsValidEdge,
HashSet<const InplaceLbiEdge*>* cur_disabled_edges) const;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_INPLACE_LBI_GRAPH_H_
......@@ -265,6 +265,7 @@ void OpGraph::Init(const Job& job) {
ForEachNode(
[&](OpNode* node) { CHECK(op_name2op_node_.emplace(node->op().op_name(), node).second); });
InitEdges();
CHECK_ISNULL(FindFirstBackEdgeDstNode());
FixOpParallelDesc();
UpdateOpNodeHasInDiff();
InferTimeShape();
......@@ -541,26 +542,8 @@ void OpGraph::ForEachChainFamily(
if (Is121Edge(edge)) { Handler(edge->dst_node()); }
}
};
ForEachComponent(ForEachConnectedWithSameSbp7ParallelDesc7TimeShape, Handler);
}
void OpGraph::ForEachComponent(
const std::function<void(OpNode*, const std::function<void(OpNode*)>&)>& ForEachConnected,
const std::function<void(const HashSet<OpNode*>&)>& Handler) const {
HashMap<OpNode*, int32_t> op_node2component_id;
int32_t cur_component_id = 0;
ForEachNode([&](OpNode* start) {
if (op_node2component_id.find(start) != op_node2component_id.end()) { return; }
++cur_component_id;
BfsForEachNode({start}, ForEachConnected, [&](OpNode* node) {
CHECK(op_node2component_id.emplace(node, cur_component_id).second);
});
});
HashMap<int32_t, HashSet<OpNode*>> component_id2op_nodes;
for (const auto& pair : op_node2component_id) {
component_id2op_nodes[pair.second].insert(pair.first);
}
for (const auto& pair : component_id2op_nodes) { Handler(pair.second); }
ForEachConnectedComponent(source_nodes(), ForEachConnectedWithSameSbp7ParallelDesc7TimeShape,
Handler);
}
void OpGraph::ForEachPseudoChain(
......
......@@ -130,6 +130,9 @@ class OpGraph final : public Graph<OpNode, OpEdge> {
// a set of nodes is called a chain family if they can divided into several connected chains
void ForEachChainFamily(const std::function<void(const HashSet<OpNode*>&)>& Handler) const;
void ForEachDataAndCtrlInNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;
void ForEachDataAndCtrlOutNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;
private:
void Init(const Job& job);
void InitNodes(const Job& job);
......@@ -149,9 +152,6 @@ class OpGraph final : public Graph<OpNode, OpEdge> {
void ReverseTopoGetPseudoChain(
const HashSet<OpNode*>& op_nodes, HashSet<OpNode*>* chain,
const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable) const;
void ForEachComponent(
const std::function<void(OpNode*, const std::function<void(OpNode*)>&)>& ForEachConnected,
const std::function<void(const HashSet<OpNode*>&)>& Handler) const;
int64_t GetSplitNum(const std::string& op_name, const LogicalBlobId& lbi) const;
HashMap<std::string, OpNode*> op_name2op_node_;
......
......@@ -13,6 +13,10 @@ message InputBlobModifier {
message OutputBlobModifier {
optional bool is_mutable = 1 [default = false];
optional bool requires_grad = 2 [default = false];
oneof inplace_type {
string mutable_inplace_ibn = 3;
string const_inplace_ibn = 4;
}
}
message OpAttribute {
......
......@@ -15,3 +15,7 @@ message OpBlobArgPair {
message OpBlobArgPairs {
repeated OpBlobArgPair pair = 1;
}
message OpBlobArgList {
repeated OpBlobArg oba = 1;
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment