Skip to content

Commit 4907ed5

Browse files
authored
[CINN] Fix performance issue on gemm fusion in float32 datatype related to fused_gemm_epilogue_pass (#71226)
* [CINN] Fix performance issue on gemm fusion in float32 due to fuse_gemm_epilogue_pass. * Modified original fp32 unittest to fp16, in order to perform check. * polish code * Modified python pass unittest to perform proper checks.
1 parent cdd9ed6 commit 4907ed5

File tree

3 files changed

+40
-18
lines changed

3 files changed

+40
-18
lines changed

paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ class FusedLinearPattern
3636
bool MatchAndRewrite(paddle::dialect::MatmulOp matmul,
3737
pir::PatternRewriter &rewriter) const override {
3838
auto matmul_out = matmul->result(0);
39+
// The datatype(without auto-promote) of matmul should not be float32 type,
40+
// which may cause performance issue in some cases.
41+
if (pir::GetDataTypeFromValue(matmul.x()).isa<pir::Float32Type>()) {
42+
return false;
43+
}
44+
3945
// The result of matmul can only be uniquely used by an add OP.
4046
if (matmul_out.use_count() != 1) {
4147
return false;
@@ -99,6 +105,11 @@ class FusedLinearGradPattern
99105
pir::PatternRewriter &rewriter) const override {
100106
auto matmul_grad_out = matmul_grad->operand_source(2);
101107

108+
// The datatype(without auto-promote) of matmul should not be float32 type,
109+
// which may cause performance issue in some cases.
110+
if (pir::GetDataTypeFromValue(matmul_grad.x()).isa<pir::Float32Type>()) {
111+
return false;
112+
}
102113
paddle::dialect::AddGradOp add_grad;
103114
if (add_grad = matmul_grad_out.defining_op()
104115
->dyn_cast<paddle::dialect::AddGradOp>()) {
@@ -175,6 +186,11 @@ class FusedLinearGradSinglePattern
175186
pir::PatternRewriter &rewriter) const override {
176187
auto dout = matmul_grad->operand_source(2);
177188

189+
// The datatype(without auto-promote) of matmul should not be float32 type,
190+
// which may cause performance issue in some cases.
191+
if (pir::GetDataTypeFromValue(matmul_grad.x()).isa<pir::Float32Type>()) {
192+
return false;
193+
}
178194
if (pir::GetShapeFromValue(matmul_grad->operand_source(1)).size() != 2) {
179195
return false;
180196
}

test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,27 @@
2525

2626
void BuildProgram(pir::Builder &builder) { // NOLINT
2727
paddle::dialect::FullOp full_input_op1 =
28-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 512, 64},
29-
1.5);
28+
builder.Build<paddle::dialect::FullOp>(
29+
std::vector<int64_t>{1, 512, 64}, 1.5, phi::DataType::FLOAT16);
3030
// linear 1
3131
paddle::dialect::FullOp full_weight_op1 =
32-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 64}, 1.5);
32+
builder.Build<paddle::dialect::FullOp>(
33+
std::vector<int64_t>{64, 64}, 1.5, phi::DataType::FLOAT16);
3334
paddle::dialect::FullOp full_bias_op1 =
34-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0);
35+
builder.Build<paddle::dialect::FullOp>(
36+
std::vector<int64_t>{64}, 1.0, phi::DataType::FLOAT16);
3537
paddle::dialect::MatmulOp matmul_op1 =
3638
builder.Build<paddle::dialect::MatmulOp>(full_input_op1.out(),
3739
full_weight_op1.out());
3840
paddle::dialect::AddOp add_op1 = builder.Build<paddle::dialect::AddOp>(
3941
matmul_op1.out(), full_bias_op1.out());
4042
// linear 2
4143
paddle::dialect::FullOp full_weight_op2 =
42-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 128},
43-
1.5);
44+
builder.Build<paddle::dialect::FullOp>(
45+
std::vector<int64_t>{64, 128}, 1.5, phi::DataType::FLOAT16);
4446
paddle::dialect::FullOp full_bias_op2 =
45-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128}, 1.0);
47+
builder.Build<paddle::dialect::FullOp>(
48+
std::vector<int64_t>{128}, 1.0, phi::DataType::FLOAT16);
4649
paddle::dialect::MatmulOp matmul_op2 =
4750
builder.Build<paddle::dialect::MatmulOp>(add_op1.out(),
4851
full_weight_op2.out());
@@ -52,10 +55,11 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
5255
builder.Build<paddle::dialect::ReluOp>(add_op2.out());
5356
// linear 3
5457
paddle::dialect::FullOp full_weight_op3 =
55-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128, 64},
56-
1.5);
58+
builder.Build<paddle::dialect::FullOp>(
59+
std::vector<int64_t>{128, 64}, 1.5, phi::DataType::FLOAT16);
5760
paddle::dialect::FullOp full_bias_op3 =
58-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0);
61+
builder.Build<paddle::dialect::FullOp>(
62+
std::vector<int64_t>{64}, 1.0, phi::DataType::FLOAT16);
5963
paddle::dialect::MatmulOp matmul_op3 =
6064
builder.Build<paddle::dialect::MatmulOp>(relu_op.out(),
6165
full_weight_op3.out());
@@ -65,9 +69,11 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
6569
builder.Build<paddle::dialect::GeluOp>(add_op3.out());
6670
// linear 4
6771
paddle::dialect::FullOp full_weight_op4 =
68-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 64}, 1.5);
72+
builder.Build<paddle::dialect::FullOp>(
73+
std::vector<int64_t>{64, 64}, 1.5, phi::DataType::FLOAT16);
6974
paddle::dialect::FullOp full_bias_op4 =
70-
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0);
75+
builder.Build<paddle::dialect::FullOp>(
76+
std::vector<int64_t>{64}, 1.0, phi::DataType::FLOAT16);
7177
paddle::dialect::MatmulOp matmul_op4 =
7278
builder.Build<paddle::dialect::MatmulOp>(gelu_op1.out(),
7379
full_weight_op4.out());
@@ -78,7 +84,7 @@ void BuildProgram(pir::Builder &builder) { // NOLINT
7884

7985
// backward
8086
paddle::dialect::FullOp full_grad_op = builder.Build<paddle::dialect::FullOp>(
81-
std::vector<int64_t>{1, 512, 64}, 1.0);
87+
std::vector<int64_t>{1, 512, 64}, 1.0, phi::DataType::FLOAT16);
8288

8389
paddle::dialect::GeluGradOp gelu_op2_grad =
8490
builder.Build<paddle::dialect::GeluGradOp>(

test/ir/pir/fused_pass/test_fused_gemm_epilogue_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,20 @@ def test_fused_gemm_epilogue_add(self):
5252
main_program = paddle.base.Program()
5353
with paddle.pir_utils.IrGuard():
5454
x_np = np.random.normal(3, 2.5, size=(1024, 1024)).astype(
55-
np.float32
55+
np.float16
5656
)
5757
y_np = x_np
58-
z_np = np.random.normal(3, 2.5, size=(1024)).astype(np.float32)
58+
z_np = np.random.normal(3, 2.5, size=(1024)).astype(np.float16)
5959
with paddle.base.program_guard(main_program):
6060
with pir_op_role_guard(0), pir_chunk_id_guard(0):
6161
x_ = paddle.static.data(
62-
name="x", shape=[1024, 1024], dtype="float32"
62+
name="x", shape=[1024, 1024], dtype="float16"
6363
)
6464
y_ = paddle.static.data(
65-
name="y", shape=[1024, 1024], dtype="float32"
65+
name="y", shape=[1024, 1024], dtype="float16"
6666
)
6767
z_ = paddle.static.data(
68-
name="z", shape=[1024], dtype="float32"
68+
name="z", shape=[1024], dtype="float16"
6969
)
7070
x_.stop_gradient = False
7171
y_.stop_gradient = False

0 commit comments

Comments
 (0)