Skip to content

Commit 2c7a583

Browse files
authored
Add manual op ap_op.add. (#72701)
1 parent 6428d9c commit 2c7a583

File tree

4 files changed

+63
-0
lines changed

4 files changed

+63
-0
lines changed

paddle/ap/include/paddle/hlir/manual_op.h

+18
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ class IR_API FacadeOp
4242
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
4343
};
4444

45+
class IR_API AddOp
46+
: public pir::Op<AddOp,
47+
pir::ImmutableLayoutTrait,
48+
::paddle::dialect::InferSymbolicShapeInterface> {
49+
public:
50+
using Op::Op;
51+
static const char *name() { return "ap_op.add"; }
52+
static constexpr uint32_t attributes_num = 0;
53+
static constexpr const char **attributes_name = nullptr;
54+
static void Build(pir::Builder &builder, // NOLINT
55+
pir::OperationArgument &argument, // NOLINT
56+
pir::Value lhs,
57+
pir::Value rhs);
58+
void VerifySig() const {}
59+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
60+
};
61+
4562
class IR_API UpSpiderOp
4663
: public pir::Op<UpSpiderOp,
4764
pir::SideEffectTrait,
@@ -152,6 +169,7 @@ class IR_API StoreToGlobalOp
152169
} // namespace ap::dialect
153170

154171
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::FacadeOp);
172+
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::AddOp);
155173
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::UpSpiderOp);
156174
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::DownSpiderOp);
157175
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromRegisterOp);

paddle/ap/src/paddle/hlir/manual_op.cc

+29
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,34 @@ bool FacadeOp::InferSymbolicShape(
4141
return ApOpFacadeOpInferSymbolicShape(*this, infer_context);
4242
}
4343

44+
void AddOp::Build(pir::Builder& builder, // NOLINT
45+
pir::OperationArgument& argument, // NOLINT
46+
pir::Value lhs,
47+
pir::Value rhs) {
48+
argument.AddInput(lhs);
49+
argument.AddInput(rhs);
50+
}
51+
52+
bool AddOp::InferSymbolicShape(pir::InferSymbolicShapeContext* infer_context) {
53+
const auto& lhs_shape_or_data =
54+
infer_context->GetShapeOrDataForValue(operand_source(0));
55+
const auto& rhs_shape_or_data =
56+
infer_context->GetShapeOrDataForValue(operand_source(1));
57+
PADDLE_ENFORCE_GT(lhs_shape_or_data.shape().size(),
58+
rhs_shape_or_data.shape().size(),
59+
phi::errors::InvalidArgument(
60+
"lhs and rhs of ap_op.add should have same rank"));
61+
for (int i = 0; i < lhs_shape_or_data.shape().size(); ++i) {
62+
const auto& lhs_dim_expr = lhs_shape_or_data.shape().at(i);
63+
const auto& rhs_dim_expr = rhs_shape_or_data.shape().at(i);
64+
if (lhs_dim_expr != rhs_dim_expr) {
65+
infer_context->AddEqualCstr(lhs_dim_expr, rhs_dim_expr);
66+
}
67+
}
68+
infer_context->SetShapeOrDataForValue(result(0), rhs_shape_or_data);
69+
return true;
70+
}
71+
4472
void UpSpiderOp::Build(pir::Builder& builder, // NOLINT
4573
pir::OperationArgument& argument, // NOLINT
4674
pir::Value lhs,
@@ -165,6 +193,7 @@ bool StoreToGlobalOp::InferSymbolicShape(
165193
} // namespace ap::dialect
166194

167195
IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::FacadeOp);
196+
IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::AddOp);
168197
IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::UpSpiderOp);
169198
IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::DownSpiderOp);
170199
IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromRegisterOp);

paddle/ap/src/paddle/hlir/op_dialect.cc

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ OperatorDialect::OperatorDialect(::pir::IrContext *context)
2626

2727
void OperatorDialect::initialize() {
2828
RegisterOp<FacadeOp>();
29+
RegisterOp<AddOp>();
2930
RegisterOp<UpSpiderOp>();
3031
RegisterOp<DownSpiderOp>();
3132
RegisterOp<LoadFromRegisterOp>();

paddle/ap/src/paddle/pass/op_factory.cc

+15
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ adt::Result<pir::Operation*> ConstructPdOpSum(
3636
return op;
3737
}
3838

39+
adt::Result<pir::Operation*> ConstructAddOp(
40+
pir::Builder* builder,
41+
const std::vector<pir::Value>& inputs,
42+
const pir::AttributeMap& attrs) {
43+
ADT_CHECK(inputs.size() == 2) << adt::errors::TypeError{
44+
std::string() + "'ap_op.add' op takes 2 arguments, but " +
45+
std::to_string(inputs.size()) + " were given"};
46+
auto op = builder->Build<ap::dialect::AddOp>(inputs.at(0), inputs.at(1));
47+
return op;
48+
}
49+
3950
adt::Result<pir::Operation*> ConstructUpSpiderOp(
4051
pir::Builder* builder,
4152
const std::vector<pir::Value>& inputs,
@@ -186,6 +197,10 @@ adt::Result<std::optional<pir::Operation*>> CreateOperation(
186197
ADT_LET_CONST_REF(ret, ConstructShadowOutputOp(builder, inputs, attrs));
187198
return ret;
188199
}
200+
if (op_name == "ap_op.add") {
201+
ADT_LET_CONST_REF(ret, ConstructAddOp(builder, inputs, attrs));
202+
return ret;
203+
}
189204
if (op_name == "ap_op.up_spider") {
190205
ADT_LET_CONST_REF(ret, ConstructUpSpiderOp(builder, inputs, attrs));
191206
return ret;

0 commit comments

Comments
 (0)