Skip to content

Commit 7bbd2e0

Browse files
authored
Merge branch 'develop' into fix_randomness_in_fusion
2 parents d7370f6 + c3ba53c commit 7bbd2e0

File tree

74 files changed

+886
-383
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+886
-383
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ repos:
5353
rev: v1.27.3
5454
hooks:
5555
- id: typos
56-
args: []
56+
args: [--force-exclude]
5757
# For Python files
5858
- repo: https://github.com/psf/black-pre-commit-mirror
5959
rev: 24.8.0

_typos.toml

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
[files]
2+
# The following files will be excluded from spell check during commits
3+
extend-exclude = [
4+
"test/dataset/imikolov_test.py"
5+
]
6+
17
[default.extend-words]
28
# PaddlePaddle specific words
39
lod = "lod"
@@ -71,7 +77,6 @@ Prepar = 'Prepar'
7177
precent = 'precent'
7278
sheduler = 'sheduler'
7379
outpus = 'outpus'
74-
atrribute = 'atrribute'
7580
normlize = 'normlize'
7681
Costum = 'Costum'
7782
differnt = 'differnt'
@@ -154,7 +159,6 @@ CACH = 'CACH'
154159
endianess = 'endianess'
155160
VAILD = 'VAILD'
156161
ues = 'ues'
157-
aer = 'aer'
158162
elemenents = 'elemenents'
159163
CANN = 'CANN'
160164
pathes = 'pathes'
@@ -293,10 +297,8 @@ wiil = 'wiil'
293297
configurated = 'configurated'
294298
perfome = 'perfome'
295299
consructor = 'consructor'
296-
attribtue = 'attribtue'
297300
quitted = 'quitted'
298301
attribtes = 'attribtes'
299-
automatical = 'automatical'
300302
orignal = 'orignal'
301303
furture = 'furture'
302304
Indext = 'Indext'
@@ -467,7 +469,6 @@ channnel = 'channnel'
467469
Suger = 'Suger'
468470
Actural = 'Actural'
469471
subsituted = 'subsituted'
470-
automaticly = 'automaticly'
471472
Minium = 'Minium'
472473
sequnece = 'sequnece'
473474
payed = 'payed'
@@ -615,7 +616,6 @@ theads = 'theads'
615616
postive = 'postive'
616617
progrss = 'progrss'
617618
diffrent = 'diffrent'
618-
attritube = 'attritube'
619619
compability = 'compability'
620620
hge = 'hge'
621621
Funcion = 'Funcion'
@@ -772,7 +772,6 @@ Localy = 'Localy'
772772
PARM = 'PARM'
773773
thi = 'thi'
774774
Oll = 'Oll'
775-
Auxillary = 'Auxillary'
776775
Infor = 'Infor'
777776
statment = 'statment'
778777
varn = 'varn'

cmake/external/pybind11.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ set(PYBIND_PATCH_COMMAND "")
2727
if(LINUX
2828
AND (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
2929
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9)
30-
set(PYBIND_TAG v2.12.0)
30+
set(PYBIND_TAG v2.13.6)
3131
file(TO_NATIVE_PATH
3232
${PADDLE_SOURCE_DIR}/patches/pybind/detail/internals.h.patch native_dst)
3333
# Note: [Why calling some `git` commands before `patch`?]

paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc

+11-5
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ void CollectSymbolNames(const std::vector<symbol::DimExpr>& dim_exprs,
539539

540540
template <typename SymbolBindingsT>
541541
void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
542-
const std::set<std::string>& symbol_names,
542+
std::set<std::string>* remain_symbol_names_to_bind,
543543
int in_tensor_idx,
544544
SymbolBindings* symbol_bindings) {
545545
for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size();
@@ -550,7 +550,10 @@ void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
550550
"The type of dim_expr is not atomic"));
551551
if (!dim_expr.isa<std::string>()) continue;
552552
const auto& sym_name = dim_expr.dyn_cast<std::string>();
553-
if (symbol_names.find(sym_name) == symbol_names.end()) continue;
553+
if (remain_symbol_names_to_bind->find(sym_name) ==
554+
remain_symbol_names_to_bind->end())
555+
continue;
556+
remain_symbol_names_to_bind->erase(sym_name);
554557
symbol_bindings->emplace_back(SymbolBindingsT{
555558
/*.symbol_name=*/sym_name,
556559
/*.input_tensor_idx=*/in_tensor_idx,
@@ -564,14 +567,17 @@ void GenerateSymbolBindings(
564567
const std::vector<pir::Value>& input_tensors,
565568
const std::set<std::string>& symbol_names,
566569
SymbolBindings* symbol_bindings) {
570+
std::set<std::string> remain_symbol_names_to_bind = symbol_names;
567571
for (int i = 0; i < input_tensors.size(); ++i) {
568572
const auto& input_tensor = input_tensors.at(i);
569573
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
570574
AppendSymbolBindings<ShapeSymbolBinding>(
571-
dim_exprs.shape(), symbol_names, i, symbol_bindings);
575+
dim_exprs.shape(), &remain_symbol_names_to_bind, i, symbol_bindings);
572576
if (dim_exprs.data().has_value()) {
573-
AppendSymbolBindings<DataSymbolBinding>(
574-
dim_exprs.data().value(), symbol_names, i, symbol_bindings);
577+
AppendSymbolBindings<DataSymbolBinding>(dim_exprs.data().value(),
578+
&remain_symbol_names_to_bind,
579+
i,
580+
symbol_bindings);
575581
}
576582
}
577583
}

paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.cc

+9
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class MergeParallelMatmulPattern
4444
if (!op->dyn_cast<paddle::dialect::MatmulOp>()) {
4545
return false;
4646
}
47+
4748
bool trans_x =
4849
op->attribute("transpose_x").dyn_cast<pir::BoolAttribute>().data();
4950
bool trans_y =
@@ -54,6 +55,10 @@ class MergeParallelMatmulPattern
5455
return false;
5556
}
5657

58+
auto IsFirstInput = [&](pir::Operation* op, pir::Value in_x) -> bool {
59+
return in_x == op->operand_source(0);
60+
};
61+
5762
auto VectorPrefixEqual = [](const std::vector<std::int64_t>& a,
5863
const std::vector<std::int64_t>& b) {
5964
return std::vector<std::int64_t>(a.begin(), a.end() - 1) ==
@@ -74,6 +79,10 @@ class MergeParallelMatmulPattern
7479
if (!ValidMatmulTranspose(it->owner())) {
7580
continue;
7681
}
82+
83+
if (!IsFirstInput(it->owner(), input_x)) {
84+
continue;
85+
}
7786
if (!pre_dim.has_value()) {
7887
pre_dim = ::common::vectorize(
7988
it->owner()

paddle/cinn/hlir/framework/pir/trivial_op_util.cc

+75-20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h"
1616

17+
#include "paddle/cinn/common/dim_expr_converter.h"
1718
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
1819
#include "paddle/cinn/hlir/framework/compile_error.h"
1920
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
@@ -547,9 +548,6 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
547548
* remove it in axes.bind()
548549
*/
549550
const auto& f = [=](const ir::Expr& e) -> ir::Expr {
550-
VLOG(4) << "Start RemoveVarInScheduleBlockRealize(" << target_vars << ", "
551-
<< replaced_expr << ")";
552-
VLOG(4) << " Input is " << e;
553551
PADDLE_ENFORCE_NE(
554552
e.As<ir::ScheduleBlockRealize>(),
555553
nullptr,
@@ -562,22 +560,11 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
562560
auto block_bound_vars = copied_ir.As<ir::ScheduleBlockRealize>()
563561
->schedule_block.As<ir::ScheduleBlock>()
564562
->iter_vars;
565-
for (const auto& i_var : schedule_block_iter_vars) {
566-
PADDLE_ENFORCE_EQ(
567-
i_var.is_var(),
568-
true,
569-
::common::errors::InvalidArgument("RemoveVarInScheduleBlockRealize: "
570-
"axes.bind rhs is is not a Var."));
571-
}
572563
// find replace idx
573564
int target_idx = -1;
574565
for (int i = 0; i < schedule_block_iter_vars.size(); ++i) {
575-
VLOG(4) << "RemoveVarInScheduleBlockRealize: compare with "
576-
<< schedule_block_iter_vars[i] << " vs " << target_vars
577-
<< ", and equality is: "
578-
<< (schedule_block_iter_vars[i].as_var()->name ==
579-
target_vars->name);
580-
if (schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
566+
if (schedule_block_iter_vars[i].is_var() &&
567+
schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
581568
target_idx = i;
582569
}
583570
}
@@ -688,8 +675,6 @@ ExprTransformer RemoveOneTransformer(int one) {
688675
.GetSingle(copied);
689676
const ir::Expr& target_block =
690677
ExprSetFinderUtils::DirectlyFather(copied).GetSingle(target_for);
691-
VLOG(4) << "RemoveOneTransformer: directly target_block of for is "
692-
<< target_block;
693678
if (target_block.As<ir::ScheduleBlockRealize>() != nullptr) {
694679
VLOG(4) << "RemoveOneTransformer: father block is root realize";
695680
ir::Expr shedule_block =
@@ -708,7 +693,6 @@ ExprTransformer RemoveOneTransformer(int one) {
708693
shedule_block.As<ir::ScheduleBlock>()->body = for_body;
709694
}
710695
} else if (target_block.As<ir::Block>() != nullptr) {
711-
VLOG(4) << "RemoveOneTransformer: father block is Block";
712696
std::vector<ir::Expr> new_bodies;
713697
for (const auto& expr : target_block.As<ir::Block>()->stmts) {
714698
if (expr != target_for) {
@@ -728,7 +712,6 @@ ExprTransformer RemoveOneTransformer(int one) {
728712
"RemoveOneTransformer: target for father should be a ir::Block or "
729713
"ir::ScheduleBlockRealize."));
730714
}
731-
VLOG(4) << "Remove Var to 0 in ScheduleBlockRealizer: " << copied;
732715
// Remove var to 0 in ScheduleBlockRealizer
733716
InplaceMutateSingleExpr(
734717
&copied,
@@ -949,6 +932,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root) {
949932

950933
ir::Expr GetBodyBlock(const ir::Expr& root) {
951934
const auto& iters = GetNonReduceLoopVars(root);
935+
if (iters.empty()) {
936+
return ir::Block::Make(
937+
{ExprSetFinderUtils::ChildScheduleBlockRealizes.GetSingle(root)});
938+
}
952939
const size_t reduce_size =
953940
std::count_if(iters.begin(), iters.end(), [](const ir::Var& v) {
954941
return v->is_reduce_axis;
@@ -965,6 +952,74 @@ ir::Expr GetBodyBlock(const ir::Expr& root) {
965952
->body;
966953
}
967954

955+
ir::Expr ReshapeLoop(const ir::Expr& root,
956+
const std::vector<symbol::DimExpr>& in_shape,
957+
const std::vector<symbol::DimExpr>& out_shape) {
958+
auto copied = ir::ir_utils::IRCopy(root);
959+
960+
ir::ModuleExpr mod_expr({copied});
961+
ir::IRSchedule ir_sch(
962+
mod_expr, -1, false, cinn::utils::ErrorMessageLevel::kGeneral, true);
963+
964+
const auto block_realize =
965+
(ExprSetFinderUtils::ChildScheduleBlockRealizes).GetSingle(copied);
966+
const auto block_name = block_realize.As<ir::ScheduleBlockRealize>()
967+
->schedule_block.As<ir::ScheduleBlock>()
968+
->name;
969+
const auto shape_partion = fusion::PartionReshapeAxes(in_shape, out_shape);
970+
971+
for (int idx = shape_partion.size() - 1; idx > 0; --idx) {
972+
const auto& in_s = shape_partion[idx - 1].first;
973+
const auto& in_e = shape_partion[idx].first;
974+
const auto& out_s = shape_partion[idx - 1].second;
975+
const auto& out_e = shape_partion[idx].second;
976+
977+
std::vector<int> fuse_indices;
978+
for (int i = in_e - 1; i >= in_s; --i) {
979+
if (in_shape[i] != symbol::DimExpr(1)) {
980+
fuse_indices.insert(fuse_indices.begin(), i);
981+
} else {
982+
VLOG(4) << "Remove index[" << i << "]: " << in_shape[i]
983+
<< " for expr: \n"
984+
<< copied;
985+
copied = ExprTransformerUtils::RemoveOneTransformer(i)(copied);
986+
ir_sch.SetExprs({copied});
987+
for (auto& index : fuse_indices) {
988+
index--;
989+
}
990+
}
991+
}
992+
if (fuse_indices.size() > 1) {
993+
VLOG(4) << "fuse_indices: " << cinn::utils::Join(fuse_indices, ",");
994+
ir_sch.Fuse(block_name, fuse_indices);
995+
}
996+
997+
std::vector<ir::Expr> split_shapes;
998+
for (int i = out_s; i < out_e; ++i) {
999+
if (out_shape[i] != symbol::DimExpr(1)) {
1000+
split_shapes.push_back(
1001+
cinn::common::DimExprConverter().ConvertToIrExpr(out_shape[i]));
1002+
}
1003+
}
1004+
if (split_shapes.size() > 1) {
1005+
ir_sch.Split(ir_sch.GetLoops(block_name)[in_s], split_shapes)[0];
1006+
}
1007+
}
1008+
1009+
std::vector<int> insert_axis;
1010+
std::vector<ir::Var> ones_var;
1011+
for (int i = 0; i < out_shape.size(); ++i) {
1012+
if (out_shape[i] == symbol::DimExpr(1)) {
1013+
insert_axis.push_back(i);
1014+
ones_var.push_back(ir::Var(1, "one_" + std::to_string(ones_var.size())));
1015+
}
1016+
}
1017+
copied = ExprTransformerUtils::InsertForsTransformer(insert_axis,
1018+
ones_var)(copied);
1019+
1020+
return copied;
1021+
}
1022+
9681023
} // namespace trivial_fusion_detail
9691024
} // namespace pir
9701025
} // namespace framework

paddle/cinn/hlir/framework/pir/trivial_op_util.h

+4
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root);
297297

298298
ir::Expr GetBodyBlock(const ir::Expr& root);
299299

300+
ir::Expr ReshapeLoop(const ir::Expr& root,
301+
const std::vector<symbol::DimExpr>& in_shape,
302+
const std::vector<symbol::DimExpr>& out_shape);
303+
300304
} // namespace trivial_fusion_detail
301305
} // namespace pir
302306
} // namespace framework

paddle/cinn/ir/group_schedule/config/group_tile_util.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
6363
auto* block = expr_block.As<ir::ScheduleBlockRealize>();
6464
auto& iter_vars = block->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
6565
for (int i = 0; i < iter_vars.size(); i++) {
66-
ir::Var loop_var = block->iter_values[i].as_var_ref();
67-
if (reduce_loop_vars.count(loop_var->name) > 0) {
66+
if (block->iter_values[i].is_var() &&
67+
reduce_loop_vars.count(block->iter_values[i].as_var()->name) > 0) {
6868
reduce_iter_vars.insert(iter_vars[i]->name);
6969
}
7070
}

paddle/cinn/operator_fusion/fusion_tracker/expr_utils.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ ir::Expr ApplyItersTransform::operator()(const TransposeItersTransform& trans) {
3131

3232
ir::Expr ApplyItersTransform::operator()(const RemoveOnesTransform& trans) {
3333
VLOG(4) << "[ItersTransform] Before RemoveOnesTransform("
34-
<< utils::Join(trans.ones_, ",") << "'): " << expr_;
34+
<< utils::Join(trans.ones_, ",") << "): " << expr_;
3535
auto result = RemoveOnesTransformer(trans.ones_)(expr_);
3636
VLOG(4) << "[ItersTransform] After RemoveOnesTransform: " << result;
3737
return result;

paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc

+18
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,20 @@ void RunItersTransformInstr(const std::shared_ptr<ItersTransformInstr>& instr,
159159
interpreter->scope[instr->target_] = new_pattern;
160160
}
161161

162+
void RunReshapeAlignInstr(const std::shared_ptr<ReshapeAlignInstr>& instr,
163+
FusionInterpreter* interpreter) {
164+
const auto expr = std::visit(
165+
FusibleOp2Expr(), interpreter->scope[instr->input_]->fusion_ops[0])[0];
166+
VLOG(4) << "Before RunReshapeAlignInstr: \n" << expr;
167+
auto result = cinn::hlir::framework::pir::trivial_fusion_detail::ReshapeLoop(
168+
expr, instr->in_shape_, instr->out_shape_);
169+
170+
auto new_pattern = std::make_shared<ScopeElement>();
171+
new_pattern->fusion_ops.emplace_back(TrivialOp(result));
172+
interpreter->scope[instr->result_] = new_pattern;
173+
VLOG(4) << "After ReshapeAlignInstr: \n" << result;
174+
}
175+
162176
void RunPaddingInstr(const std::shared_ptr<PaddingInstr>& instr,
163177
FusionInterpreter* interpreter) {
164178
ScopeElementPtr new_pattern = std::make_shared<ScopeElement>();
@@ -229,6 +243,10 @@ std::vector<ir::Expr> FusionInterpreter::Run() {
229243
RunItersTransformInstr(
230244
dynamic_cast_instr_with_err<ItersTransformInstr>(instr), this);
231245
break;
246+
case T_ReshapeAlign:
247+
RunReshapeAlignInstr(
248+
dynamic_cast_instr_with_err<ReshapeAlignInstr>(instr), this);
249+
break;
232250
default:
233251
PADDLE_THROW(
234252
::common::errors::Unavailable("Unsupported Fusion Instrution"));

0 commit comments

Comments
 (0)