diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index b7e144c4b65f..d62e21e11ede 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2052,9 +2052,24 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite( auto elementType = elementTypeIfVector(op.getLhs().getType()); mlir::Value bitResult; if (auto intType = mlir::dyn_cast(elementType)) { + + auto isCIRZeroVector = [](mlir::Value value) { + if (auto constantOp = value.getDefiningOp()) + if (auto zeroAttr = + mlir::dyn_cast(constantOp.getValue())) + return true; + return false; + }; + + bool shouldUseSigned = intType.isSigned(); + // Special treatment for sign-bit extraction patterns (lt comparison with + // zero), always use signed comparison to preserve the semantic intent + if (op.getKind() == cir::CmpOpKind::lt && isCIRZeroVector(op.getRhs())) + shouldUseSigned = true; + bitResult = rewriter.create( op.getLoc(), - convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()), + convertCmpKindToICmpPredicate(op.getKind(), shouldUseSigned), adaptor.getLhs(), adaptor.getRhs()); } else if (mlir::isa(elementType)) { bitResult = rewriter.create( diff --git a/clang/test/CIR/Lowering/vec-cmp.cir b/clang/test/CIR/Lowering/vec-cmp.cir index 3bc6570444a7..1ef2657f3266 100644 --- a/clang/test/CIR/Lowering/vec-cmp.cir +++ b/clang/test/CIR/Lowering/vec-cmp.cir @@ -3,6 +3,7 @@ !s16i = !cir.int !u16i = !cir.int +!u8i = !cir.int cir.func @vec_cmp(%0: !cir.vector, %1: !cir.vector) -> () { %2 = cir.vec.cmp(lt, %0, %1) : !cir.vector, !cir.vector x 16> @@ -14,3 +15,15 @@ cir.func @vec_cmp(%0: !cir.vector, %1: !cir.vector) -> ( // MLIR-NEXT: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %arg1 : vector<16xi16> // MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16 // MLIR-NEXT: llvm.return + +cir.func @vec_cmp_zero(%0: !cir.vector) -> () { + %1 = cir.const #cir.zero : !cir.vector + %2 = cir.vec.cmp(lt, %0, %1) : !cir.vector, !cir.vector x 16> + %3 = cir.cast(bitcast, %2 : !cir.vector x 16>), !cir.int + + cir.return +} + +// MLIR: llvm.func @vec_cmp_zero +// MLIR: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %{{[0-9]+}} : vector<16xi8> +// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16