Skip to content

Commit 6975119

Browse files
authored
[mlir] Implement inferResultRanges for vector.transpose (#151537)
Implements the `inferResultRanges` method from the `InferIntRangeInterface` interface for `vector.transpose`. The result ranges simply match the source ranges. Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
1 parent 4adce33 commit 6975119

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,6 +2595,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
25952595

25962596
def Vector_TransposeOp :
25972597
Vector_Op<"transpose", [Pure,
2598+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
25982599
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
25992600
PredOpTrait<"operand and result have same element type",
26002601
TCresVTEtIsSameAsOpBase<0, 0>>]> {

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6316,6 +6316,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
63166316
return llvm::to_vector<4>(getResultVectorType().getShape());
63176317
}
63186318

6319+
void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6320+
SetIntRangeFn setResultRanges) {
6321+
setResultRanges(getResult(), argRanges.front());
6322+
}
6323+
63196324
namespace {
63206325

63216326
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.

mlir/test/Dialect/Vector/int-range-interface.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> {
5151
func.return %2 : vector<4x4xindex>
5252
}
5353

54+
// CHECK-LABEL: func @vector_transpose
55+
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index}
56+
func.func @vector_transpose() -> vector<2x4xindex> {
57+
%0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex>
58+
%1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex>
59+
%2 = test.reflect_bounds %1 : vector<2x4xindex>
60+
func.return %2 : vector<2x4xindex>
61+
}
62+
5463
// CHECK-LABEL: func @vector_extract
5564
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
5665
func.func @vector_extract() -> index {

0 commit comments

Comments
 (0)