Skip to content

Commit 8ae6a33

Browse files
authored
Create README.md and README_cn.md test=document_fix (PaddlePaddle#58631)
1 parent 6677117 commit 8ae6a33

File tree

2 files changed

+463
-0
lines changed

2 files changed

+463
-0
lines changed

paddle/fluid/pir/drr/README.md

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# DRR (Declarative Rewrite Rule) Tool User Manual
2+
---
3+
## 1. Related Background
4+
5+
PASS is a crucial component for optimizing intermediate representations (IR), and the transformation of DAG-to-DAG (Replace a subgraph of the directed acyclic graph (DAG) type in the original graph with another subgraph) is the most common type of Pass. The transformation of DAG-to-DAG can be divided into two steps: matching and rewriting. Matching refers to the complete matching of a known subgraph to the corresponding target subgraph in the Program, while rewriting refers to replacing the matched graph with a new subgraph.
6+
7+
DRR can reduce the development cost of PASS, allowing developers to focus on processing optimization logic without caring about the data structure of the underlying IR. After the developer declares the pattern of the target subgraph and the new subgraph to be replaced through a set of simple and easy-to-use interfaces, DRR can automatically match the original subgraph in the Program and replace it with the new subgraph.
8+
9+
Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows:
10+
~~~ c++
11+
// 1. Inherit specialized template class from DrPatternBase
12+
class RemoveRedundentCastPattern
13+
: public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
14+
// 2. Overload operator()
15+
void operator()(pir::drr::DrrPatternContext *ctx) const override {
16+
// 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute
17+
auto pat = ctx->SourcePattern();
18+
19+
pat.Tensor("tmp") = // CastOp output Tensor named "tmp"
20+
pat.Op(paddle::dialect::CastOp::name(), // Pass in the name of the CastOp
21+
{{"dtype", pat.Attr("dtype1")}}) // The corresponding globally unique ID of the "dtype" attribute of CastOp is "dtype1"
22+
(pat.Tensor("arg0")); // The input Tensor of CastOp is "arg0"
23+
pat.Tensor("ret") =
24+
pat.Op(paddle::dialect::CastOp::name(),
25+
{{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp"));
26+
// 4. Define Constrain
27+
pat.RequireEqual(pat("tmp").dtype(), pat.Tensor("ret").dtype());
28+
29+
// 5. Define ResultPattern
30+
auto res = pat.ResultPattern();
31+
res.Tensor("ret") =
32+
res.Op(paddle::dialect::CastOp::name(),
33+
{{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0"));
34+
}
35+
};
36+
~~~
37+
38+
DRR PASS contains the following three parts:
39+
+ `Source Pattern`:used to describe the target subgraph to be matched in Program
40+
+ `Constrains`:used to specify constraints for SourcePattern matching(nonessential)
41+
+ `Result Pattern`:Used to describe the subgraph that needs to be replaced by
42+
Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern` to implement a complete PASS.
43+
44+
**Note:**
45+
1. **DRR only supports matching and replacing the closed SourcePattern and ResultPattern (except for the Pattern input and output Tensor, all internal Tensors cannot be used by the Pattern external Op). If the defined Pattern is not closed in the Program, the matching will fail.**
46+
2. **The input and output of ResultPattern need to be a subset of the input and output of SourcePattern.**
47+
## 2. Interface List
48+
<table>
49+
<tr>
50+
<th> Class </th>
51+
<th> Function </th>
52+
<th> Function Description </th>
53+
<th> Parameter Interpretation </th>
54+
</tr>
55+
<tr>
56+
<td rowspan="1">DrrPatternBase</td>
57+
<td> <pre> virtual void operator()(
58+
pir::drr::DrrPatternContext* ctx) const </pre></td>
59+
<td> Implement the entry function of DRR PASS </td>
60+
<td> ctx: Context parameters required to create Patten</td>
61+
</tr>
62+
<tr>
63+
<td rowspan="6"> SourcePattern</td>
64+
<td><pre> const drr::Op& Op(
65+
const std::string& op_type,
66+
const std::unordered_map&lt;std::string, Attribute&gt;& attributes)</pre></td>
67+
<td> Define an Op in the SourcePattern</td>
68+
<td> op_type: The defined Op name. Can be obtained through paddle::dialect::xxOp::name() interface <br> attributes : Attribute information of the created Op </td>
69+
</tr>
70+
<tr>
71+
<td><pre> const drr::Tensor& Tensor(
72+
const std::string& tensor_name) </pre></td>
73+
<td> Define a tensor named tensor_name in SourcePattern</td>
74+
<td> tensor_name: The name of the defined Tensor must be unique within the SourcePattern </td>
75+
</tr>
76+
<tr>
77+
<td> <pre> Attribute Attr(
78+
const std::string& attr_name) const </pre></td>
79+
<td> Define an attribute named attr_name in SourcePattern</td>
80+
<td> attr_name: The name of the attribute, which needs to be unique within SourcePattern </td>
81+
</tr>
82+
<tr>
83+
<td><pre> void RequireEqual(
84+
const TensorShape& first,
85+
const TensorShape& second)</pre></td>
86+
<td> Requires the TensorShape of the two Tensors in SourcePattern to be the same</td>
87+
<td> first: first TensorShape <br> second : second TensorShape</td>
88+
</tr>
89+
<tr>
90+
<td><pre> void RequireEqual(
91+
const TensorDataType& first,
92+
const TensorDataType& second)</pre></td>
93+
<td> The data types of the two Tensors in SourcePattern are required to be the same</td>
94+
<td> first: DataType of the first Tensor <br> second : DataType of the second Tensor</td>
95+
</tr>
96+
<tr>
97+
<td> <pre>void RequireNativeCall(
98+
const std::function&lt;bool(const MatchContext&)&gt;& custom_fn)</pre></td>
99+
<td> Define a constraint in SourcePattern. You can use this interface and lambda expressions to implement custom constraints on SourcePattern.</td>
100+
<td> custom_fn: Customized constraint functions</td>
101+
</tr>
102+
<tr>
103+
<td rowspan="5"> ResultPattern</td>
104+
<td><pre> const drr::Op& Op(
105+
const std::string& op_type,
106+
const std::unordered_map&lt;std::string, Attribute&gt;& attributes) </pre></td>
107+
<td> Define an Op in ResultPattern </td>
108+
<td> op_type: The defined Op name. Can be obtained through paddle::dialect::xxOp::name() interface<br> attributes : Attribute information of the created Op </td>
109+
</tr>
110+
<tr>
111+
<td> <pre>const drr::Tensor& Tensor(
112+
const std::string& tensor_name)</pre></td>
113+
<td> Define a tensor named tensor_name in ResultPattern</td>
114+
<td> tensor_name: The name of the defined Tensor must be unique within the ResultPattern </td>
115+
</tr>
116+
<tr>
117+
<td><pre>Attribute Attr(
118+
const std::string& attr_name) const </pre></td>
119+
<td> Define an attribute named attr_name in ResultPattern </td>
120+
<td> attr_name: The name of the attribute must be unique within the ResultPattern </td>
121+
</tr>
122+
<tr>
123+
<td><pre>using AttrComputeFunc = std::function&lt;std::any(const MatchContext&)&gt;;
124+
Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
125+
<td> Create an Attribute through a custom calculation logic AttrComputeFunc</td>
126+
<td>attr_compute_func: Customized calculation logic</td>
127+
</tr>
128+
<tr>
129+
<td> <pre>drr::Tensor& NoneTensor()</pre></td>
130+
<td> When the input Tensor of an Op is optional and not needed, NoneTensor needs to be used to occupy the place.</td>
131+
<td> / </td>
132+
</tr>
133+
<tr>
134+
<td rowspan="2"> TensorShape</td>
135+
<td><pre>explicit TensorShape(
136+
const std::string& tensor_name) </pre></td>
137+
<td> Abstract the class that describes the shape of Tensor </td>
138+
<td> tensor_name: The name of the Tensor being described </td>
139+
</tr>
140+
<tr>
141+
<td><pre> const std::string& tensor_name() const</pre></td>
142+
<td> Obtain the name of Tensor </td>
143+
<td> / </td>
144+
</tr>
145+
<tr>
146+
<td rowspan="2"> TensorDataType</td>
147+
<td><pre>explicit TensorDataType(
148+
const std::string& tensor_name)</pre></td>
149+
<td> An abstract class that describes the data types of elements in Tensor </td>
150+
<td> tensor_name: The name of the Tensor being described </td>
151+
</tr>
152+
<tr>
153+
<td><pre> const std::string& tensor_name() const</pre></td>
154+
<td> Obtain the name of Tensor </td>
155+
<td> / </td>
156+
</tr>
157+
<tr>
158+
<td rowspan="1"> DrrPatternContext</td>
159+
<td><pre>drr::SourcePattern DrrPatternContext::SourcePattern()</pre> </td>
160+
<td> Create a SourcePattern object and return </td>
161+
<td> / </td>
162+
</tr>
163+
</table>
164+
165+
## 3 Example
166+
Example 1: Matmul + Add -> FusedGemmEpilogue
167+
~~~ c++
168+
class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
169+
public:
170+
void operator()(pir::drr::DrrPatternContext *ctx) const override {
171+
// Define SourcePattern
172+
pir::drr::SourcePattern pat = ctx->SourcePattern();
173+
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
174+
{{"transpose_x", pat.Attr("trans_x")},
175+
{"transpose_y", pat.Attr("trans_y")}});
176+
const auto &add = pat.Op(paddle::dialect::AddOp::name());
177+
178+
pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w"));
179+
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
180+
181+
// Define ResultPattern
182+
pir::drr::ResultPattern res = pat.ResultPattern();
183+
// Define Constrain
184+
const auto &act_attr =
185+
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
186+
return "none";
187+
});
188+
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
189+
{{{"trans_x", pat.Attr("trans_x")},
190+
{"trans_y", pat.Attr("trans_y")},
191+
{"activation", act_attr}}});
192+
fused_gemm_epilogue(
193+
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
194+
{&res.Tensor("out")});
195+
}
196+
};
197+
~~~
198+
199+
Example 2: Full + Expand -> Full
200+
~~~ c++
201+
class FoldExpandToConstantPattern
202+
: public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
203+
public:
204+
void operator()(pir::drr::DrrPatternContext *ctx) const override {
205+
// Define SourcePattern
206+
pir::drr::SourcePattern pat = ctx->SourcePattern();
207+
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
208+
{{"shape", pat.Attr("shape_1")},
209+
{"value", pat.Attr("value_1")},
210+
{"dtype", pat.Attr("dtype_1")},
211+
{"place", pat.Attr("place_1")}});
212+
const auto &full_int_array1 =
213+
pat.Op(paddle::dialect::FullIntArrayOp::name(),
214+
{{"value", pat.Attr("expand_shape_value")},
215+
{"dtype", pat.Attr("dtype_2")},
216+
{"place", pat.Attr("place_2")}});
217+
const auto &expand = pat.Op(paddle::dialect::ExpandOp::name());
218+
pat.Tensor("ret") = expand(full1(), full_int_array1());
219+
220+
// Define ResultPattern
221+
pir::drr::ResultPattern res = pat.ResultPattern();
222+
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
223+
{{"shape", pat.Attr("expand_shape_value")},
224+
{"value", pat.Attr("value_1")},
225+
{"dtype", pat.Attr("dtype_1")},
226+
{"place", pat.Attr("place_1")}});
227+
res.Tensor("ret") = full2();
228+
}
229+
};
230+
~~~

0 commit comments

Comments
 (0)