|
| 1 | +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#ifdef PADDLE_WITH_TESTING |
| 18 | +#include <gtest/gtest_prod.h> |
| 19 | +#endif |
| 20 | + |
| 21 | +#include <numeric> |
| 22 | +#include "paddle/fluid/framework/ir/graph.h" |
| 23 | +#include "paddle/fluid/framework/ir/node.h" |
| 24 | + |
| 25 | +namespace paddle { |
| 26 | +namespace framework { |
| 27 | +namespace ir { |
| 28 | + |
| 29 | +// Some basic torminolygies: |
| 30 | +// - PDPattern: a pattern defined as a data flow graph. |
| 31 | +// - PDNode: the node in the pattern, each PDNode represents an `ir::Node` |
| 32 | +// that meets some conditions defined in `PDNode.teller`. |
| 33 | +// - A pattern is defined with PDNodes with edges. |
| 34 | + |
| 35 | +// Pattern detector node. This node helps to build a pattern. |
| 36 | +struct PDNode { |
| 37 | + // tell whether an ir::Node* is a candidation for a PDNode. |
| 38 | + using teller_t = std::function<bool(Node*)>; |
| 39 | + |
| 40 | + PDNode(teller_t&& teller, const std::string& name = "") |
| 41 | + : teller_(teller), name_(name) { |
| 42 | + PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); |
| 43 | + } |
| 44 | + |
| 45 | + PDNode(PDNode&& other) = default; |
| 46 | + |
| 47 | + std::vector<PDNode*> inlinks; |
| 48 | + std::vector<PDNode*> outlinks; |
| 49 | + |
| 50 | + bool Tell(Node* node) const { |
| 51 | + PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); |
| 52 | + return teller_(node); |
| 53 | + } |
| 54 | + |
| 55 | + const std::string& name() const { return name_; } |
| 56 | + |
| 57 | + PDNode(const PDNode&) = delete; |
| 58 | + PDNode& operator=(const PDNode&) = delete; |
| 59 | + |
| 60 | + private: |
| 61 | + teller_t teller_; |
| 62 | + std::string name_; |
| 63 | +}; |
| 64 | + |
| 65 | +/* |
| 66 | + * A pattern in a graph, which defined with PDNode and edges. Most graph |
| 67 | + * patterns can be divided into PDNodes and link relations between them. |
| 68 | + * |
| 69 | + * For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD |
| 70 | + * operators from the computation graph, the MUL's output should have only one |
| 71 | + * consumer which is the ELEMENTWISE_ADD. |
| 72 | + * This pattern can be defined as with the following pseudo codes |
| 73 | + * |
| 74 | + * // Create two operator PDNodes. |
| 75 | + * MUL = PDPattern.NewNode() |
| 76 | + * ELE = PDPattern.NewNode() |
| 77 | + * // Create the variable PDNodes. |
| 78 | + * MUL_out = PDPattern.NewNode() |
| 79 | + * // Add teller to define some rules that help to filter the target Nodes. |
| 80 | + * MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul"; |
| 81 | + * ELE.teller = lambda(node): \ |
| 82 | + * node->IsOp() && node->Op()->Type == "elementwise_add"; |
| 83 | + * MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs) |
| 84 | + * && (ELE in node->outputs) |
| 85 | + * |
| 86 | + * One can add more specific tellers for PDNodes or edges, both the Operator |
| 87 | + * and Variable Nodes can be ruled in PDNode.teller. |
| 88 | + * |
| 89 | + * PDPattern can record the general patterns, such as the pattern represents |
| 90 | + * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. |
| 91 | + * - Ops whose inputs and outputs share the same variables |
| 92 | + */ |
| 93 | +class PDPattern { |
| 94 | + public: |
| 95 | + using edge_t = std::pair<PDNode*, PDNode*>; |
| 96 | + |
| 97 | + void AddEdge(PDNode* a, PDNode* b); |
| 98 | + |
| 99 | + PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = ""); |
| 100 | + |
| 101 | + const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } |
| 102 | + const std::vector<edge_t>& edges() const { return edges_; } |
| 103 | + |
| 104 | + private: |
| 105 | +#ifdef PADDLE_WITH_TESTING |
| 106 | + FRIEND_TEST(PDPattern, AddEdge); |
| 107 | + FRIEND_TEST(PDPattern, NewNode); |
| 108 | +#endif |
| 109 | + |
| 110 | + std::vector<std::unique_ptr<PDNode>> nodes_; |
| 111 | + std::vector<edge_t> edges_; |
| 112 | +}; |
| 113 | + |
| 114 | +/* |
| 115 | + * GraphPatternDetecter helps to detect the specific patterns in the graph. |
| 116 | + * Input a pattern, output a list of the matched subgraphs/nodes. |
| 117 | + * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). |
| 118 | + * |
| 119 | + * The algorithm has three phases: |
| 120 | + * 1. Mark the nodes that match the defined PDNodes in a PDPattern, |
| 121 | + * 2. Extend a PDNode to subgraphs by deducing the connection relation defined |
| 122 | + * in PAPattern(the edges), |
| 123 | + * 3. Get the filtered subgraphs and treat them with a pre-defined handler. |
| 124 | + * |
| 125 | + * Usage: |
| 126 | + * // Create a detector |
| 127 | + * GraphPatternDetecter detector; |
| 128 | + * // Define the detector's pattern, by adding PDNode and define the edges. |
| 129 | + * auto* node0 = detector.mutable_pattern().AddNode(...) |
| 130 | + * auto* node1 = detector.mutable_pattern().AddNode(...) |
| 131 | + * node0->teller = some lambda. |
| 132 | + * node1->teller = some lambda. |
| 133 | + * detector.mutable_pattern().AddEdge(node0, node1); |
| 134 | + * // Create an handler, to define the behavior of treating the filtered |
| 135 | + * // subgraphs that comply with the patterns. |
| 136 | + * GraphPatternDetecter::handle_t handler = some labmda |
| 137 | + * // Execute the detector. |
| 138 | + * detector(&graph, handler); |
| 139 | + */ |
| 140 | +class GraphPatternDetecter { |
| 141 | + public: |
| 142 | + using subgraph_t = std::unordered_map<PDNode*, Node*>; |
| 143 | + |
| 144 | + // Operate on the detected pattern. |
| 145 | + using handle_t = |
| 146 | + std::function<void(const subgraph_t& /*hitted pattern*/, Graph*)>; |
| 147 | + |
| 148 | + void operator()(Graph* graph, handle_t handler); |
| 149 | + |
| 150 | + const PDPattern& pattern() const { return pattern_; } |
| 151 | + PDPattern* mutable_pattern() { return &pattern_; } |
| 152 | + |
| 153 | + private: |
| 154 | + // Mark the nodes that fits the pattern. |
| 155 | + bool MarkPDNodesInGraph(const ir::Graph& graph); |
| 156 | + |
| 157 | + // Detect all the pattern and output the hit records. |
| 158 | + std::vector<subgraph_t> DetectPatterns(); |
| 159 | + |
| 160 | + // Remove duplicate patterns. |
| 161 | + void UniquePatterns(std::vector<subgraph_t>* subgraphs); |
| 162 | + |
| 163 | + // Remove overlapped match subgraphs, when overlapped, keep the previous one. |
| 164 | + void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs); |
| 165 | + |
| 166 | +#ifdef PADDLE_WITH_TESTING |
| 167 | + FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph); |
| 168 | + FRIEND_TEST(GraphPatternDetecter, DetectPatterns); |
| 169 | +#endif |
| 170 | + |
| 171 | + private: |
| 172 | + using hit_rcd_t = |
| 173 | + std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>; |
| 174 | + PDPattern pattern_; |
| 175 | + std::vector<hit_rcd_t> marked_records_; |
| 176 | + std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_; |
| 177 | +}; |
| 178 | + |
| 179 | +} // namespace ir |
| 180 | +} // namespace framework |
| 181 | +} // namespace paddle |
0 commit comments