Skip to content

Commit 0becc4a

Browse files
Your Nameyinfan98
Your Name
authored andcommitted
Add int4 quantize kernel
add int4_1 int4_2 FLAGS_logging_pir_py_code (PaddlePaddle#63981) * FLAGS_logging_pir_py_code * FLAGS_logging_pir_py_code_dir --------- Co-authored-by: jiahy0825 <jiahongyu@baidu.com> [Cleanup] Remove Flake8 config in `.editorconfig` (PaddlePaddle#64027) 【PIR Dist Op Reg No.19】 reg pull_box_sparse (PaddlePaddle#62982) * fix * fix * fix * fix * fix * fix * add test * add * fix * fix * add out * fix * codestyle * fix * fix backward * merge [Dy2St][PIR] Hold backward program in GradNode (PaddlePaddle#63694) Co-authored-by: xiongkun <xiongkun03@baidu.com> Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com> split test.cmake: add new test_cases.cmake (PaddlePaddle#64007) [PIR] Support sparse_slice and sparse_sum in pt (PaddlePaddle#64009) * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt option for WITH_CPP_TEST (PaddlePaddle#63896) * option for WITH_CPP_TEST * fix * Fix * Fix [PIR] Fix `attributes_num` of `SliceArrayOp` (PaddlePaddle#64013) [Dy2St] Use `full_graph=True` outside dy2st uts (part1) (PaddlePaddle#64058) [Dy2St] Use `full_graph=True` outside dy2st uts (part2) (PaddlePaddle#64059) fix typo (PaddlePaddle#64060) Co-authored-by: jiahy0825 <jiahongyu@baidu.com> update (PaddlePaddle#64042) Replace paddle/fluid/platform/device/gpu/gpu_dnn.h (PaddlePaddle#63819) * Fix * Fix * Fix Clean lookup_table_v2_op.h lookup_table_v2_op.cu (PaddlePaddle#64020) * Fix * ci refine GetTensorListFromArgs (PaddlePaddle#64045) Revert "【Hackathon 6th Fundable Projects 3 No.60】Remove fluid operator chunk_…" (PaddlePaddle#64050) This reverts commit 88b1a6e. [Prim][PIR] support floor_divide op forward in prim pir (PaddlePaddle#64023) * floor-div-dev * update test [CINN] Reconstruct shape_analysis (PaddlePaddle#63790) * reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type [XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (PaddlePaddle#64003) * [XPU] fix segmentfault caused by setting fix_seed_offset on XPU * cast attention_mask to float32 when necessary fix merge bug (PaddlePaddle#64069) 【Fix PIR Unittest No.125、147、481】Fix some 0D uts in PIR mode (part1) (PaddlePaddle#64064) [Prim][VJP]support autogen to remove unused composite in .yaml (PaddlePaddle#64054) * support autogen to remove unused composite in .yaml * fix bug [PIR] Fix typo `set_pit_tests_properties` -> `set_pir_tests_properties` (PaddlePaddle#64063) [Dy2St] Use `full_graph=True` outside dy2st uts (part3) (PaddlePaddle#64066) [PIR save/load] Open more tests for paddle.save and paddle.load (PaddlePaddle#64044) * open more tests for paddle.save and paddle.load * fix API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm (PaddlePaddle#63881) * update group_norm * update trt plugin * update trt plugin * fix trt plugin * fix trt plugin * fix test * fix test * fix ci windows inference * update kernel function names and add v2 test * fix * fix fp16 test Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_a…" (PaddlePaddle#64049) This reverts commit 2134ead. Clean paddle/fluid/operators/fused/attention_layer_norm.h (PaddlePaddle#64051) * Fix * Fix Replace operators::math to phi::math in fluid/operators (PaddlePaddle#63854) [CINN]Clean usless loop_reorder_aligment tactic (PaddlePaddle#63998) * [CINN]Clean usless loop_reorder_aligment tactic * fix source 【Hackathon 6th Fundable Projects 3 No.396】fluid operator yolo_box_head (PaddlePaddle#63783) * Fix * Fix * Fix * Fix * Fix 【Hackathon 6th Fundable Projects 3 No.240】fluid operator moe (PaddlePaddle#63929) 【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (PaddlePaddle#63936) * Fix * Fix * Fix * Fix [CINN] Remove useless log (PaddlePaddle#64052) [pir_save_load] add pir for test_jit_save_load.py (PaddlePaddle#63958) * add jit load.train * modify backward program lost * modify * combine eval and train * modify 8 case of jit.save.load * modify jit_save_load case * rename jit_save_load * change name all * modify timeout * modify new case * modify TestJitSaveLoadMultiMethods * modify cpu tensor no holder bug Flashattention support qkvpacked and varlen (PaddlePaddle#63289) * Flashattention support qkvpacked and varlen * fix codestyle * fix codestyle * FlashAttention kvReduceGQA Performance Optimization * Fix problem with windows * code clean * update third_party/flashattn * update errormsg and docs * update api * update doc * update doctest * update doc, test=document_fix * update doc, test=document_fix * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * update doc --------- Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> 【PIR Dist Op Reg No.20】 reg global_gather (PaddlePaddle#63867) * reg global_gather * reg global_gather * reg_global_gather * fix * fix * fix * fix conflict * fix conflict * Update ops_api_gen.py * Update ops_api_gen.py Fix backward program kwargs error when process inplace value (PaddlePaddle#63939) 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True fix (PaddlePaddle#63880) * support kwargs for recompute when use_reentrant == True * recover third party merge main lint delete printf change flash attn version
1 parent eb7c5d1 commit 0becc4a

File tree

264 files changed

+7793
-5088
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

264 files changed

+7793
-5088
lines changed

.editorconfig

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,12 @@ insert_final_newline = true
1515
[*.{c,cc,cxx,cpp,cu,cuh,h,hpp,hxx,kps}]
1616
indent_size = 2
1717

18-
[*.{py,java,r}]
18+
[*.{py,pyi,java,r,toml}]
1919
indent_size = 4
2020

2121
[Dockerfile.*]
2222
indent_size = 4
2323

24-
[.flake8]
25-
indent_size = 4
26-
2724
[*.go]
2825
indent_style = tab
2926
indent_size = 4

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ option(WITH_PIP_CUDA_LIBRARIES
6868
"Paddle uses the CUDA library provided by NVIDIA" OFF)
6969
option(WITH_NIGHTLY_BUILD
7070
"Compile nightly paddle whl package of the develop branch" OFF)
71+
option(WITH_CPP_TEST "Compile PaddlePaddle skip cpp test" ON)
7172
find_package(Git REQUIRED)
7273

7374
# config GIT_URL with github mirrors to speed up dependent repos clone

paddle/cinn/hlir/dialect/operator/ir/manual_op.cc

+12-12
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ void GroupOp::Print(pir::IrPrinter& printer) {
116116
}
117117

118118
bool GroupOp::InferSymbolicShape(
119-
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
120-
::pir::InferSymExprForBlock(*block(), shape_analysis);
119+
::pir::InferSymbolicShapeContext* infer_context) {
120+
::pir::InferSymExprForBlock(*block(), infer_context);
121121

122122
for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) {
123123
auto inner_yield_value = block()->back().operand_source(rst_idx);
124124
const auto& shape =
125-
shape_analysis->GetShapeOrDataForValue(inner_yield_value);
126-
shape_analysis->SetShapeOrDataForValue(result(rst_idx), shape);
125+
infer_context->GetShapeOrDataForValue(inner_yield_value);
126+
infer_context->SetShapeOrDataForValue(result(rst_idx), shape);
127127
}
128128

129129
if (VLOG_IS_ON(4)) {
@@ -204,16 +204,16 @@ void YieldStoreOp::Build(pir::Builder& builder,
204204
void YieldStoreOp::VerifySig() {}
205205

206206
bool YieldStoreOp::InferSymbolicShape(
207-
pir::ShapeConstraintIRAnalysis* shape_analysis) {
208-
shape_analysis->SetShapeOrDataForValue(
209-
result(0), shape_analysis->GetShapeOrDataForValue(operand_source(0)));
207+
pir::InferSymbolicShapeContext* infer_context) {
208+
infer_context->SetShapeOrDataForValue(
209+
result(0), infer_context->GetShapeOrDataForValue(operand_source(0)));
210210
return true;
211211
}
212212

213213
bool ConcatOp::InferSymbolicShape(
214-
pir::ShapeConstraintIRAnalysis* shape_analysis) {
214+
pir::InferSymbolicShapeContext* infer_context) {
215215
VLOG(4) << "Infer symbolic shape for cinn_op.concat";
216-
return ConcatOpInferSymbolicShape(this->operation(), shape_analysis);
216+
return ConcatOpInferSymbolicShape(this->operation(), infer_context);
217217
}
218218

219219
void ConcatOp::Build(pir::Builder& builder, // NOLINT
@@ -476,7 +476,7 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings(
476476
}
477477

478478
bool GenerateShapeOp::InferSymbolicShape(
479-
pir::ShapeConstraintIRAnalysis* shape_analysis) {
479+
pir::InferSymbolicShapeContext* infer_context) {
480480
const auto attr_dim_exprs = [&] {
481481
std::vector<symbol::DimExpr> dim_exprs{};
482482
pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs");
@@ -505,7 +505,7 @@ bool GenerateShapeOp::InferSymbolicShape(
505505
}();
506506
auto DimExprs4InputDim =
507507
[&](int input_idx) -> const symbol::ShapeOrDataDimExprs& {
508-
return shape_analysis->GetShapeOrDataForValue(
508+
return infer_context->GetShapeOrDataForValue(
509509
this->operand_source(input_idx));
510510
};
511511
auto DimExprs4SymbolName =
@@ -527,7 +527,7 @@ bool GenerateShapeOp::InferSymbolicShape(
527527
symbol::ShapeOrDataDimExprs shape_or_data_dim_exprs{
528528
symbol::TensorShapeOrDataDimExprs(shape, substituted_dim_exprs)};
529529

530-
shape_analysis->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);
530+
infer_context->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);
531531

532532
return true;
533533
}

paddle/cinn/hlir/dialect/operator/ir/manual_op.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class IR_API GroupOp
5353
pir::Block *block() const;
5454
std::vector<pir::Operation *> GetOperators() const;
5555

56-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
56+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
5757

5858
void VerifySig();
5959
void Print(pir::IrPrinter &printer); // NOLINT
@@ -102,7 +102,7 @@ class IR_API YieldStoreOp
102102

103103
void VerifySig();
104104

105-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
105+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
106106
};
107107

108108
class IR_API ConcatOp
@@ -123,7 +123,7 @@ class IR_API ConcatOp
123123

124124
void VerifySig() const {}
125125

126-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
126+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
127127
};
128128

129129
class IR_API SplitOp : public pir::Op<SplitOp> {
@@ -177,7 +177,7 @@ class IR_API GenerateShapeOp
177177

178178
pir::Value out() { return result(0); }
179179

180-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
180+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
181181

182182
static pir::Attribute ConvertSymbolBindingsToAttribute(
183183
pir::Builder &builder, const SymbolBindings &symbol_bindings); // NOLINT

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

+2-6
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,10 @@ void ApplyCinnPass(::pir::Program* program,
220220
ApplyPdToCinnPass(program, CreatePassManager);
221221
ApplyCinnPreprocessPass(program, CreatePassManager);
222222
ApplyBuildGroupOpPass(program, CreatePassManager);
223-
LOG(INFO) << "====[pir-to-py-code group-ops begin]===" << std::endl
224-
<< PirToPyCodeConverter().Convert(*program);
225-
LOG(INFO) << "====[pir-to-py-code group-ops end]===";
223+
PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program);
226224
ApplyGroupOpPass(program, CreatePassManager);
227225
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
228-
LOG(INFO) << "====[pir-to-py-code fusion-ops begin]===" << std::endl
229-
<< PirToPyCodeConverter().Convert(*program);
230-
LOG(INFO) << "====[pir-to-py-code fusion-ops end]===";
226+
PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program);
231227
LOG(INFO) << "FusionOp count before lowering : *****[ "
232228
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
233229
<< " ]*****";

paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc

+3-5
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@ class AddYieldStoreInFusionOpPattern
4545
auto orignal_base = op->operand_source(i);
4646
op->operand(i).set_source(store_op.result(0));
4747

48-
if (shape_analysis.HasShapeOrDataForValue(orignal_base)) {
49-
shape_analysis.SetShapeOrDataForValue(
50-
store_op.result(0),
51-
shape_analysis.GetShapeOrDataForValue(orignal_base));
52-
}
48+
shape_analysis.SetShapeOrDataForValue(
49+
store_op.result(0),
50+
shape_analysis.GetShapeOrDataForValue(orignal_base));
5351
}
5452

5553
return true;

paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ class BlockDimExprsAsserter {
144144
PADDLE_THROW(phi::errors::Unimplemented(
145145
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
146146
} else {
147-
bool infer_result = interface.InferSymbolicShape(shape_analysis.get());
147+
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
148+
// is done.
149+
bool infer_result = interface.InferSymbolicShape(
150+
shape_analysis->GetInferSymbolicShapeContext());
148151
PADDLE_ENFORCE_EQ(infer_result,
149152
true,
150153
::common::errors::PreconditionNotMet(

paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc

+6-10
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,9 @@ ::pir::GroupOpsVec CloneOps(
182182
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
183183

184184
for (size_t i = 0; i < op->num_results(); ++i) {
185-
if (shape_analysis.HasShapeOrDataForValue(op->result(i))) {
186-
shape_analysis.SetShapeOrDataForValue(
187-
new_op->result(i),
188-
shape_analysis.GetShapeOrDataForValue(op->result(i)));
189-
}
185+
shape_analysis.SetShapeOrDataForValue(
186+
new_op->result(i),
187+
shape_analysis.GetShapeOrDataForValue(op->result(i)));
190188
}
191189

192190
vec_new_op_list.push_back(new_op);
@@ -357,11 +355,9 @@ class CinnGroupClusterPattern
357355
// update ir mapping
358356
for (size_t i = 0; i < output_values.size(); ++i) {
359357
ir_mapping.Add(output_values[i], new_group_op->result(i));
360-
if (shape_analysis.HasShapeOrDataForValue(output_values[i])) {
361-
shape_analysis.SetShapeOrDataForValue(
362-
new_group_op->result(i),
363-
shape_analysis.GetShapeOrDataForValue(output_values[i]));
364-
}
358+
shape_analysis.SetShapeOrDataForValue(
359+
new_group_op->result(i),
360+
shape_analysis.GetShapeOrDataForValue(output_values[i]));
365361
}
366362
for (size_t i = 0; i < output_values.size(); ++i) {
367363
auto find_it = all_output_values.find(output_values[i]);

paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc

+10-12
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,16 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op,
3333
std::vector<int> shape = phi::vectorize<int>(
3434
output.type().dyn_cast<pir::DenseTensorType>().dims());
3535

36-
if (shape_analysis->HasShapeOrDataForValue(op->result(0))) {
37-
const auto& shape_info =
38-
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
39-
int temp_dim = -1;
40-
41-
for (size_t i = 0; i < shape_info.size(); ++i) {
42-
if (shape_info[i].isa<int64_t>()) {
43-
shape[i] = shape_info[i].Get<int64_t>();
44-
} else {
45-
shape[i] = temp_dim;
46-
temp_dim = 1;
47-
}
36+
const auto& shape_info =
37+
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
38+
int temp_dim = -1;
39+
40+
for (size_t i = 0; i < shape_info.size(); ++i) {
41+
if (shape_info[i].isa<int64_t>()) {
42+
shape[i] = shape_info[i].Get<int64_t>();
43+
} else {
44+
shape[i] = temp_dim;
45+
temp_dim = 1;
4846
}
4947
}
5048
return shape;

paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.cc

+5-9
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,11 @@ bool RemoveOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
5353
if (has_dynamic_shape) {
5454
auto& shape_analysis =
5555
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
56-
if (shape_analysis.HasShapeOrDataForValue(input) &&
57-
shape_analysis.HasShapeOrDataForValue(output)) {
58-
auto input_sym_shape =
59-
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
60-
auto output_sym_shape =
61-
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
62-
return input_sym_shape == output_sym_shape;
63-
}
64-
return false;
56+
auto input_sym_shape =
57+
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
58+
auto output_sym_shape =
59+
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
60+
return input_sym_shape == output_sym_shape;
6561
}
6662
return GetDims(input) == GetDims(output);
6763
};

paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,10 @@ void InferSymbolicShapeForSubgraph(
214214
auto infer_symbolic_shape_interface =
215215
op->dyn_cast<paddle::dialect::InferSymbolicShapeInterface>();
216216
if (infer_symbolic_shape_interface) {
217-
infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis);
217+
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
218+
// is done.
219+
infer_symbolic_shape_interface.InferSymbolicShape(
220+
shape_analysis->GetInferSymbolicShapeContext());
218221
} else {
219222
PADDLE_THROW(phi::errors::Unimplemented(
220223
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
@@ -348,7 +351,6 @@ bool ReplaceShapeOpsToGenerateShape(
348351
auto ShapeOrDataDimExprs4Value =
349352
[&shape_analysis](
350353
pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
351-
CHECK(shape_analysis->HasShapeOrDataForValue(value));
352354
return shape_analysis->GetShapeOrDataForValue(value);
353355
};
354356
std::optional<pir::Value> opt_generated_shape =

paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc

-14
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,6 @@ class DynamicToStaticConverter {
104104
}
105105

106106
bool Convert() {
107-
if (!IsSymbolFullyInfered()) {
108-
return false;
109-
}
110107
bool updated = false;
111108
VisitEachValue(fusion_op_, [&](pir::Value value) {
112109
updated |= UpdateValueShape(value);
@@ -116,16 +113,6 @@ class DynamicToStaticConverter {
116113
}
117114

118115
private:
119-
bool IsSymbolFullyInfered() {
120-
bool is_infered = true;
121-
VisitEachValue(fusion_op_, [&](pir::Value value) {
122-
if (!shape_analysis_->HasShapeOrDataForValue(value)) {
123-
is_infered = false;
124-
}
125-
});
126-
return is_infered;
127-
}
128-
129116
DimExpr4SymbolName InitDimExpr4SymbolName() {
130117
const auto* map = GetGlobalDynamicToStaticDimMap();
131118
CHECK(map->has_value());
@@ -178,7 +165,6 @@ class DynamicToStaticConverter {
178165

179166
bool UpdateValueShape(pir::Value value) {
180167
bool update = false;
181-
CHECK(shape_analysis_->HasShapeOrDataForValue(value));
182168
const auto& origin_shape = GetOriginValueShape(value);
183169
const auto& target_shape = GetTargetValueShape(value);
184170
PADDLE_ENFORCE_EQ(

paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc

+2-21
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ struct StaticDimToDynamicConverter {
150150
&pir::ShapeAnalysisManager::Instance().Get(
151151
this->fusion_op->GetParentProgram());
152152
ForEachValue([&](pir::Value value) {
153-
CHECK(shape_analysis->HasShapeOrDataForValue(value));
154153
const auto& origin_shape = GetOriginValueShape(value);
155154
const auto& target_shape = GetTargetValueShape(
156155
shape_analysis->GetShapeOrDataForValue(value).shape());
@@ -369,26 +368,8 @@ struct StaticDimToDynamicConverter {
369368
pir::Value value,
370369
int64_t constant,
371370
const std::string& symbol) {
372-
if (shape_analysis->HasShapeOrDataForValue(value)) {
373-
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
374-
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
375-
} else {
376-
auto& dims = value.type().dyn_cast<::pir::DenseTensorType>().dims();
377-
const auto& int_dims = ::common::vectorize<int>(dims);
378-
std::vector<symbol::DimExpr> old{};
379-
for (int dim : int_dims) {
380-
old.emplace_back(static_cast<std::int64_t>(dim));
381-
}
382-
const auto& opt_exprs =
383-
ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
384-
if (opt_exprs.has_value()) {
385-
return opt_exprs.value();
386-
} else {
387-
return symbol::ShapeOrDataDimExprs{
388-
symbol::TensorShapeOrDataDimExprs(old)};
389-
}
390-
}
391-
PADDLE_THROW(phi::errors::Fatal("Dead code"));
371+
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
372+
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
392373
}
393374

394375
template <typename ConverterT>

paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc

+8-14
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,14 @@ void SimplifyDimExpr(pir::Operation* module_op) {
101101

102102
VisitEachOp(module_op, [&](pir::Operation& op) {
103103
VisitEachValue(op, [&](pir::Value value) {
104-
if (!shape_analysis->HasShapeOrDataForValue(value)) {
105-
VLOG(4) << "SimplifyDimExpr: shape_analysis can't find ShapeOrData for "
106-
"value of the op:"
107-
<< op.name();
108-
} else {
109-
const symbol::ShapeOrDataDimExprs& shape_or_data =
110-
shape_analysis->GetShapeOrDataForValue(value);
111-
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
112-
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
113-
SimplifyShapeOrData(shape_or_data);
114-
VLOG(8) << op.name()
115-
<< " simplified_shape_or_data: " << simplified_shape_or_data;
116-
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
117-
}
104+
const symbol::ShapeOrDataDimExprs& shape_or_data =
105+
shape_analysis->GetShapeOrDataForValue(value);
106+
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
107+
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
108+
SimplifyShapeOrData(shape_or_data);
109+
VLOG(8) << op.name()
110+
<< " simplified_shape_or_data: " << simplified_shape_or_data;
111+
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
118112
});
119113
if (op.num_results() > 0) {
120114
pir::shape::SetShapeAttrForOp(

0 commit comments

Comments
 (0)