@@ -37,7 +37,7 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
37
37
38
38
std::string name () const override { return " RmsNormFusePattern" ; }
39
39
40
- uint32_t benefit () const override { return 2 ; }
40
+ uint32_t benefit () const override { return 3 ; }
41
41
42
42
void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
43
43
paddle::drr::SourcePattern pat = ctx->SourcePattern ();
@@ -139,7 +139,14 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
139
139
};
140
140
141
141
class AddRmsNormFusePattern : public paddle ::drr::DrrPatternBase {
142
+ private:
143
+ const bool extra_add_;
144
+
142
145
public:
146
+ explicit AddRmsNormFusePattern (bool extra_add) : extra_add_(extra_add) {}
147
+
148
+ uint32_t benefit () const override { return extra_add_ ? 2 : 1 ; }
149
+
143
150
std::string name () const override { return " AddRmsNormFusePattern" ; }
144
151
145
152
void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
@@ -157,16 +164,21 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
157
164
});
158
165
pat.Tensor (" add_out" ) = add (pat.Tensor (" x" ), pat.Tensor (" residual" ));
159
166
pat_rms_norm ({&pat.Tensor (" add_out" ),
160
- &pat.InputNoneTensor ( ),
167
+ &pat.Tensor ( " bias " ),
161
168
&pat.InputNoneTensor (),
162
169
&pat.Tensor (" w" ),
163
170
&pat.InputNoneTensor ()},
164
171
{&pat.Tensor (" rms_norm_out" ),
165
172
&pat.Tensor (" residual_out_0" ),
166
173
&pat.Tensor (" inv_var_0" )});
167
-
174
+ // TODO(bukejiyu) :DRR support matching placeholder op,
175
+ // the following needs to be deleted
176
+ if (extra_add_) {
177
+ const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
178
+ pat.Tensor (" add_out1" ) =
179
+ add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
180
+ }
168
181
paddle::drr::ResultPattern res = pat.ResultPattern ();
169
-
170
182
const auto &res_rms_norm =
171
183
res.Op (paddle::dialect::RmsNormOp::name (),
172
184
{
@@ -181,19 +193,25 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
181
193
res_rms_norm (
182
194
{
183
195
&res.Tensor (" x" ),
184
- &res.InputNoneTensor ( ),
196
+ &res.Tensor ( " bias " ),
185
197
&res.Tensor (" residual" ),
186
198
&res.Tensor (" w" ),
187
199
&res.InputNoneTensor (),
188
200
},
189
201
{&res.Tensor (" rms_norm_out" ),
190
- &res.Tensor (" residual_out " ),
202
+ &res.Tensor (" add_out " ),
191
203
&res.Tensor (" inv_var" )});
192
204
}
193
205
};
194
206
195
207
class AddLayerNormFusePattern : public paddle ::drr::DrrPatternBase {
208
+ private:
209
+ const bool extra_add_;
210
+
196
211
public:
212
+ explicit AddLayerNormFusePattern (bool extra_add) : extra_add_(extra_add) {}
213
+
214
+ uint32_t benefit () const override { return extra_add_ ? 2 : 1 ; }
197
215
std::string name () const override { return " AddLayerNormFusePattern" ; }
198
216
199
217
void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
@@ -204,11 +222,17 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
204
222
{{" epsilon" , pat.Attr (" epsilon" )},
205
223
{" begin_norm_axis" , pat.Attr (" begin_norm_axis" )}});
206
224
pat.Tensor (" add_out" ) = add (pat.Tensor (" x" ), pat.Tensor (" residual" ));
207
- layer_norm (
208
- {&pat.Tensor (" add_out" ), &pat.Tensor (" w" ), &pat.InputNoneTensor ()},
209
- {&pat.Tensor (" layer_norm_out" ),
210
- &pat.Tensor (" mean_out_0" ),
211
- &pat.Tensor (" variance_out_0" )});
225
+ layer_norm ({&pat.Tensor (" add_out" ), &pat.Tensor (" w" ), &pat.Tensor (" bias" )},
226
+ {&pat.Tensor (" layer_norm_out" ),
227
+ &pat.Tensor (" mean_out_0" ),
228
+ &pat.Tensor (" variance_out_0" )});
229
+ // TODO(bukejiyu) :DRR support matching placeholder op,
230
+ // the following needs to be deleted
231
+ if (extra_add_) {
232
+ const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
233
+ pat.Tensor (" add_out1" ) =
234
+ add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
235
+ }
212
236
213
237
paddle::drr::ResultPattern res = pat.ResultPattern ();
214
238
const auto &fuse_layer_norm =
@@ -224,13 +248,13 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
224
248
fuse_layer_norm (
225
249
{
226
250
&res.Tensor (" x" ),
227
- &res.InputNoneTensor ( ),
251
+ &res.Tensor ( " bias " ),
228
252
&res.Tensor (" residual" ),
229
253
&res.Tensor (" w" ),
230
254
&res.InputNoneTensor (),
231
255
},
232
256
{&res.Tensor (" layer_norm_out" ),
233
- &res.Tensor (" residual_out " ),
257
+ &res.Tensor (" add_out " ),
234
258
&res.Tensor (" mean_out" ),
235
259
&res.Tensor (" variance_out" )});
236
260
}
@@ -248,16 +272,19 @@ class AddNormFusePass : public pir::PatternRewritePass {
248
272
// mul --->rms_norm
249
273
// w-----------------------------
250
274
bool is_half_weight = true ;
275
+ bool extra_add = true ;
251
276
ps.Add (paddle::drr::Create<RmsNormFusePattern>(context, !is_half_weight));
252
277
ps.Add (paddle::drr::Create<RmsNormFusePattern>(context, is_half_weight));
253
278
// x--------
254
279
// add-rms_norm ---> rms_norm
255
280
// residual-
256
- ps.Add (paddle::drr::Create<AddRmsNormFusePattern>(context));
281
+ ps.Add (paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add));
282
+ ps.Add (paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add));
257
283
// x--------
258
284
// add-layer_norm ----> fused_bias_residual_layernorm
259
285
// residual-
260
- ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(context));
286
+ ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(context, !extra_add));
287
+ ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add));
261
288
return ps;
262
289
}
263
290
};
0 commit comments