Skip to content

Commit b823dd2

Browse files
committed
[CINN] Register InferSymbolicShape for cinn_op.argmin/argmax
1 parent e9d8a8e commit b823dd2

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,16 @@ bool ReduceSumOpInferSymbolicShape(
260260
return ReduceInferSymbolicShape(op, infer_context);
261261
}
262262

263+
bool ArgminOpInferSymbolicShape(pir::Operation *op,
264+
pir::InferSymbolicShapeContext *infer_context) {
265+
return ReduceInferSymbolicShape(op, infer_context);
266+
}
267+
268+
bool ArgmaxOpInferSymbolicShape(pir::Operation *op,
269+
pir::InferSymbolicShapeContext *infer_context) {
270+
return ReduceInferSymbolicShape(op, infer_context);
271+
}
272+
263273
bool ReshapeOpInferSymbolicShape(
264274
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
265275
std::vector<int> shape =

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceMax)
2323
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceMin)
2424
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceProd)
2525
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceSum)
26+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmin)
27+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmax)
2628
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape)
2729
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice)
2830
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Split)

0 commit comments

Comments
 (0)