Skip to content

Commit c1f5c39

Browse files
authored
[PIR inference]update add_rms_norm pass (#63154)
* update add_rms_norm * update * fix timeout
1 parent 5c15379 commit c1f5c39

File tree

5 files changed

+99
-24
lines changed

5 files changed

+99
-24
lines changed

paddle/fluid/inference/api/paddle_analysis_config.h

+10-2
Original file line numberDiff line numberDiff line change
@@ -1250,8 +1250,16 @@ struct PD_INFER_DECL AnalysisConfig {
12501250
bool custom_pass_only = false);
12511251

12521252
///
1253-
/// \brief Set passmanager opt level.Pass level lower than
1254-
/// opt level which will be added to passmanager
1253+
/// \brief Set pir Optimization level.
1254+
/// \param opt_level The optimization level
1255+
/// The optimization Level in range [0,4], Default 2.
1256+
/// Higher optimization level allows the predictor to apply more passes.
1257+
/// If 0, Only basic pass support.
1258+
/// If 1, Additional support for functional pass.
1259+
/// If 2, Additional support the fusion logical pass,maybe affect precision
1260+
/// and speed.
1261+
/// If 3, support layout pass, etc.
1262+
/// If 4, add the radicaloptimization, maybe affect precision, etc.
12551263
///
12561264
void SetOptimizationLevel(int opt_level);
12571265

paddle/fluid/pir/drr/src/rewrite_pattern.cc

+14-2
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,11 @@ bool DrrRewritePattern::MatchFromOutputToInput(
324324
}
325325
return false;
326326
};
327-
327+
// Check whether Drr Tensor and IR Value is None.
328+
const auto& IsNoneTensorAndValue = [](const Tensor* drr_input_tensor,
329+
pir::Value ir_value) {
330+
return drr_input_tensor->is_none() && ir_value == nullptr;
331+
};
328332
// Step 1: Initialize DRR matched queue.
329333
bool matched = true;
330334
size_t step = 0;
@@ -348,7 +352,15 @@ bool DrrRewritePattern::MatchFromOutputToInput(
348352
auto ir_input_values = ir_node->operands_source();
349353
for (size_t i = 0; i < drr_input_tensors.size(); ++i) {
350354
if (drr_input_tensors[i]->is_none()) {
351-
continue;
355+
if (IsNoneTensorAndValue(drr_input_tensors[i], ir_input_values[i])) {
356+
continue;
357+
} else {
358+
VLOG(8) << drr_node->name() << "Match failed:drr_input[" << i
359+
<< "] != pir_intput[" << i << "] , drr_input_tensor[" << i
360+
<< "] is None.";
361+
matched = false;
362+
break;
363+
}
352364
}
353365
if (HasVisitedOperands(drr_input_tensors[i], ir_input_values[i])) {
354366
matched = false;

paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.cc

+42-15
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
3737

3838
std::string name() const override { return "RmsNormFusePattern"; }
3939

40-
uint32_t benefit() const override { return 2; }
40+
uint32_t benefit() const override { return 3; }
4141

4242
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
4343
paddle::drr::SourcePattern pat = ctx->SourcePattern();
@@ -139,7 +139,14 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
139139
};
140140

141141
class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
142+
private:
143+
const bool extra_add_;
144+
142145
public:
146+
explicit AddRmsNormFusePattern(bool extra_add) : extra_add_(extra_add) {}
147+
148+
uint32_t benefit() const override { return extra_add_ ? 2 : 1; }
149+
143150
std::string name() const override { return "AddRmsNormFusePattern"; }
144151

145152
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
@@ -157,16 +164,21 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
157164
});
158165
pat.Tensor("add_out") = add(pat.Tensor("x"), pat.Tensor("residual"));
159166
pat_rms_norm({&pat.Tensor("add_out"),
160-
&pat.InputNoneTensor(),
167+
&pat.Tensor("bias"),
161168
&pat.InputNoneTensor(),
162169
&pat.Tensor("w"),
163170
&pat.InputNoneTensor()},
164171
{&pat.Tensor("rms_norm_out"),
165172
&pat.Tensor("residual_out_0"),
166173
&pat.Tensor("inv_var_0")});
167-
174+
// TODO(bukejiyu) :DRR support matching placeholder op,
175+
// the following needs to be deleted
176+
if (extra_add_) {
177+
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
178+
pat.Tensor("add_out1") =
179+
add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
180+
}
168181
paddle::drr::ResultPattern res = pat.ResultPattern();
169-
170182
const auto &res_rms_norm =
171183
res.Op(paddle::dialect::RmsNormOp::name(),
172184
{
@@ -181,19 +193,25 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
181193
res_rms_norm(
182194
{
183195
&res.Tensor("x"),
184-
&res.InputNoneTensor(),
196+
&res.Tensor("bias"),
185197
&res.Tensor("residual"),
186198
&res.Tensor("w"),
187199
&res.InputNoneTensor(),
188200
},
189201
{&res.Tensor("rms_norm_out"),
190-
&res.Tensor("residual_out"),
202+
&res.Tensor("add_out"),
191203
&res.Tensor("inv_var")});
192204
}
193205
};
194206

195207
class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
208+
private:
209+
const bool extra_add_;
210+
196211
public:
212+
explicit AddLayerNormFusePattern(bool extra_add) : extra_add_(extra_add) {}
213+
214+
uint32_t benefit() const override { return extra_add_ ? 2 : 1; }
197215
std::string name() const override { return "AddLayerNormFusePattern"; }
198216

199217
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
@@ -204,11 +222,17 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
204222
{{"epsilon", pat.Attr("epsilon")},
205223
{"begin_norm_axis", pat.Attr("begin_norm_axis")}});
206224
pat.Tensor("add_out") = add(pat.Tensor("x"), pat.Tensor("residual"));
207-
layer_norm(
208-
{&pat.Tensor("add_out"), &pat.Tensor("w"), &pat.InputNoneTensor()},
209-
{&pat.Tensor("layer_norm_out"),
210-
&pat.Tensor("mean_out_0"),
211-
&pat.Tensor("variance_out_0")});
225+
layer_norm({&pat.Tensor("add_out"), &pat.Tensor("w"), &pat.Tensor("bias")},
226+
{&pat.Tensor("layer_norm_out"),
227+
&pat.Tensor("mean_out_0"),
228+
&pat.Tensor("variance_out_0")});
229+
// TODO(bukejiyu) :DRR support matching placeholder op,
230+
// the following needs to be deleted
231+
if (extra_add_) {
232+
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
233+
pat.Tensor("add_out1") =
234+
add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
235+
}
212236

213237
paddle::drr::ResultPattern res = pat.ResultPattern();
214238
const auto &fuse_layer_norm =
@@ -224,13 +248,13 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
224248
fuse_layer_norm(
225249
{
226250
&res.Tensor("x"),
227-
&res.InputNoneTensor(),
251+
&res.Tensor("bias"),
228252
&res.Tensor("residual"),
229253
&res.Tensor("w"),
230254
&res.InputNoneTensor(),
231255
},
232256
{&res.Tensor("layer_norm_out"),
233-
&res.Tensor("residual_out"),
257+
&res.Tensor("add_out"),
234258
&res.Tensor("mean_out"),
235259
&res.Tensor("variance_out")});
236260
}
@@ -248,16 +272,19 @@ class AddNormFusePass : public pir::PatternRewritePass {
248272
// mul --->rms_norm
249273
// w-----------------------------
250274
bool is_half_weight = true;
275+
bool extra_add = true;
251276
ps.Add(paddle::drr::Create<RmsNormFusePattern>(context, !is_half_weight));
252277
ps.Add(paddle::drr::Create<RmsNormFusePattern>(context, is_half_weight));
253278
// x--------
254279
// add-rms_norm ---> rms_norm
255280
// residual-
256-
ps.Add(paddle::drr::Create<AddRmsNormFusePattern>(context));
281+
ps.Add(paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add));
282+
ps.Add(paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add));
257283
// x--------
258284
// add-layer_norm ----> fused_bias_residual_layernorm
259285
// residual-
260-
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(context));
286+
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(context, !extra_add));
287+
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add));
261288
return ps;
262289
}
263290
};

test/ir/pir/fused_pass/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ foreach(target ${TEST_INTERP_CASES})
1919
endforeach()
2020

2121
set_tests_properties(test_pir_multihead_matmul_fuse_pass PROPERTIES TIMEOUT 100)
22+
set_tests_properties(test_add_norm_fuse_pass PROPERTIES TIMEOUT 300)
2223
if(WITH_CUTLASS)
2324
set_tests_properties(test_fused_weight_only_linear_pass PROPERTIES TIMEOUT
2425
300)

test/ir/pir/fused_pass/test_add_norm_fuse_pass.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_check_output(self):
176176
self.check_pass_correct(atol=1e-3, rtol=1e-3)
177177

178178

179-
class TestAddRmsNormFusePattern(TestRmsNormFusePattern):
179+
class TestAddRmsNormFusePatternWithResidual(TestRmsNormFusePattern):
180180
r"""
181181
x residual w
182182
| |
@@ -222,12 +222,25 @@ def sample_program(self):
222222
np.random.random(w_shape).astype(w_type)
223223
),
224224
)
225+
w1 = create_parameter(
226+
name="w1",
227+
shape=w_shape,
228+
dtype=w_type,
229+
initializer=paddle.nn.initializer.Assign(
230+
np.random.random([4096, 4096]).astype(
231+
w_type
232+
)
233+
),
234+
)
225235
add_out = paddle.add(residual, x)
236+
add_out_1 = add_out
226237
variance = add_out.pow(2).mean(-1, keepdim=True)
227238
add_out = (
228239
paddle.rsqrt(variance + 1e-6) * add_out
229240
)
230-
out = add_out * w
241+
mul_out = add_out * w
242+
matmul_out = paddle.matmul(mul_out, w1)
243+
out = paddle.add(add_out_1, matmul_out)
231244
out = paddle.assign(out)
232245
self.pass_list = ['add_norm_fuse_pass']
233246
self.feeds = {
@@ -240,7 +253,6 @@ def sample_program(self):
240253
}
241254
self.fetch_list = [out]
242255
self.valid_op_map = {
243-
"pd_op.add": 0,
244256
"pd_op.pow": 0,
245257
"pd_op.mean": 0,
246258
"pd_op.full": 0,
@@ -288,13 +300,26 @@ def sample_program(self):
288300
mean=0.0, std=2.0
289301
),
290302
)
303+
w1 = create_parameter(
304+
name="w1",
305+
shape=w_shape,
306+
dtype=w_type,
307+
initializer=paddle.nn.initializer.Assign(
308+
np.random.random([4096, 4096]).astype(
309+
w_type
310+
)
311+
),
312+
)
291313
add_out = paddle.add(residual, x)
314+
add_out_1 = add_out
292315
layer_norm = paddle.nn.LayerNorm(
293316
add_out.shape[-1:],
294317
epsilon=epilson,
295318
weight_attr=w_attr,
296319
)
297-
out = layer_norm(add_out)
320+
layer_norm_out = layer_norm(add_out)
321+
matmul_out = paddle.matmul(layer_norm_out, w1)
322+
out = paddle.add(add_out_1, matmul_out)
298323
out = paddle.assign(out)
299324
self.pass_list = ['add_norm_fuse_pass']
300325
self.feeds = {
@@ -307,13 +332,15 @@ def sample_program(self):
307332
}
308333
self.fetch_list = [out]
309334
self.valid_op_map = {
310-
"pd_op.add": 0,
311335
"pd_op.layer_norm": 0,
312336
"pd_op.fused_bias_residual_layernorm": 1,
313337
}
314338

315339
yield [main_prog, start_prog], False
316340

341+
def test_check_output(self):
342+
self.check_pass_correct(atol=1e-3, rtol=1e-3)
343+
317344

318345
if __name__ == "__main__":
319346
unittest.main()

0 commit comments

Comments
 (0)