Skip to content

Commit 6752369

Browse files
authored
[LV] Unify interleaved load handling for fixed and scalable VFs. nfc (#146914)
This patch modifies VPInterleaveRecipe::execute to handle both fixed and scalable VFs using a single loop.
1 parent 6f240d5 commit 6752369

File tree

1 file changed

+17
-35
lines changed

1 file changed

+17
-35
lines changed

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3445,7 +3445,6 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
34453445
VPValue *BlockInMask = getMask();
34463446
VPValue *Addr = getAddr();
34473447
Value *ResAddr = State.get(Addr, VPLane(0));
3448-
Value *PoisonVec = PoisonValue::get(VecTy);
34493448

34503449
auto CreateGroupMask = [&BlockInMask, &State,
34513450
&InterleaveFactor](Value *MaskForGaps) -> Value * {
@@ -3484,6 +3483,7 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
34843483
Instruction *NewLoad;
34853484
if (BlockInMask || MaskForGaps) {
34863485
Value *GroupMask = CreateGroupMask(MaskForGaps);
3486+
Value *PoisonVec = PoisonValue::get(VecTy);
34873487
NewLoad = State.Builder.CreateMaskedLoad(VecTy, ResAddr,
34883488
Group->getAlign(), GroupMask,
34893489
PoisonVec, "wide.masked.vec");
@@ -3493,57 +3493,39 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
34933493
Group->addMetadata(NewLoad);
34943494

34953495
ArrayRef<VPValue *> VPDefs = definedValues();
3496-
const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
34973496
if (VecTy->isScalableTy()) {
34983497
// Scalable vectors cannot use arbitrary shufflevectors (only splats),
34993498
// so must use intrinsics to deinterleave.
35003499
assert(InterleaveFactor <= 8 &&
35013500
"Unsupported deinterleave factor for scalable vectors");
3502-
Value *Deinterleave = State.Builder.CreateIntrinsic(
3501+
NewLoad = State.Builder.CreateIntrinsic(
35033502
getDeinterleaveIntrinsicID(InterleaveFactor), NewLoad->getType(),
35043503
NewLoad,
35053504
/*FMFSource=*/nullptr, "strided.vec");
3505+
}
35063506

3507-
for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
3508-
Instruction *Member = Group->getMember(I);
3509-
Value *StridedVec = State.Builder.CreateExtractValue(Deinterleave, I);
3510-
if (!Member) {
3511-
// This value is not needed as it's not used
3512-
cast<Instruction>(StridedVec)->eraseFromParent();
3513-
continue;
3514-
}
3515-
// If this member has different type, cast the result type.
3516-
if (Member->getType() != ScalarTy) {
3517-
VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF);
3518-
StridedVec =
3519-
createBitOrPointerCast(State.Builder, StridedVec, OtherVTy, DL);
3520-
}
3521-
3522-
if (Group->isReverse())
3523-
StridedVec = State.Builder.CreateVectorReverse(StridedVec, "reverse");
3524-
3525-
State.set(VPDefs[J], StridedVec);
3526-
++J;
3527-
}
3507+
auto CreateStridedVector = [&InterleaveFactor, &State,
3508+
&NewLoad](unsigned Index) -> Value * {
3509+
assert(Index < InterleaveFactor && "Illegal group index");
3510+
if (State.VF.isScalable())
3511+
return State.Builder.CreateExtractValue(NewLoad, Index);
35283512

3529-
return;
3530-
}
3531-
assert(!State.VF.isScalable() && "VF is assumed to be non scalable.");
3513+
// For fixed length VF, use shuffle to extract the sub-vectors from the
3514+
// wide load.
3515+
auto StrideMask =
3516+
createStrideMask(Index, InterleaveFactor, State.VF.getFixedValue());
3517+
return State.Builder.CreateShuffleVector(NewLoad, StrideMask,
3518+
"strided.vec");
3519+
};
35323520

3533-
// For each member in the group, shuffle out the appropriate data from the
3534-
// wide loads.
3535-
unsigned J = 0;
3536-
for (unsigned I = 0; I < InterleaveFactor; ++I) {
3521+
for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
35373522
Instruction *Member = Group->getMember(I);
35383523

35393524
// Skip the gaps in the group.
35403525
if (!Member)
35413526
continue;
35423527

3543-
auto StrideMask =
3544-
createStrideMask(I, InterleaveFactor, State.VF.getFixedValue());
3545-
Value *StridedVec =
3546-
State.Builder.CreateShuffleVector(NewLoad, StrideMask, "strided.vec");
3528+
Value *StridedVec = CreateStridedVector(I);
35473529

35483530
// If this member has different type, cast the result type.
35493531
if (Member->getType() != ScalarTy) {

0 commit comments

Comments
 (0)