Skip to content

[DAGCombiner] Add combine for vector interleave of splats #151110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ namespace {
return CombineTo(N, To, 2, AddTo);
}

SDValue CombineTo(SDNode *N, SmallVectorImpl<SDValue> *To,
bool AddTo = true) {
return CombineTo(N, To->data(), To->size(), AddTo);
}

void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);

private:
Expand Down Expand Up @@ -541,6 +546,7 @@ namespace {
SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
SDValue visitBUILD_VECTOR(SDNode *N);
SDValue visitCONCAT_VECTORS(SDNode *N);
SDValue visitVECTOR_INTERLEAVE(SDNode *N);
SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
SDValue visitVECTOR_SHUFFLE(SDNode *N);
SDValue visitSCALAR_TO_VECTOR(SDNode *N);
Expand Down Expand Up @@ -2021,6 +2027,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
case ISD::VECTOR_INTERLEAVE: return visitVECTOR_INTERLEAVE(N);
case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
Expand Down Expand Up @@ -25274,6 +25281,28 @@ static SDValue combineConcatVectorOfShuffleAndItsOperands(
return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
}

static SDValue combineConcatVectorOfSplats(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI,
bool LegalTypes,
bool LegalOperations) {
EVT VT = N->getValueType(0);

// Post-legalization we can only create wider SPLAT_VECTOR operations if both
// the type and operation is legal. The Hexagon target has custom
// legalization for SPLAT_VECTOR that splits the operation into two parts and
// concatenates them. Therefore, custom lowering must also be rejected in
// order to avoid an infinite loop.
if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
(LegalOperations && !TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT)))
return SDValue();

SDValue Op0 = N->getOperand(0);
if (!llvm::all_equal(N->op_values()) || Op0.getOpcode() != ISD::SPLAT_VECTOR)
return SDValue();

return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, Op0.getOperand(0));
}

SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
// If we only have one input vector, we don't need to do any concatenation.
if (N->getNumOperands() == 1)
Expand Down Expand Up @@ -25397,6 +25426,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
return DAG.getBuildVector(VT, SDLoc(N), Opnds);
}

if (SDValue V =
combineConcatVectorOfSplats(N, DAG, TLI, LegalTypes, LegalOperations))
return V;

// Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
// FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
if (SDValue V = combineConcatVectorOfScalars(N, DAG))
Expand Down Expand Up @@ -25465,6 +25498,21 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitVECTOR_INTERLEAVE(SDNode *N) {
// Check to see if all operands are identical.
if (!llvm::all_equal(N->op_values()))
return SDValue();

// Check to see if the identical operand is a splat.
if (!DAG.isSplatValue(N->getOperand(0)))
return SDValue();

// interleave splat(X), splat(X).... --> splat(X), splat(X)....
SmallVector<SDValue, 4> Ops;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps "interleave splat(X), splat(X).... --> splat(X), splat(X)...."?

Ops.append(N->op_values().begin(), N->op_values().end());
return CombineTo(N, &Ops);
}

// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
// if the subvector can be sourced for free.
static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@ target triple = "aarch64"
define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
; CHECK-LABEL: complex_mul_v2f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: movi v1.2d, #0000000000000000
; CHECK-NEXT: mov w8, #100 // =0x64
; CHECK-NEXT: cntd x9
; CHECK-NEXT: whilelo p1.d, xzr, x8
; CHECK-NEXT: cntd x9
; CHECK-NEXT: rdvl x10, #2
; CHECK-NEXT: mov x11, x9
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: mov x11, x9
; CHECK-NEXT: .LBB0_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: zip2 p2.d, p1.d, p1.d
; CHECK-NEXT: mov z6.d, z1.d
; CHECK-NEXT: mov z7.d, z0.d
; CHECK-NEXT: mov z6.d, z0.d
; CHECK-NEXT: mov z7.d, z1.d
; CHECK-NEXT: zip1 p1.d, p1.d, p1.d
; CHECK-NEXT: ld1d { z2.d }, p2/z, [x0, #1, mul vl]
; CHECK-NEXT: ld1d { z4.d }, p2/z, [x1, #1, mul vl]
Expand All @@ -39,14 +38,14 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: mov z0.d, p2/m, z7.d
; CHECK-NEXT: mov z1.d, p1/m, z6.d
; CHECK-NEXT: mov z1.d, p2/m, z7.d
; CHECK-NEXT: mov z0.d, p1/m, z6.d
; CHECK-NEXT: whilelo p1.d, x11, x8
; CHECK-NEXT: add x11, x11, x9
; CHECK-NEXT: b.mi .LBB0_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp1 z2.d, z1.d, z0.d
; CHECK-NEXT: uzp2 z1.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z2.d, z0.d, z1.d
; CHECK-NEXT: uzp2 z1.d, z0.d, z1.d
; CHECK-NEXT: faddv d0, p0, z2.d
; CHECK-NEXT: faddv d1, p0, z1.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
Expand Down Expand Up @@ -111,21 +110,20 @@ exit.block: ; preds = %vector.body
define %"class.std::complex" @complex_mul_predicated_v2f64(ptr %a, ptr %b, ptr %cond) {
; CHECK-LABEL: complex_mul_predicated_v2f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: movi v1.2d, #0000000000000000
; CHECK-NEXT: cntd x9
; CHECK-NEXT: mov w11, #100 // =0x64
; CHECK-NEXT: neg x10, x9
; CHECK-NEXT: mov w11, #100 // =0x64
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: mov x8, xzr
; CHECK-NEXT: and x10, x10, x11
; CHECK-NEXT: rdvl x11, #2
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: .LBB1_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ld1w { z2.d }, p0/z, [x2, x8, lsl #2]
; CHECK-NEXT: mov z6.d, z1.d
; CHECK-NEXT: mov z7.d, z0.d
; CHECK-NEXT: mov z6.d, z0.d
; CHECK-NEXT: mov z7.d, z1.d
; CHECK-NEXT: add x8, x8, x9
; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, #0
; CHECK-NEXT: cmp x10, x8
Expand All @@ -141,12 +139,12 @@ define %"class.std::complex" @complex_mul_predicated_v2f64(ptr %a, ptr %b, ptr %
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: mov z0.d, p2/m, z7.d
; CHECK-NEXT: mov z1.d, p1/m, z6.d
; CHECK-NEXT: mov z1.d, p2/m, z7.d
; CHECK-NEXT: mov z0.d, p1/m, z6.d
; CHECK-NEXT: b.ne .LBB1_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp1 z2.d, z1.d, z0.d
; CHECK-NEXT: uzp2 z1.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z2.d, z0.d, z1.d
; CHECK-NEXT: uzp2 z1.d, z0.d, z1.d
; CHECK-NEXT: faddv d0, p0, z2.d
; CHECK-NEXT: faddv d1, p0, z1.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
Expand Down Expand Up @@ -213,21 +211,20 @@ exit.block: ; preds = %vector.body
define %"class.std::complex" @complex_mul_predicated_x2_v2f64(ptr %a, ptr %b, ptr %cond) {
; CHECK-LABEL: complex_mul_predicated_x2_v2f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: movi v1.2d, #0000000000000000
; CHECK-NEXT: mov w8, #100 // =0x64
; CHECK-NEXT: cntd x9
; CHECK-NEXT: whilelo p1.d, xzr, x8
; CHECK-NEXT: cntd x9
; CHECK-NEXT: rdvl x10, #2
; CHECK-NEXT: cnth x11
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: cnth x11
; CHECK-NEXT: mov x12, x9
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: .LBB2_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ld1w { z2.d }, p1/z, [x2]
; CHECK-NEXT: mov z6.d, z1.d
; CHECK-NEXT: mov z7.d, z0.d
; CHECK-NEXT: mov z6.d, z0.d
; CHECK-NEXT: mov z7.d, z1.d
; CHECK-NEXT: add x2, x2, x11
; CHECK-NEXT: and z2.d, z2.d, #0xffffffff
; CHECK-NEXT: cmpne p1.d, p1/z, z2.d, #0
Expand All @@ -243,14 +240,14 @@ define %"class.std::complex" @complex_mul_predicated_x2_v2f64(ptr %a, ptr %b, pt
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: mov z0.d, p2/m, z7.d
; CHECK-NEXT: mov z1.d, p1/m, z6.d
; CHECK-NEXT: mov z1.d, p2/m, z7.d
; CHECK-NEXT: mov z0.d, p1/m, z6.d
; CHECK-NEXT: whilelo p1.d, x12, x8
; CHECK-NEXT: add x12, x12, x9
; CHECK-NEXT: b.mi .LBB2_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp1 z2.d, z1.d, z0.d
; CHECK-NEXT: uzp2 z1.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z2.d, z0.d, z1.d
; CHECK-NEXT: uzp2 z1.d, z0.d, z1.d
; CHECK-NEXT: faddv d0, p0, z2.d
; CHECK-NEXT: faddv d1, p0, z1.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ target triple = "aarch64"
define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
; CHECK-LABEL: complex_mul_v2f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: movi v1.2d, #0000000000000000
; CHECK-NEXT: cntd x8
; CHECK-NEXT: mov w10, #100 // =0x64
; CHECK-NEXT: neg x9, x8
; CHECK-NEXT: mov w10, #100 // =0x64
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and x9, x9, x10
; CHECK-NEXT: rdvl x10, #2
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: .LBB0_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ldr z2, [x0, #1, mul vl]
Expand All @@ -32,14 +31,14 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
; CHECK-NEXT: ldr z5, [x1]
; CHECK-NEXT: add x1, x1, x10
; CHECK-NEXT: add x0, x0, x10
; CHECK-NEXT: fcmla z1.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z0.d, p0/m, z4.d, z2.d, #0
; CHECK-NEXT: fcmla z1.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: fcmla z0.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: fcmla z0.d, p0/m, z5.d, z3.d, #0
; CHECK-NEXT: fcmla z1.d, p0/m, z4.d, z2.d, #0
; CHECK-NEXT: fcmla z0.d, p0/m, z5.d, z3.d, #90
; CHECK-NEXT: fcmla z1.d, p0/m, z4.d, z2.d, #90
; CHECK-NEXT: b.ne .LBB0_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp1 z2.d, z1.d, z0.d
; CHECK-NEXT: uzp2 z1.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z2.d, z0.d, z1.d
; CHECK-NEXT: uzp2 z1.d, z0.d, z1.d
; CHECK-NEXT: faddv d0, p0, z2.d
; CHECK-NEXT: faddv d1, p0, z1.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
Expand Down Expand Up @@ -183,17 +182,16 @@ exit.block: ; preds = %vector.body
define %"class.std::complex" @complex_mul_v2f64_unrolled(ptr %a, ptr %b) {
; CHECK-LABEL: complex_mul_v2f64_unrolled:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: movi v1.2d, #0000000000000000
; CHECK-NEXT: cntw x8
; CHECK-NEXT: mov w10, #1000 // =0x3e8
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: movi v3.2d, #0000000000000000
; CHECK-NEXT: neg x9, x8
; CHECK-NEXT: mov w10, #1000 // =0x3e8
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and x9, x9, x10
; CHECK-NEXT: rdvl x10, #4
; CHECK-NEXT: zip2 z0.d, z1.d, z1.d
; CHECK-NEXT: zip1 z1.d, z1.d, z1.d
; CHECK-NEXT: mov z2.d, z1.d
; CHECK-NEXT: mov z3.d, z0.d
; CHECK-NEXT: .LBB2_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ldr z4, [x0, #1, mul vl]
Expand All @@ -207,20 +205,20 @@ define %"class.std::complex" @complex_mul_v2f64_unrolled(ptr %a, ptr %b) {
; CHECK-NEXT: ldr z18, [x1, #3, mul vl]
; CHECK-NEXT: ldr z19, [x1, #2, mul vl]
; CHECK-NEXT: add x1, x1, x10
; CHECK-NEXT: fcmla z1.d, p0/m, z16.d, z5.d, #0
; CHECK-NEXT: fcmla z0.d, p0/m, z7.d, z4.d, #0
; CHECK-NEXT: fcmla z0.d, p0/m, z16.d, z5.d, #0
; CHECK-NEXT: fcmla z1.d, p0/m, z7.d, z4.d, #0
; CHECK-NEXT: fcmla z3.d, p0/m, z18.d, z6.d, #0
; CHECK-NEXT: fcmla z2.d, p0/m, z19.d, z17.d, #0
; CHECK-NEXT: fcmla z1.d, p0/m, z16.d, z5.d, #90
; CHECK-NEXT: fcmla z0.d, p0/m, z7.d, z4.d, #90
; CHECK-NEXT: fcmla z0.d, p0/m, z16.d, z5.d, #90
; CHECK-NEXT: fcmla z1.d, p0/m, z7.d, z4.d, #90
; CHECK-NEXT: fcmla z3.d, p0/m, z18.d, z6.d, #90
; CHECK-NEXT: fcmla z2.d, p0/m, z19.d, z17.d, #90
; CHECK-NEXT: b.ne .LBB2_1
; CHECK-NEXT: // %bb.2: // %exit.block
; CHECK-NEXT: uzp1 z4.d, z2.d, z3.d
; CHECK-NEXT: uzp1 z5.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z5.d, z0.d, z1.d
; CHECK-NEXT: uzp2 z2.d, z2.d, z3.d
; CHECK-NEXT: uzp2 z0.d, z1.d, z0.d
; CHECK-NEXT: uzp2 z0.d, z0.d, z1.d
; CHECK-NEXT: fadd z1.d, z4.d, z5.d
; CHECK-NEXT: fadd z2.d, z2.d, z0.d
; CHECK-NEXT: faddv d0, p0, z1.d
Expand Down Expand Up @@ -310,15 +308,15 @@ define dso_local %"class.std::complex" @reduction_mix(ptr %a, ptr %b, ptr noalia
; CHECK-LABEL: reduction_mix:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: cntd x9
; CHECK-NEXT: mov w11, #100 // =0x64
; CHECK-NEXT: movi v1.2d, #0000000000000000
; CHECK-NEXT: neg x10, x9
; CHECK-NEXT: mov w11, #100 // =0x64
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: mov x8, xzr
; CHECK-NEXT: and x10, x10, x11
; CHECK-NEXT: rdvl x11, #2
; CHECK-NEXT: zip2 z0.d, z2.d, z2.d
; CHECK-NEXT: zip1 z1.d, z2.d, z2.d
; CHECK-NEXT: .LBB3_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ldr z3, [x0]
Expand All @@ -327,13 +325,13 @@ define dso_local %"class.std::complex" @reduction_mix(ptr %a, ptr %b, ptr noalia
; CHECK-NEXT: ld1w { z5.d }, p0/z, [x3, x8, lsl #2]
; CHECK-NEXT: add x8, x8, x9
; CHECK-NEXT: cmp x10, x8
; CHECK-NEXT: fadd z0.d, z4.d, z0.d
; CHECK-NEXT: fadd z1.d, z3.d, z1.d
; CHECK-NEXT: fadd z1.d, z4.d, z1.d
; CHECK-NEXT: fadd z0.d, z3.d, z0.d
; CHECK-NEXT: add z2.d, z5.d, z2.d
; CHECK-NEXT: b.ne .LBB3_1
; CHECK-NEXT: // %bb.2: // %middle.block
; CHECK-NEXT: uzp2 z3.d, z1.d, z0.d
; CHECK-NEXT: uzp1 z1.d, z1.d, z0.d
; CHECK-NEXT: uzp2 z3.d, z0.d, z1.d
; CHECK-NEXT: uzp1 z1.d, z0.d, z1.d
; CHECK-NEXT: uaddv d2, p0, z2.d
; CHECK-NEXT: faddv d0, p0, z3.d
; CHECK-NEXT: faddv d1, p0, z1.d
Expand Down
Loading