Skip to content

Commit a32ce8c

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into port_pybind11
2 parents 507f479 + 5d2834f commit a32ce8c

File tree

10 files changed

+736
-7
lines changed

10 files changed

+736
-7
lines changed

doc/fluid/new_docs/advanced_usage/deploy/native_infer.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Paddle 预测 API
44
为了更简单方便的预测部署,Fluid 提供了一套高层 API
55
用来隐藏底层不同的优化实现。
66

7-
`预测库相关代码 <https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/contrib/inference>`__
7+
`预测库相关代码 <https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/inference/api>`__
88
包括
99

1010
- 头文件 ``paddle_inference_api.h`` 定义了所有的接口
@@ -104,5 +104,5 @@ engine
104104
------------
105105

106106
- `inference
107-
demos <https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/contrib/inference/demo>`__
108-
- `复杂单线程/多线程例子 <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/contrib/inference/test_paddle_inference_api_impl.cc>`__
107+
demos <https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/inference/api/demo_ci>`__
108+
- `复杂单线程/多线程例子 <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/api/api_impl_tester.cc>`__

paddle/fluid/framework/ir/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ cc_library(graph SRCS graph.cc DEPS node)
33
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
44
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
55
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
6+
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
7+
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
68

79
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
810
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
911
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
12+
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
#include <vector>
17+
18+
#include "paddle/fluid/framework/ir/graph_helper.h"
19+
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
20+
#include "paddle/fluid/framework/ir/graph_traits.h"
21+
#include "paddle/fluid/platform/enforce.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
namespace ir {
26+
27+
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
28+
nodes_.emplace_back(new PDNode(std::move(teller), name));
29+
auto* cur = nodes_.back().get();
30+
return cur;
31+
}
32+
33+
void PDPattern::AddEdge(PDNode* a, PDNode* b) {
34+
PADDLE_ENFORCE(a);
35+
PADDLE_ENFORCE(b);
36+
PADDLE_ENFORCE(a != b, "can't connect to the same nodes.");
37+
edges_.emplace_back(a, b);
38+
}
39+
40+
void GraphPatternDetecter::operator()(Graph* graph,
41+
GraphPatternDetecter::handle_t handler) {
42+
if (!MarkPDNodesInGraph(*graph)) return;
43+
auto subgraphs = DetectPatterns();
44+
UniquePatterns(&subgraphs);
45+
RemoveOverlappedMatch(&subgraphs);
46+
47+
for (auto& g : subgraphs) {
48+
handler(g, graph);
49+
}
50+
}
51+
52+
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) {
53+
if (graph.Nodes().empty()) return false;
54+
55+
for (auto& node : GraphTraits::DFS(graph)) {
56+
for (const auto& pdnode : pattern_.nodes()) {
57+
if (pdnode->Tell(&node)) {
58+
pdnodes2nodes_[pdnode.get()].insert(&node);
59+
}
60+
}
61+
}
62+
return !pdnodes2nodes_.empty();
63+
}
64+
65+
struct HitGroup {
66+
std::unordered_map<PDNode*, Node*> roles;
67+
68+
bool Match(Node* node, PDNode* pat) {
69+
return !roles.count(pat) || roles.at(pat) == node;
70+
}
71+
72+
void Register(Node* node, PDNode* pat) { roles[pat] = node; }
73+
};
74+
75+
// Tell whether Node a links to b.
76+
bool IsNodesLink(Node* a, Node* b) {
77+
for (auto* node : a->outputs) {
78+
if (b == node) {
79+
return true;
80+
}
81+
}
82+
return false;
83+
}
84+
85+
std::vector<GraphPatternDetecter::subgraph_t>
86+
GraphPatternDetecter::DetectPatterns() {
87+
// Init empty subgraphs.
88+
std::vector<GraphPatternDetecter::subgraph_t> result;
89+
std::vector<HitGroup> init_groups;
90+
PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
91+
auto* first_pnode = pattern_.edges().front().first;
92+
if (!pdnodes2nodes_.count(first_pnode)) return result;
93+
for (auto* node : pdnodes2nodes_[first_pnode]) {
94+
HitGroup group;
95+
group.roles[first_pnode] = node;
96+
init_groups.emplace_back(group);
97+
}
98+
99+
int step = 0;
100+
std::array<std::vector<HitGroup>, 2> bi_records;
101+
bi_records[0] = std::move(init_groups);
102+
103+
// Extend a PDNode to subgraphs by deducing the connection relations defined
104+
// in edges of PDNodes.
105+
for (const auto& edge : pattern_.edges()) {
106+
// Each role has two PDNodes, which indicates two roles.
107+
// Detect two Nodes that can match these two roles and they are connected.
108+
auto& pre_groups = bi_records[step % 2];
109+
auto& cur_groups = bi_records[1 - (step++ % 2)];
110+
cur_groups.clear();
111+
// source -> target
112+
for (Node* source : pdnodes2nodes_[edge.first]) {
113+
for (Node* target : pdnodes2nodes_[edge.second]) {
114+
// TODO(Superjomn) add some prune strategies.
115+
for (const auto& group : pre_groups) {
116+
HitGroup new_group = group;
117+
if (IsNodesLink(source, target) &&
118+
new_group.Match(source, edge.first)) {
119+
new_group.Register(source, edge.first);
120+
if (new_group.Match(target, edge.second)) {
121+
new_group.Register(target, edge.second);
122+
cur_groups.push_back(new_group);
123+
// TODO(Superjomn) need to unique
124+
}
125+
}
126+
}
127+
}
128+
}
129+
}
130+
131+
for (auto& group : bi_records[step % 2]) {
132+
GraphPatternDetecter::subgraph_t subgraph;
133+
for (auto& role : group.roles) {
134+
subgraph.emplace(role.first, role.second);
135+
}
136+
result.emplace_back(subgraph);
137+
}
138+
return result;
139+
}
140+
141+
void GraphPatternDetecter::UniquePatterns(
142+
std::vector<GraphPatternDetecter::subgraph_t>* subgraphs) {
143+
if (subgraphs->empty()) return;
144+
std::vector<GraphPatternDetecter::subgraph_t> result;
145+
146+
std::unordered_set<size_t> set;
147+
for (auto& g : *subgraphs) {
148+
size_t key = 0;
149+
for (auto& item : g) {
150+
key ^= std::hash<void*>{}(item.first);
151+
key ^= std::hash<void*>{}(item.second);
152+
}
153+
if (!set.count(key)) {
154+
result.emplace_back(g);
155+
set.insert(key);
156+
}
157+
}
158+
*subgraphs = result;
159+
}
160+
161+
void GraphPatternDetecter::RemoveOverlappedMatch(
162+
std::vector<subgraph_t>* subgraphs) {
163+
std::vector<subgraph_t> result;
164+
std::unordered_set<Node*> node_set;
165+
166+
for (const auto& subgraph : *subgraphs) {
167+
bool valid = true;
168+
for (auto& item : subgraph) {
169+
if (node_set.count(item.second)) {
170+
valid = false;
171+
break;
172+
}
173+
}
174+
if (valid) {
175+
for (auto& item : subgraph) {
176+
node_set.insert(item.second);
177+
}
178+
result.push_back(subgraph);
179+
}
180+
}
181+
*subgraphs = result;
182+
}
183+
184+
} // namespace ir
185+
} // namespace framework
186+
} // namespace paddle
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#ifdef PADDLE_WITH_TESTING
18+
#include <gtest/gtest_prod.h>
19+
#endif
20+
21+
#include <numeric>
22+
#include "paddle/fluid/framework/ir/graph.h"
23+
#include "paddle/fluid/framework/ir/node.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
namespace ir {
28+
29+
// Some basic torminolygies:
30+
// - PDPattern: a pattern defined as a data flow graph.
31+
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
32+
// that meets some conditions defined in `PDNode.teller`.
33+
// - A pattern is defined with PDNodes with edges.
34+
35+
// Pattern detector node. This node helps to build a pattern.
36+
struct PDNode {
37+
// tell whether an ir::Node* is a candidation for a PDNode.
38+
using teller_t = std::function<bool(Node*)>;
39+
40+
PDNode(teller_t&& teller, const std::string& name = "")
41+
: teller_(teller), name_(name) {
42+
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
43+
}
44+
45+
PDNode(PDNode&& other) = default;
46+
47+
std::vector<PDNode*> inlinks;
48+
std::vector<PDNode*> outlinks;
49+
50+
bool Tell(Node* node) const {
51+
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
52+
return teller_(node);
53+
}
54+
55+
const std::string& name() const { return name_; }
56+
57+
PDNode(const PDNode&) = delete;
58+
PDNode& operator=(const PDNode&) = delete;
59+
60+
private:
61+
teller_t teller_;
62+
std::string name_;
63+
};
64+
65+
/*
66+
* A pattern in a graph, which defined with PDNode and edges. Most graph
67+
* patterns can be divided into PDNodes and link relations between them.
68+
*
69+
* For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD
70+
* operators from the computation graph, the MUL's output should have only one
71+
* consumer which is the ELEMENTWISE_ADD.
72+
* This pattern can be defined as with the following pseudo codes
73+
*
74+
* // Create two operator PDNodes.
75+
* MUL = PDPattern.NewNode()
76+
* ELE = PDPattern.NewNode()
77+
* // Create the variable PDNodes.
78+
* MUL_out = PDPattern.NewNode()
79+
* // Add teller to define some rules that help to filter the target Nodes.
80+
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
81+
* ELE.teller = lambda(node): \
82+
* node->IsOp() && node->Op()->Type == "elementwise_add";
83+
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
84+
* && (ELE in node->outputs)
85+
*
86+
* One can add more specific tellers for PDNodes or edges, both the Operator
87+
* and Variable Nodes can be ruled in PDNode.teller.
88+
*
89+
* PDPattern can record the general patterns, such as the pattern represents
90+
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
91+
* - Ops whose inputs and outputs share the same variables
92+
*/
93+
class PDPattern {
94+
public:
95+
using edge_t = std::pair<PDNode*, PDNode*>;
96+
97+
void AddEdge(PDNode* a, PDNode* b);
98+
99+
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = "");
100+
101+
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
102+
const std::vector<edge_t>& edges() const { return edges_; }
103+
104+
private:
105+
#ifdef PADDLE_WITH_TESTING
106+
FRIEND_TEST(PDPattern, AddEdge);
107+
FRIEND_TEST(PDPattern, NewNode);
108+
#endif
109+
110+
std::vector<std::unique_ptr<PDNode>> nodes_;
111+
std::vector<edge_t> edges_;
112+
};
113+
114+
/*
115+
* GraphPatternDetecter helps to detect the specific patterns in the graph.
116+
* Input a pattern, output a list of the matched subgraphs/nodes.
117+
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
118+
*
119+
* The algorithm has three phases:
120+
* 1. Mark the nodes that match the defined PDNodes in a PDPattern,
121+
* 2. Extend a PDNode to subgraphs by deducing the connection relation defined
122+
* in PAPattern(the edges),
123+
* 3. Get the filtered subgraphs and treat them with a pre-defined handler.
124+
*
125+
* Usage:
126+
* // Create a detector
127+
* GraphPatternDetecter detector;
128+
* // Define the detector's pattern, by adding PDNode and define the edges.
129+
* auto* node0 = detector.mutable_pattern().AddNode(...)
130+
* auto* node1 = detector.mutable_pattern().AddNode(...)
131+
* node0->teller = some lambda.
132+
* node1->teller = some lambda.
133+
* detector.mutable_pattern().AddEdge(node0, node1);
134+
* // Create an handler, to define the behavior of treating the filtered
135+
* // subgraphs that comply with the patterns.
136+
* GraphPatternDetecter::handle_t handler = some labmda
137+
* // Execute the detector.
138+
* detector(&graph, handler);
139+
*/
140+
class GraphPatternDetecter {
141+
public:
142+
using subgraph_t = std::unordered_map<PDNode*, Node*>;
143+
144+
// Operate on the detected pattern.
145+
using handle_t =
146+
std::function<void(const subgraph_t& /*hitted pattern*/, Graph*)>;
147+
148+
void operator()(Graph* graph, handle_t handler);
149+
150+
const PDPattern& pattern() const { return pattern_; }
151+
PDPattern* mutable_pattern() { return &pattern_; }
152+
153+
private:
154+
// Mark the nodes that fits the pattern.
155+
bool MarkPDNodesInGraph(const ir::Graph& graph);
156+
157+
// Detect all the pattern and output the hit records.
158+
std::vector<subgraph_t> DetectPatterns();
159+
160+
// Remove duplicate patterns.
161+
void UniquePatterns(std::vector<subgraph_t>* subgraphs);
162+
163+
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
164+
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
165+
166+
#ifdef PADDLE_WITH_TESTING
167+
FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph);
168+
FRIEND_TEST(GraphPatternDetecter, DetectPatterns);
169+
#endif
170+
171+
private:
172+
using hit_rcd_t =
173+
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
174+
PDPattern pattern_;
175+
std::vector<hit_rcd_t> marked_records_;
176+
std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_;
177+
};
178+
179+
} // namespace ir
180+
} // namespace framework
181+
} // namespace paddle

0 commit comments

Comments
 (0)