Skip to content

Commit 1295b0d

Browse files
committed
[compiler] fix error cast in shape reification pass
1 parent 624a7d1 commit 1295b0d

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

compiler/include/byteir/Transforms/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,8 @@ def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
440440
let constructor = "mlir::createByteIRShapeReificationPass()";
441441
let dependentDialects = [
442442
"mlir::shape::ShapeDialect",
443-
"mlir::tensor::TensorDialect"
443+
"mlir::tensor::TensorDialect",
444+
"mlir::arith::ArithDialect",
444445
];
445446
}
446447

compiler/lib/Transforms/PassDetail.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ namespace memref {
4343
class MemRefDialect;
4444
} // namespace memref
4545

46+
namespace arith {
47+
class ArithDialect;
48+
} // namespace arith
49+
4650
namespace mhlo {
4751
class MhloDialect;
4852
} // namespace mhlo

compiler/lib/Transforms/ShapeReification.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
2121
#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
2222
#include "mhlo/IR/hlo_ops.h"
23+
#include "mlir/Dialect/Arith/IR/Arith.h"
2324
#include "mlir/Dialect/Func/IR/FuncOps.h"
2425
#include "mlir/Dialect/Shape/IR/Shape.h"
2526
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -59,8 +60,8 @@ struct ShapeReificationOnTensorDimPattern
5960

6061
// Insert cast, if needed.
6162
if (dimOfShape.getType() != op.getType()) {
62-
dimOfShape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
63-
dimOfShape);
63+
dimOfShape = rewriter.create<arith::IndexCastOp>(
64+
op.getLoc(), op.getType(), dimOfShape);
6465
}
6566

6667
rewriter.replaceOp(op, dimOfShape);

0 commit comments

Comments
 (0)