Skip to content

Commit 803909e

Browse files
authored
[CINN] Clean some code (#69520)
* fused_attention * Fix compile * refine code * refine code * clean useless code * clean useless code * use const && some todo
1 parent ea2bc4d commit 803909e

File tree

4 files changed

+29
-202
lines changed

4 files changed

+29
-202
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,8 +2385,9 @@ bool WeightDequantizeOpInferSymbolicShape(
23852385
common::errors::InvalidArgument(
23862386
"The x tensor of dequantize op must be 2D, but got[%d]",
23872387
x_shape.size()));
2388-
int group_size = op->attribute<pir::Int32Attribute>("group_size").data();
2389-
std::string algo = op->attribute<pir::StrAttribute>("algo").AsString();
2388+
const int group_size =
2389+
op->attribute<pir::Int32Attribute>("group_size").data();
2390+
const std::string algo = op->attribute<pir::StrAttribute>("algo").AsString();
23902391
PADDLE_ENFORCE_EQ(
23912392
(group_size == -1 || group_size == 64 || group_size == 128),
23922393
true,

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,14 +1676,14 @@ bool FusedAttentionOpInferSymbolicShape(
16761676
symbol::DimExpr dim_head = 0;
16771677
symbol::DimExpr hidden_size = 0;
16781678
symbol::DimExpr nranks = 1;
1679-
bool transpose_qkv_wb =
1679+
const bool transpose_qkv_wb =
16801680
op->attribute<pir::BoolAttribute>("transpose_qkv_wb").data();
1681-
int num_heads_ = op->attribute<pir::Int32Attribute>("num_heads").data();
1681+
const int num_heads_ = op->attribute<pir::Int32Attribute>("num_heads").data();
16821682
symbol::DimExpr num_heads = symbol::DimExpr(num_heads_);
1683-
int ring_id = op->attribute<pir::Int32Attribute>("ring_id").data();
1684-
bool pre_layer_norm =
1683+
const int ring_id = op->attribute<pir::Int32Attribute>("ring_id").data();
1684+
const bool pre_layer_norm =
16851685
op->attribute<pir::BoolAttribute>("pre_layer_norm").data();
1686-
bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();
1686+
const bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();
16871687
if (transpose_qkv_wb) {
16881688
PADDLE_ENFORCE_EQ(qkv_weight_shape.size(),
16891689
2,
@@ -1747,7 +1747,7 @@ bool FusedAttentionOpInferSymbolicShape(
17471747
symbol::ShapeOrDataDimExprs{
17481748
symbol::TensorShapeOrDataDimExprs(x_shape)});
17491749
} else {
1750-
// The follwing three code used to set unoptional output value.
1750+
// The following three code used to set unoptional output value.
17511751
// Now it's result related to the infermeta.
17521752
infer_context->SetSymbolForValueByStaticShape(op->result(0));
17531753
infer_context->SetSymbolForValueByStaticShape(op->result(1));
@@ -1793,6 +1793,8 @@ bool FusedAttentionOpInferSymbolicShape(
17931793
x_shape[1],
17941794
symbol::DimExpr(3) * num_heads * dim_head})});
17951795
} else {
1796+
// The following code used to set unoptional output value.
1797+
// Now it's result related to the infermeta.
17961798
infer_context->SetSymbolForValueByStaticShape(op->result(4));
17971799
}
17981800
} else {
@@ -1815,6 +1817,8 @@ bool FusedAttentionOpInferSymbolicShape(
18151817
num_heads,
18161818
dim_head})});
18171819
} else {
1820+
// The following code used to set unoptional output value.
1821+
// Now it's result related to the infermeta.
18181822
infer_context->SetSymbolForValueByStaticShape(op->result(4));
18191823
}
18201824
}
@@ -1863,7 +1867,7 @@ bool FusedAttentionOpInferSymbolicShape(
18631867
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
18641868
{x_shape[0], num_heads, x_shape[1], out_seq_len})});
18651869
} else {
1866-
// The follwing code used to set unoptional output value.
1870+
// The following code used to set unoptional output value.
18671871
// Now it's result related to the infermeta.
18681872
infer_context->SetSymbolForValueByStaticShape(op->result(11));
18691873
}
@@ -1878,7 +1882,7 @@ bool FusedAttentionOpInferSymbolicShape(
18781882
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
18791883
{x_shape[0], num_heads, x_shape[1], out_seq_len})});
18801884
} else {
1881-
// The follwing code used to set unoptional output value.
1885+
// The following code used to set unoptional output value.
18821886
// Now it's result related to the infermeta.
18831887
infer_context->SetSymbolForValueByStaticShape(op->result(9));
18841888
}
@@ -1910,7 +1914,7 @@ bool FusedAttentionOpInferSymbolicShape(
19101914
symbol::ShapeOrDataDimExprs{
19111915
symbol::TensorShapeOrDataDimExprs(x_shape)});
19121916
} else {
1913-
// The follwing code used to set unoptional output value.
1917+
// The following code used to set unoptional output value.
19141918
// Now it's result related to the infermeta.
19151919
infer_context->SetSymbolForValueByStaticShape(op->result(14));
19161920
}
@@ -3350,9 +3354,8 @@ bool LstmOpInferSymbolicShape(pir::Operation *op,
33503354
const symbol::ShapeOrDataDimExprs &bias_shape_or_data =
33513355
infer_context->GetShapeOrDataForValue(op->operand_source(4));
33523356
const auto &bias_shape = bias_shape_or_data.shape();
3353-
bool use_peepholes =
3357+
const bool use_peepholes =
33543358
op->attribute<pir::BoolAttribute>("use_peepholes").data();
3355-
bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();
33563359
PADDLE_ENFORCE_EQ(
33573360
input_shape.size(),
33583361
2,
@@ -3394,19 +3397,15 @@ bool LstmOpInferSymbolicShape(pir::Operation *op,
33943397
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_shape)};
33953398
infer_context->SetShapeOrDataForValue(op->result(0), out_shape_or_data);
33963399
infer_context->SetShapeOrDataForValue(op->result(1), out_shape_or_data);
3397-
if (!is_test) {
3398-
infer_context->SetShapeOrDataForValue(
3399-
op->result(2),
3400-
symbol::ShapeOrDataDimExprs{
3401-
symbol::TensorShapeOrDataDimExprs(input_shape)});
3402-
infer_context->SetShapeOrDataForValue(op->result(3), out_shape_or_data);
3403-
} else {
3404-
infer_context->SetShapeOrDataForValue(
3405-
op->result(2),
3406-
symbol::ShapeOrDataDimExprs{
3407-
symbol::TensorShapeOrDataDimExprs(input_shape)});
3408-
infer_context->SetShapeOrDataForValue(op->result(3), out_shape_or_data);
3409-
}
3400+
3401+
// Based on the kernel and infermeta, the inferred results are the same
3402+
// regardless of whether is_test is true or false.
3403+
infer_context->SetShapeOrDataForValue(
3404+
op->result(2),
3405+
symbol::ShapeOrDataDimExprs{
3406+
symbol::TensorShapeOrDataDimExprs(input_shape)});
3407+
infer_context->SetShapeOrDataForValue(op->result(3), out_shape_or_data);
3408+
34103409
return true;
34113410
}
34123411

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4454,8 +4454,9 @@ bool WeightQuantizeOpInferSymbolicShape(
44544454
x_shape_1));
44554455
}
44564456

4457-
int group_size = op->attribute<pir::Int32Attribute>("group_size").data();
4458-
std::string algo = op->attribute<pir::StrAttribute>("algo").AsString();
4457+
const int group_size =
4458+
op->attribute<pir::Int32Attribute>("group_size").data();
4459+
const std::string algo = op->attribute<pir::StrAttribute>("algo").AsString();
44594460
PADDLE_ENFORCE_EQ(
44604461
((group_size == -1) || (group_size == 64) || (group_size == 128)),
44614462
true,

test/legacy_test/test_fused_attention_op_api.py

Lines changed: 0 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,6 @@
2020
from paddle.incubate.nn.layer.fused_transformer import FusedMultiHeadAttention
2121

2222

23-
def check_symbolic_result(program, fetch_vars, outs, op_type):
24-
if paddle.base.libpaddle.pir.all_ops_defined_symbol_infer(program):
25-
shape_analysis = (
26-
paddle.base.libpaddle.pir.get_shape_constraint_ir_analysis(program)
27-
)
28-
for i, var in enumerate(fetch_vars):
29-
if var.is_dense_tensor_type() or var.is_selected_row_type():
30-
shape_or_data = shape_analysis.get_shape_or_data_for_var(var)
31-
expect_shape = outs[i].shape
32-
expect_data = []
33-
if not shape_or_data.is_equal(expect_shape, expect_data):
34-
raise AssertionError(
35-
f"The shape or data of Operator {op_type}'s result is different from expected."
36-
)
37-
else:
38-
pass
39-
40-
4123
def fc(x, weight):
4224
return np.matmul(x, weight)
4325

@@ -425,20 +407,6 @@ def run_static(self):
425407
fused_attn.pre_ln_scale,
426408
],
427409
)
428-
fetch_list = exe._check_fetch_list(
429-
[
430-
final_out,
431-
fused_attn.qkv_weight,
432-
fused_attn.linear_weight,
433-
fused_attn.pre_ln_scale,
434-
]
435-
)
436-
check_symbolic_result(
437-
paddle.static.default_main_program(),
438-
fetch_list,
439-
[out, qkv_weight, out_linear_weight, ln_scale],
440-
'fused_attention',
441-
)
442410
else:
443411
out, qkv_weight, out_linear_weight, ln_2_scale = exe.run(
444412
paddle.static.default_main_program(),
@@ -450,20 +418,6 @@ def run_static(self):
450418
fused_attn.ln_scale,
451419
],
452420
)
453-
fetch_list = exe._check_fetch_list(
454-
[
455-
final_out,
456-
fused_attn.qkv_weight,
457-
fused_attn.linear_weight,
458-
fused_attn.ln_scale,
459-
]
460-
)
461-
check_symbolic_result(
462-
paddle.static.default_main_program(),
463-
fetch_list,
464-
[out, qkv_weight, out_linear_weight, ln_2_scale],
465-
'fused_attention',
466-
)
467421
else:
468422
if self.pre_layer_norm:
469423
(
@@ -487,31 +441,6 @@ def run_static(self):
487441
fused_attn.pre_ln_bias,
488442
],
489443
)
490-
fetch_list = exe._check_fetch_list(
491-
[
492-
final_out,
493-
fused_attn.qkv_weight,
494-
fused_attn.qkv_bias,
495-
fused_attn.linear_weight,
496-
fused_attn.linear_bias,
497-
fused_attn.pre_ln_scale,
498-
fused_attn.pre_ln_bias,
499-
]
500-
)
501-
check_symbolic_result(
502-
paddle.static.default_main_program(),
503-
fetch_list,
504-
[
505-
out,
506-
qkv_weight,
507-
qkv_bias,
508-
out_linear_weight,
509-
linear_bias,
510-
ln_scale,
511-
ln_bias,
512-
],
513-
'fused_attention',
514-
)
515444
else:
516445
(
517446
out,
@@ -534,31 +463,6 @@ def run_static(self):
534463
fused_attn.ln_bias,
535464
],
536465
)
537-
fetch_list = exe._check_fetch_list(
538-
[
539-
final_out,
540-
fused_attn.qkv_weight,
541-
fused_attn.qkv_bias,
542-
fused_attn.linear_weight,
543-
fused_attn.linear_bias,
544-
fused_attn.ln_scale,
545-
fused_attn.ln_bias,
546-
]
547-
)
548-
check_symbolic_result(
549-
paddle.static.default_main_program(),
550-
fetch_list,
551-
[
552-
out,
553-
qkv_weight,
554-
qkv_bias,
555-
out_linear_weight,
556-
linear_bias,
557-
ln_2_scale,
558-
ln_2_bias,
559-
],
560-
'fused_attention',
561-
)
562466
else:
563467
if self.bias_attr is False:
564468
if self.pre_layer_norm:
@@ -574,20 +478,6 @@ def run_static(self):
574478
fused_attn.pre_ln_scale,
575479
],
576480
)
577-
fetch_list = exe._check_fetch_list(
578-
[
579-
final_out,
580-
fused_attn.qkv_weight,
581-
fused_attn.linear_weight,
582-
fused_attn.pre_ln_scale,
583-
]
584-
)
585-
check_symbolic_result(
586-
paddle.static.default_main_program(),
587-
fetch_list,
588-
[out, qkv_weight, out_linear_weight, ln_scale],
589-
'fused_attention',
590-
)
591481
else:
592482
out, qkv_weight, out_linear_weight, ln_2_scale = exe.run(
593483
paddle.static.default_main_program(),
@@ -601,20 +491,6 @@ def run_static(self):
601491
fused_attn.ln_scale,
602492
],
603493
)
604-
fetch_list = exe._check_fetch_list(
605-
[
606-
final_out,
607-
fused_attn.qkv_weight,
608-
fused_attn.linear_weight,
609-
fused_attn.ln_scale,
610-
]
611-
)
612-
check_symbolic_result(
613-
paddle.static.default_main_program(),
614-
fetch_list,
615-
[out, qkv_weight, out_linear_weight, ln_2_scale],
616-
'fused_attention',
617-
)
618494
else:
619495
if self.pre_layer_norm:
620496
(
@@ -640,31 +516,6 @@ def run_static(self):
640516
fused_attn.pre_ln_bias,
641517
],
642518
)
643-
fetch_list = exe._check_fetch_list(
644-
[
645-
final_out,
646-
fused_attn.qkv_weight,
647-
fused_attn.qkv_bias,
648-
fused_attn.linear_weight,
649-
fused_attn.linear_bias,
650-
fused_attn.pre_ln_scale,
651-
fused_attn.pre_ln_bias,
652-
]
653-
)
654-
check_symbolic_result(
655-
paddle.static.default_main_program(),
656-
fetch_list,
657-
[
658-
out,
659-
qkv_weight,
660-
qkv_bias,
661-
out_linear_weight,
662-
linear_bias,
663-
ln_scale,
664-
ln_bias,
665-
],
666-
'fused_attention',
667-
)
668519
else:
669520
(
670521
out,
@@ -689,31 +540,6 @@ def run_static(self):
689540
fused_attn.ln_bias,
690541
],
691542
)
692-
fetch_list = exe._check_fetch_list(
693-
[
694-
final_out,
695-
fused_attn.qkv_weight,
696-
fused_attn.qkv_bias,
697-
fused_attn.linear_weight,
698-
fused_attn.linear_bias,
699-
fused_attn.ln_scale,
700-
fused_attn.ln_bias,
701-
]
702-
)
703-
check_symbolic_result(
704-
paddle.static.default_main_program(),
705-
fetch_list,
706-
[
707-
out,
708-
qkv_weight,
709-
qkv_bias,
710-
out_linear_weight,
711-
linear_bias,
712-
ln_2_scale,
713-
ln_2_bias,
714-
],
715-
'fused_attention',
716-
)
717543
return (
718544
out,
719545
qkv_weight,

0 commit comments

Comments
 (0)