|
| 1 | +# DRR (Declarative Rewrite Rule) Tool User Manual |
| 2 | +--- |
| 3 | +## 1. Related Background |
| 4 | + |
| 5 | +PASS is a crucial component for optimizing intermediate representations (IR), and the transformation of DAG-to-DAG (Replace a subgraph of the directed acyclic graph (DAG) type in the original graph with another subgraph) is the most common type of Pass. The transformation of DAG-to-DAG can be divided into two steps: matching and rewriting. Matching refers to the complete matching of a known subgraph to the corresponding target subgraph in the Program, while rewriting refers to replacing the matched graph with a new subgraph. |
| 6 | + |
| 7 | +DRR can reduce the development cost of PASS, allowing developers to focus on processing optimization logic without caring about the data structure of the underlying IR. After the developer declares the pattern of the target subgraph and the new subgraph to be replaced through a set of simple and easy-to-use interfaces, DRR can automatically match the original subgraph in the Program and replace it with the new subgraph. |
| 8 | + |
| 9 | +Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows: |
| 10 | +~~~ c++ |
| 11 | +// 1. Inherit specialized template class from DrPatternBase |
| 12 | +class RemoveRedundentCastPattern |
| 13 | + : public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> { |
| 14 | + // 2. Overload operator() |
| 15 | + void operator()(pir::drr::DrrPatternContext *ctx) const override { |
| 16 | + // 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute |
| 17 | + auto pat = ctx->SourcePattern(); |
| 18 | + |
| 19 | + pat.Tensor("tmp") = // CastOp output Tensor named "tmp" |
| 20 | + pat.Op(paddle::dialect::CastOp::name(), // Pass in the name of the CastOp |
| 21 | + {{"dtype", pat.Attr("dtype1")}}) // The corresponding globally unique ID of the "dtype" attribute of CastOp is "dtype1" |
| 22 | + (pat.Tensor("arg0")); // The input Tensor of CastOp is "arg0" |
| 23 | + pat.Tensor("ret") = |
| 24 | + pat.Op(paddle::dialect::CastOp::name(), |
| 25 | + {{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp")); |
| 26 | + // 4. Define Constrain |
| 27 | + pat.RequireEqual(pat("tmp").dtype(), pat.Tensor("ret").dtype()); |
| 28 | + |
| 29 | + // 5. Define ResultPattern |
| 30 | + auto res = pat.ResultPattern(); |
| 31 | + res.Tensor("ret") = |
| 32 | + res.Op(paddle::dialect::CastOp::name(), |
| 33 | + {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); |
| 34 | + } |
| 35 | +}; |
| 36 | +~~~ |
| 37 | +
|
| 38 | +DRR PASS contains the following three parts: |
| 39 | ++ `Source Pattern`:used to describe the target subgraph to be matched in Program |
| 40 | ++ `Constrains`:used to specify constraints for SourcePattern matching(nonessential) |
| 41 | ++ `Result Pattern`:Used to describe the subgraph that needs to be replaced by |
| 42 | +Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern` to implement a complete PASS. |
| 43 | +
|
| 44 | +**Note:** |
| 45 | +1. **DRR only supports matching and replacing the closed SourcePattern and ResultPattern (except for the Pattern input and output Tensor, all internal Tensors cannot be used by the Pattern external Op). If the defined Pattern is not closed in the Program, the matching will fail.** |
| 46 | +2. **The input and output of ResultPattern need to be a subset of the input and output of SourcePattern.** |
| 47 | +## 2. Interface List |
| 48 | +<table> |
| 49 | + <tr> |
| 50 | + <th> Class </th> |
| 51 | + <th> Function </th> |
| 52 | + <th> Function Description </th> |
| 53 | + <th> Parameter Interpretation </th> |
| 54 | + </tr> |
| 55 | + <tr> |
| 56 | + <td rowspan="1">DrrPatternBase</td> |
| 57 | + <td> <pre> virtual void operator()( |
| 58 | + pir::drr::DrrPatternContext* ctx) const </pre></td> |
| 59 | + <td> Implement the entry function of DRR PASS </td> |
| 60 | + <td> ctx: Context parameters required to create Patten</td> |
| 61 | + </tr> |
| 62 | + <tr> |
| 63 | + <td rowspan="6"> SourcePattern</td> |
| 64 | + <td><pre> const drr::Op& Op( |
| 65 | + const std::string& op_type, |
| 66 | + const std::unordered_map<std::string, Attribute>& attributes)</pre></td> |
| 67 | + <td> Define an Op in the SourcePattern</td> |
| 68 | + <td> op_type: The defined Op name. Can be obtained through paddle::dialect::xxOp::name() interface <br> attributes : Attribute information of the created Op </td> |
| 69 | + </tr> |
| 70 | + <tr> |
| 71 | + <td><pre> const drr::Tensor& Tensor( |
| 72 | + const std::string& tensor_name) </pre></td> |
| 73 | + <td> Define a tensor named tensor_name in SourcePattern</td> |
| 74 | + <td> tensor_name: The name of the defined Tensor must be unique within the SourcePattern </td> |
| 75 | + </tr> |
| 76 | + <tr> |
| 77 | + <td> <pre> Attribute Attr( |
| 78 | + const std::string& attr_name) const </pre></td> |
| 79 | + <td> Define an attribute named attr_name in SourcePattern</td> |
| 80 | + <td> attr_name: The name of the attribute, which needs to be unique within SourcePattern </td> |
| 81 | + </tr> |
| 82 | + <tr> |
| 83 | + <td><pre> void RequireEqual( |
| 84 | + const TensorShape& first, |
| 85 | + const TensorShape& second)</pre></td> |
| 86 | + <td> Requires the TensorShape of the two Tensors in SourcePattern to be the same</td> |
| 87 | + <td> first: first TensorShape <br> second : second TensorShape</td> |
| 88 | + </tr> |
| 89 | + <tr> |
| 90 | + <td><pre> void RequireEqual( |
| 91 | + const TensorDataType& first, |
| 92 | + const TensorDataType& second)</pre></td> |
| 93 | + <td> The data types of the two Tensors in SourcePattern are required to be the same</td> |
| 94 | + <td> first: DataType of the first Tensor <br> second : DataType of the second Tensor</td> |
| 95 | + </tr> |
| 96 | + <tr> |
| 97 | + <td> <pre>void RequireNativeCall( |
| 98 | + const std::function<bool(const MatchContext&)>& custom_fn)</pre></td> |
| 99 | + <td> Define a constraint in SourcePattern. You can use this interface and lambda expressions to implement custom constraints on SourcePattern.</td> |
| 100 | + <td> custom_fn: Customized constraint functions</td> |
| 101 | + </tr> |
| 102 | + <tr> |
| 103 | + <td rowspan="5"> ResultPattern</td> |
| 104 | + <td><pre> const drr::Op& Op( |
| 105 | + const std::string& op_type, |
| 106 | + const std::unordered_map<std::string, Attribute>& attributes) </pre></td> |
| 107 | + <td> Define an Op in ResultPattern </td> |
| 108 | + <td> op_type: The defined Op name. Can be obtained through paddle::dialect::xxOp::name() interface<br> attributes : Attribute information of the created Op </td> |
| 109 | + </tr> |
| 110 | + <tr> |
| 111 | + <td> <pre>const drr::Tensor& Tensor( |
| 112 | + const std::string& tensor_name)</pre></td> |
| 113 | + <td> Define a tensor named tensor_name in ResultPattern</td> |
| 114 | + <td> tensor_name: The name of the defined Tensor must be unique within the ResultPattern </td> |
| 115 | + </tr> |
| 116 | + <tr> |
| 117 | + <td><pre>Attribute Attr( |
| 118 | + const std::string& attr_name) const </pre></td> |
| 119 | + <td> Define an attribute named attr_name in ResultPattern </td> |
| 120 | + <td> attr_name: The name of the attribute must be unique within the ResultPattern </td> |
| 121 | + </tr> |
| 122 | +<tr> |
| 123 | + <td><pre>using AttrComputeFunc = std::function<std::any(const MatchContext&)>; |
| 124 | +Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td> |
| 125 | + <td> Create an Attribute through a custom calculation logic AttrComputeFunc</td> |
| 126 | + <td>attr_compute_func: Customized calculation logic</td> |
| 127 | + </tr> |
| 128 | + <tr> |
| 129 | + <td> <pre>drr::Tensor& NoneTensor()</pre></td> |
| 130 | + <td> When the input Tensor of an Op is optional and not needed, NoneTensor needs to be used to occupy the place.</td> |
| 131 | + <td> / </td> |
| 132 | + </tr> |
| 133 | + <tr> |
| 134 | + <td rowspan="2"> TensorShape</td> |
| 135 | + <td><pre>explicit TensorShape( |
| 136 | + const std::string& tensor_name) </pre></td> |
| 137 | + <td> Abstract the class that describes the shape of Tensor </td> |
| 138 | + <td> tensor_name: The name of the Tensor being described </td> |
| 139 | + </tr> |
| 140 | + <tr> |
| 141 | + <td><pre> const std::string& tensor_name() const</pre></td> |
| 142 | + <td> Obtain the name of Tensor </td> |
| 143 | + <td> / </td> |
| 144 | + </tr> |
| 145 | + <tr> |
| 146 | + <td rowspan="2"> TensorDataType</td> |
| 147 | + <td><pre>explicit TensorDataType( |
| 148 | + const std::string& tensor_name)</pre></td> |
| 149 | + <td> An abstract class that describes the data types of elements in Tensor </td> |
| 150 | + <td> tensor_name: The name of the Tensor being described </td> |
| 151 | + </tr> |
| 152 | + <tr> |
| 153 | + <td><pre> const std::string& tensor_name() const</pre></td> |
| 154 | + <td> Obtain the name of Tensor </td> |
| 155 | + <td> / </td> |
| 156 | + </tr> |
| 157 | + <tr> |
| 158 | + <td rowspan="1"> DrrPatternContext</td> |
| 159 | + <td><pre>drr::SourcePattern DrrPatternContext::SourcePattern()</pre> </td> |
| 160 | + <td> Create a SourcePattern object and return </td> |
| 161 | + <td> / </td> |
| 162 | + </tr> |
| 163 | +</table> |
| 164 | +
|
| 165 | +## 3 Example |
| 166 | +Example 1: Matmul + Add -> FusedGemmEpilogue |
| 167 | +~~~ c++ |
| 168 | +class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> { |
| 169 | + public: |
| 170 | + void operator()(pir::drr::DrrPatternContext *ctx) const override { |
| 171 | + // Define SourcePattern |
| 172 | + pir::drr::SourcePattern pat = ctx->SourcePattern(); |
| 173 | + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), |
| 174 | + {{"transpose_x", pat.Attr("trans_x")}, |
| 175 | + {"transpose_y", pat.Attr("trans_y")}}); |
| 176 | + const auto &add = pat.Op(paddle::dialect::AddOp::name()); |
| 177 | +
|
| 178 | + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); |
| 179 | + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); |
| 180 | +
|
| 181 | + // Define ResultPattern |
| 182 | + pir::drr::ResultPattern res = pat.ResultPattern(); |
| 183 | + // Define Constrain |
| 184 | + const auto &act_attr = |
| 185 | + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { |
| 186 | + return "none"; |
| 187 | + }); |
| 188 | + const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), |
| 189 | + {{{"trans_x", pat.Attr("trans_x")}, |
| 190 | + {"trans_y", pat.Attr("trans_y")}, |
| 191 | + {"activation", act_attr}}}); |
| 192 | + fused_gemm_epilogue( |
| 193 | + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, |
| 194 | + {&res.Tensor("out")}); |
| 195 | + } |
| 196 | +}; |
| 197 | +~~~ |
| 198 | + |
| 199 | +Example 2: Full + Expand -> Full |
| 200 | +~~~ c++ |
| 201 | +class FoldExpandToConstantPattern |
| 202 | + : public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> { |
| 203 | + public: |
| 204 | + void operator()(pir::drr::DrrPatternContext *ctx) const override { |
| 205 | + // Define SourcePattern |
| 206 | + pir::drr::SourcePattern pat = ctx->SourcePattern(); |
| 207 | + const auto &full1 = pat.Op(paddle::dialect::FullOp::name(), |
| 208 | + {{"shape", pat.Attr("shape_1")}, |
| 209 | + {"value", pat.Attr("value_1")}, |
| 210 | + {"dtype", pat.Attr("dtype_1")}, |
| 211 | + {"place", pat.Attr("place_1")}}); |
| 212 | + const auto &full_int_array1 = |
| 213 | + pat.Op(paddle::dialect::FullIntArrayOp::name(), |
| 214 | + {{"value", pat.Attr("expand_shape_value")}, |
| 215 | + {"dtype", pat.Attr("dtype_2")}, |
| 216 | + {"place", pat.Attr("place_2")}}); |
| 217 | + const auto &expand = pat.Op(paddle::dialect::ExpandOp::name()); |
| 218 | + pat.Tensor("ret") = expand(full1(), full_int_array1()); |
| 219 | + |
| 220 | + // Define ResultPattern |
| 221 | + pir::drr::ResultPattern res = pat.ResultPattern(); |
| 222 | + const auto &full2 = res.Op(paddle::dialect::FullOp::name(), |
| 223 | + {{"shape", pat.Attr("expand_shape_value")}, |
| 224 | + {"value", pat.Attr("value_1")}, |
| 225 | + {"dtype", pat.Attr("dtype_1")}, |
| 226 | + {"place", pat.Attr("place_1")}}); |
| 227 | + res.Tensor("ret") = full2(); |
| 228 | + } |
| 229 | +}; |
| 230 | +~~~ |
0 commit comments