Skip to content

Commit 76f24dc

Browse files
authored
[CINN] Complete the op_role and chunk_id attributes of the operator in decompose process (#69612)
* refine code * add ut * refine
1 parent 023d881 commit 76f24dc

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

paddle/fluid/primitive/base/decomp_trans.cc

+15-4
Original file line numberDiff line numberDiff line change
@@ -491,8 +491,18 @@ void DecompProgram::decomp_block(
491491
if (enable_prim) {
492492
VLOG(4) << "[Prim] decomp op name " << op->name();
493493
check_decomp_dynamic_shape(op);
494-
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
495-
builder.set_insertion_point(op);
494+
std::shared_ptr<pir::Builder> builder =
495+
paddle::dialect::ApiBuilder::Instance().GetBuilder();
496+
builder->set_insertion_point(op);
497+
498+
int op_role = (op->attribute<pir::Int32Attribute>("op_role"))
499+
? op->attribute<pir::Int32Attribute>("op_role").data()
500+
: -1;
501+
int chunk_id = (op->attribute<pir::Int32Attribute>("chunk_id"))
502+
? op->attribute<pir::Int32Attribute>("chunk_id").data()
503+
: -1;
504+
pir::BuilderAttrGuard guard(builder, op_role, chunk_id);
505+
496506
std::vector<std::vector<pir::Value>> decomp_res = call_decomp_rule(op);
497507
if (decomp_res.size() == 0) {
498508
// if we don't decomp this op, then leave it intact.
@@ -563,8 +573,9 @@ void DecompProgram::decomp_block(
563573
tar_vars[i] = src_vars_[i];
564574
}
565575
}
566-
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
567-
builder.SetInsertionPointToBlockEnd(block);
576+
std::shared_ptr<pir::Builder> builder =
577+
paddle::dialect::ApiBuilder::Instance().GetBuilder();
578+
builder->SetInsertionPointToBlockEnd(block);
568579
}
569580

570581
} // namespace paddle

paddle/pir/include/core/builder.h

+16
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,20 @@ OpTy Builder::Build(Args &&...args) {
193193
return OpTy(op);
194194
}
195195

196+
class BuilderAttrGuard {
197+
public:
198+
BuilderAttrGuard(std::shared_ptr<Builder> builder, int op_role, int chunk_id);
199+
200+
~BuilderAttrGuard();
201+
202+
// forbid copy and operator=
203+
BuilderAttrGuard(const BuilderAttrGuard &guard) = delete;
204+
BuilderAttrGuard &operator=(const BuilderAttrGuard &guard) = delete;
205+
206+
private:
207+
std::shared_ptr<Builder> builder_;
208+
int pre_op_role_ = -1;
209+
int pre_chunk_id_ = -1;
210+
};
211+
196212
} // namespace pir

paddle/pir/src/core/builder.cc

+19
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,23 @@ TensorNameAttribute Builder::tensor_name_attr(const std::string &value) {
106106
return TensorNameAttribute::get(context_, value);
107107
}
108108

109+
BuilderAttrGuard::BuilderAttrGuard(std::shared_ptr<Builder> builder,
110+
int op_role,
111+
int chunk_id)
112+
: builder_(builder),
113+
pre_op_role_(builder_->op_role()),
114+
pre_chunk_id_(builder_->chunk_id()) {
115+
if (pre_op_role_ != op_role) {
116+
builder_->set_op_role(op_role);
117+
}
118+
if (pre_chunk_id_ != chunk_id) {
119+
builder_->set_chunk_id(chunk_id);
120+
}
121+
}
122+
123+
BuilderAttrGuard::~BuilderAttrGuard() { // NOLINT
124+
builder_->set_op_role(pre_op_role_);
125+
builder_->set_chunk_id(pre_chunk_id_);
126+
}
127+
109128
} // namespace pir

test/prim/pir_prim/test_decomp_op.py

+9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def get_ir_program():
3838
y_s = paddle.mean(y_s)
3939
y_s = paddle.tanh(y_s)
4040
pir_program = pir.translate_to_pir(main_program.desc)
41+
42+
all_ops = pir_program.global_block().ops
43+
for op in all_ops:
44+
op.op_role = 1
45+
4146
return pir_program
4247

4348

@@ -68,6 +73,10 @@ def test_build_op(self):
6873
'pd_op.tanh',
6974
],
7075
)
76+
op_role_list = [op.op_role for op in pir_program.global_block().ops]
77+
self.assertEqual(
78+
all(op_role == 1 for op_role in op_role_list), True
79+
)
7180

7281

7382
if __name__ == "__main__":

0 commit comments

Comments
 (0)