Skip to content

[ARM] Have custom lowering for ucmp and scmp #149315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,12 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM_,
setOperationAction(ISD::BSWAP, VT, Expand);
}

if (!Subtarget->isThumb1Only() && !Subtarget->hasMVEIntegerOps())
setOperationAction(ISD::SCMP, MVT::i32, Custom);

if (!Subtarget->hasMVEIntegerOps())
setOperationAction(ISD::UCMP, MVT::i32, Custom);

setOperationAction(ISD::ConstantFP, MVT::f32, Custom);
setOperationAction(ISD::ConstantFP, MVT::f64, Custom);

Expand Down Expand Up @@ -1628,6 +1634,10 @@ bool ARMTargetLowering::useSoftFloat() const {
return Subtarget->useSoftFloat();
}

bool ARMTargetLowering::shouldExpandCmpUsingSelects(EVT VT) const {
return (!Subtarget->isThumb1Only() && VT.getSizeInBits() <= 32);
}

// FIXME: It might make sense to define the representative register class as the
// nearest super-register that has a non-null superset. For example, DPR_VFP2 is
// a super-register of SPR, and DPR is a superset if DPR_VFP2. Consequently,
Expand Down Expand Up @@ -10614,6 +10624,181 @@ SDValue ARMTargetLowering::LowerFP_TO_BF16(SDValue Op,
return DAG.getBitcast(MVT::i32, Res);
}

SDValue ARMTargetLowering::LowerSCMP(SDValue Op, SelectionDAG &DAG) const {
SDLoc dl(Op);
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);

// For the ARM assembly pattern:
// subs r0, r0, r1 ; subtract RHS from LHS and set flags
// movgt r0, #1 ; if LHS > RHS, set result to 1
// mvnlt r0, #0 ; if LHS < RHS, set result to -1 (mvn #0 = -1)
// ; if LHS == RHS, result remains 0 from the subs

// Optimization: if RHS is a subtraction against 0, use ADDC instead of SUBC
// Check if RHS is (0 - something), and if so use ADDC with LHS + something
SDValue SubResult, Flags;
bool CanUseAdd = false;
SDValue AddOperand;

// Check if RHS is a subtraction against 0: (0 - X)
if (RHS.getOpcode() == ISD::SUB) {
SDValue SubLHS = RHS.getOperand(0);
SDValue SubRHS = RHS.getOperand(1);

// Check if it's 0 - X
if (isNullConstant(SubLHS)) {
// For SCMP: only if X is known to never be INT_MIN (to avoid overflow)
if (RHS->getFlags().hasNoSignedWrap() || !DAG.computeKnownBits(SubRHS)
.getSignedMinValue()
.isMinSignedValue()) {
CanUseAdd = true;
AddOperand = SubRHS; // Replace RHS with X, so we do LHS + X instead of
// LHS - (0 - X)
}
}
}

if (CanUseAdd) {
// Use ADDC: LHS + AddOperand (where RHS was 0 - AddOperand)
SDValue AddWithFlags = DAG.getNode(
ARMISD::ADDC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, AddOperand);
SubResult = AddWithFlags.getValue(0); // The addition result
Flags = AddWithFlags.getValue(1); // The flags from ADDS
} else {
// Use ARMISD::SUBC to generate SUBS instruction (subtract with flags)
SDValue SubWithFlags = DAG.getNode(
ARMISD::SUBC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, RHS);
SubResult = SubWithFlags.getValue(0); // The subtraction result
Flags = SubWithFlags.getValue(1); // The flags from SUBS
}

// Constants for conditional moves
SDValue One = DAG.getConstant(1, dl, MVT::i32);
SDValue MinusOne = DAG.getAllOnesConstant(dl, MVT::i32);

// movgt: if greater than, set to 1
SDValue GTCond = DAG.getConstant(ARMCC::GT, dl, MVT::i32);
SDValue Result1 =
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, SubResult, One, GTCond, Flags);

// mvnlt: if less than, set to -1 (equivalent to mvn #0)
SDValue LTCond = DAG.getConstant(ARMCC::LT, dl, MVT::i32);
SDValue Result2 =
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, Result1, MinusOne, LTCond, Flags);

if (Op.getValueType() != MVT::i32)
Result2 = DAG.getSExtOrTrunc(Result2, dl, Op.getValueType());

return Result2;
}

SDValue ARMTargetLowering::LowerUCMP(SDValue Op, SelectionDAG &DAG) const {
SDLoc dl(Op);
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);

if (Subtarget->isThumb1Only()) {
// For Thumb unsigned comparison, use this sequence:
// subs r2, r0, r1 ; r2 = LHS - RHS, sets flags
// sbc r2, r2 ; r2 = r2 - r2 - !carry
// cmp r1, r0 ; compare RHS with LHS
// sbc r1, r1 ; r1 = r1 - r1 - !carry
// subs r0, r2, r1 ; r0 = r2 - r1 (final result)

// First subtraction: LHS - RHS
SDValue Sub1WithFlags = DAG.getNode(
ARMISD::SUBC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, RHS);
SDValue Sub1Result = Sub1WithFlags.getValue(0);
SDValue Flags1 = Sub1WithFlags.getValue(1);

// SUBE: Sub1Result - Sub1Result - !carry
// This gives 0 if LHS >= RHS (unsigned), -1 if LHS < RHS (unsigned)
SDValue Sbc1 =
DAG.getNode(ARMISD::SUBE, dl, DAG.getVTList(MVT::i32, FlagsVT),
Sub1Result, Sub1Result, Flags1);
SDValue Sbc1Result = Sbc1.getValue(0);

// Second comparison: RHS vs LHS (reverse comparison)
SDValue CmpFlags = DAG.getNode(ARMISD::CMP, dl, FlagsVT, RHS, LHS);

// SUBE: RHS - RHS - !carry
// This gives 0 if RHS <= LHS (unsigned), -1 if RHS > LHS (unsigned)
SDValue Sbc2 = DAG.getNode(
ARMISD::SUBE, dl, DAG.getVTList(MVT::i32, FlagsVT), RHS, RHS, CmpFlags);
SDValue Sbc2Result = Sbc2.getValue(0);

// Final subtraction: Sbc1Result - Sbc2Result (no flags needed)
SDValue Result =
DAG.getNode(ISD::SUB, dl, MVT::i32, Sbc1Result, Sbc2Result);
if (Op.getValueType() != MVT::i32)
Result = DAG.getSExtOrTrunc(Result, dl, Op.getValueType());

return Result;
}

// For the ARM assembly pattern (unsigned version):
// subs r0, r0, r1 ; subtract RHS from LHS and set flags
// movhi r0, #1 ; if LHS > RHS (unsigned), set result to 1
// mvnlo r0, #0 ; if LHS < RHS (unsigned), set result to -1
// ; if LHS == RHS, result remains 0 from the subs

// Optimization: if RHS is a subtraction against 0, use ADDC instead of SUBC
// Check if RHS is (0 - something), and if so use ADDC with LHS + something
SDValue SubResult, Flags;
bool CanUseAdd = false;
SDValue AddOperand;

// Check if RHS is a subtraction against 0: (0 - X)
if (RHS.getOpcode() == ISD::SUB) {
SDValue SubLHS = RHS.getOperand(0);
SDValue SubRHS = RHS.getOperand(1);

// Check if it's 0 - X
if (isNullConstant(SubLHS)) {
// For UCMP: only if X is known to never be zero
if (DAG.isKnownNeverZero(SubRHS)) {
CanUseAdd = true;
AddOperand = SubRHS; // Replace RHS with X, so we do LHS + X instead of
// LHS - (0 - X)
}
}
}

if (CanUseAdd) {
// Use ADDC: LHS + AddOperand (where RHS was 0 - AddOperand)
SDValue AddWithFlags = DAG.getNode(
ARMISD::ADDC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, AddOperand);
SubResult = AddWithFlags.getValue(0); // The addition result
Flags = AddWithFlags.getValue(1); // The flags from ADDS
} else {
// Use ARMISD::SUBC to generate SUBS instruction (subtract with flags)
SDValue SubWithFlags = DAG.getNode(
ARMISD::SUBC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, RHS);
SubResult = SubWithFlags.getValue(0); // The subtraction result
Flags = SubWithFlags.getValue(1); // The flags from SUBS
}

// Constants for conditional moves
SDValue One = DAG.getConstant(1, dl, MVT::i32);
SDValue MinusOne = DAG.getAllOnesConstant(dl, MVT::i32);

// movhi: if higher (unsigned greater than), set to 1
SDValue HICond = DAG.getConstant(ARMCC::HI, dl, MVT::i32);
SDValue Result1 =
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, SubResult, One, HICond, Flags);

// mvnlo: if lower (unsigned less than), set to -1
SDValue LOCond = DAG.getConstant(ARMCC::LO, dl, MVT::i32);
SDValue Result2 =
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, Result1, MinusOne, LOCond, Flags);

if (Op.getValueType() != MVT::i32)
Result2 = DAG.getSExtOrTrunc(Result2, dl, Op.getValueType());

return Result2;
}

SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "Lowering node: "; Op.dump());
switch (Op.getOpcode()) {
Expand Down Expand Up @@ -10742,6 +10927,10 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::FP_TO_BF16:
return LowerFP_TO_BF16(Op, DAG);
case ARMISD::WIN__DBZCHK: return SDValue();
case ISD::SCMP:
return LowerSCMP(Op, DAG);
case ISD::UCMP:
return LowerUCMP(Op, DAG);
}
}

Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ class VectorType;

bool preferZeroCompareBranch() const override { return true; }

bool shouldExpandCmpUsingSelects(EVT VT) const override;

bool isMaskAndCmp0FoldingBeneficial(const Instruction &AndI) const override;

bool hasAndNotCompare(SDValue V) const override {
Expand Down Expand Up @@ -903,6 +905,8 @@ class VectorType;
void LowerLOAD(SDNode *N, SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const;
SDValue LowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSCMP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerUCMP(SDValue Op, SelectionDAG &DAG) const;

Register getRegisterByName(const char* RegName, LLT VT,
const MachineFunction &MF) const override;
Expand Down
48 changes: 24 additions & 24 deletions llvm/test/CodeGen/ARM/scmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
define i8 @scmp_8_8(i8 signext %x, i8 signext %y) nounwind {
; CHECK-LABEL: scmp_8_8:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlt r0, #1
; CHECK-NEXT: movwgt r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwgt r0, #1
; CHECK-NEXT: mvnlt r0, #0
; CHECK-NEXT: bx lr
%1 = call i8 @llvm.scmp(i8 %x, i8 %y)
ret i8 %1
Expand All @@ -18,12 +15,9 @@ define i8 @scmp_8_8(i8 signext %x, i8 signext %y) nounwind {
define i8 @scmp_8_16(i16 signext %x, i16 signext %y) nounwind {
; CHECK-LABEL: scmp_8_16:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlt r0, #1
; CHECK-NEXT: movwgt r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwgt r0, #1
; CHECK-NEXT: mvnlt r0, #0
; CHECK-NEXT: bx lr
%1 = call i8 @llvm.scmp(i16 %x, i16 %y)
ret i8 %1
Expand All @@ -32,12 +26,9 @@ define i8 @scmp_8_16(i16 signext %x, i16 signext %y) nounwind {
define i8 @scmp_8_32(i32 %x, i32 %y) nounwind {
; CHECK-LABEL: scmp_8_32:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlt r0, #1
; CHECK-NEXT: movwgt r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwgt r0, #1
; CHECK-NEXT: mvnlt r0, #0
; CHECK-NEXT: bx lr
%1 = call i8 @llvm.scmp(i32 %x, i32 %y)
ret i8 %1
Expand Down Expand Up @@ -92,17 +83,26 @@ define i8 @scmp_8_128(i128 %x, i128 %y) nounwind {
define i32 @scmp_32_32(i32 %x, i32 %y) nounwind {
; CHECK-LABEL: scmp_32_32:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlt r0, #1
; CHECK-NEXT: movwgt r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwgt r0, #1
; CHECK-NEXT: mvnlt r0, #0
; CHECK-NEXT: bx lr
%1 = call i32 @llvm.scmp(i32 %x, i32 %y)
ret i32 %1
}

define i32 @scmp_neg(i32 %x, i32 %y) nounwind {
; CHECK-LABEL: scmp_neg:
; CHECK: @ %bb.0:
; CHECK-NEXT: adds r0, r0, r1
; CHECK-NEXT: movwgt r0, #1
; CHECK-NEXT: mvnlt r0, #0
; CHECK-NEXT: bx lr
%yy = sub nsw i32 0, %y
%1 = call i32 @llvm.scmp(i32 %x, i32 %yy)
ret i32 %1
}

define i32 @scmp_32_64(i64 %x, i64 %y) nounwind {
; CHECK-LABEL: scmp_32_64:
; CHECK: @ %bb.0:
Expand Down
36 changes: 12 additions & 24 deletions llvm/test/CodeGen/ARM/ucmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
define i8 @ucmp_8_8(i8 zeroext %x, i8 zeroext %y) nounwind {
; CHECK-LABEL: ucmp_8_8:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlo r0, #1
; CHECK-NEXT: movwhi r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwhi r0, #1
; CHECK-NEXT: mvnlo r0, #0
; CHECK-NEXT: bx lr
%1 = call i8 @llvm.ucmp(i8 %x, i8 %y)
ret i8 %1
Expand All @@ -18,12 +15,9 @@ define i8 @ucmp_8_8(i8 zeroext %x, i8 zeroext %y) nounwind {
define i8 @ucmp_8_16(i16 zeroext %x, i16 zeroext %y) nounwind {
; CHECK-LABEL: ucmp_8_16:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlo r0, #1
; CHECK-NEXT: movwhi r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwhi r0, #1
; CHECK-NEXT: mvnlo r0, #0
; CHECK-NEXT: bx lr
%1 = call i8 @llvm.ucmp(i16 %x, i16 %y)
ret i8 %1
Expand All @@ -32,12 +26,9 @@ define i8 @ucmp_8_16(i16 zeroext %x, i16 zeroext %y) nounwind {
define i8 @ucmp_8_32(i32 %x, i32 %y) nounwind {
; CHECK-LABEL: ucmp_8_32:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlo r0, #1
; CHECK-NEXT: movwhi r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwhi r0, #1
; CHECK-NEXT: mvnlo r0, #0
; CHECK-NEXT: bx lr
%1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
ret i8 %1
Expand Down Expand Up @@ -92,12 +83,9 @@ define i8 @ucmp_8_128(i128 %x, i128 %y) nounwind {
define i32 @ucmp_32_32(i32 %x, i32 %y) nounwind {
; CHECK-LABEL: ucmp_32_32:
; CHECK: @ %bb.0:
; CHECK-NEXT: cmp r0, r1
; CHECK-NEXT: mov r0, #0
; CHECK-NEXT: mov r2, #0
; CHECK-NEXT: movwlo r0, #1
; CHECK-NEXT: movwhi r2, #1
; CHECK-NEXT: sub r0, r2, r0
; CHECK-NEXT: subs r0, r0, r1
; CHECK-NEXT: movwhi r0, #1
; CHECK-NEXT: mvnlo r0, #0
; CHECK-NEXT: bx lr
%1 = call i32 @llvm.ucmp(i32 %x, i32 %y)
ret i32 %1
Expand Down
Loading