From 7f06b5cb8d382ac475c28e9148a7414d3fc1b7c6 Mon Sep 17 00:00:00 2001 From: Qi Zhao Date: Thu, 31 Jul 2025 16:59:20 +0800 Subject: [PATCH 1/5] [LoongArch] Optimize extractelement containing variable index --- .../LoongArch/LoongArchISelLowering.cpp | 19 +++++++++- .../Target/LoongArch/LoongArchISelLowering.h | 1 + .../LoongArch/LoongArchLASXInstrInfo.td | 10 ++++++ .../lasx/ir-instruction/extractelement.ll | 36 +++++-------------- 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 6583a0fef3d61..597650c8229a7 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -2608,13 +2608,29 @@ SDValue LoongArchTargetLowering::lowerCONCAT_VECTORS(SDValue Op, SDValue LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { - EVT VecTy = Op->getOperand(0)->getValueType(0); + MVT EltVT = Op.getSimpleValueType(); + SDValue Vec = Op->getOperand(0); + EVT VecTy = Vec->getValueType(0); SDValue Idx = Op->getOperand(1); unsigned NumElts = VecTy.getVectorNumElements(); + SDLoc DL(Op); + + assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type"); if (isa(Idx) && Idx->getAsZExtVal() < NumElts) return Op; + // TODO: Deal with other legal 256-bits vector types? + if (!isa(Idx) && + (VecTy == MVT::v8i32 || VecTy == MVT::v8f32)) { + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); + SDValue SplatValue = + DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue, + DAG.getConstant(0, DL, Subtarget.getGRLenVT())); + } + return SDValue(); } @@ -6632,6 +6648,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VREPLVEI) NODE_NAME_CASE(VREPLGR2VR) NODE_NAME_CASE(XVPERMI) + NODE_NAME_CASE(XVPERM) NODE_NAME_CASE(VPICK_SEXT_ELT) NODE_NAME_CASE(VPICK_ZEXT_ELT) NODE_NAME_CASE(VREPLVE) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h index f79ba7450cc36..075ccc04eca9e 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h @@ -141,6 +141,7 @@ enum NodeType : unsigned { VREPLVEI, VREPLGR2VR, XVPERMI, + XVPERM, // Extended vector element extraction VPICK_SEXT_ELT, diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index d8bb16fe9b94d..9dfd059a78286 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -10,8 +10,12 @@ // //===----------------------------------------------------------------------===// +def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>, + SDTCisVec<2>, SDTCisInt<2>]>; + // Target nodes. def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>; +def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>; def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>; def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>; def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>; @@ -1834,6 +1838,12 @@ def : Pat<(loongarch_xvpermi v4i64:$xj, immZExt8: $ui8), def : Pat<(loongarch_xvpermi v4f64:$xj, immZExt8: $ui8), (XVPERMI_D v4f64:$xj, immZExt8: $ui8)>; +// XVPERM_W +def : Pat<(loongarch_xvperm v8i32:$xj, v8i32:$xk), + (XVPERM_W v8i32:$xj, v8i32:$xk)>; +def : Pat<(loongarch_xvperm v8f32:$xj, v8i32:$xk), + (XVPERM_W v8f32:$xj, v8i32:$xk)>; + // XVREPLVE0_{W/D} def : Pat<(lasxsplatf32 FPR32:$fj), (XVREPLVE0_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32))>; diff --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll index 2e1618748688a..b191a9d08ab2d 100644 --- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll +++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll @@ -126,21 +126,11 @@ define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_8xi32_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2 -; CHECK-NEXT: ld.w $a0, $a0, 0 -; CHECK-NEXT: st.w $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <8 x i32>, ptr %src %e = extractelement <8 x i32> %v, i32 %idx @@ -176,21 +166,11 @@ define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_8xfloat_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2 -; CHECK-NEXT: fld.s $fa0, $a0, 0 -; CHECK-NEXT: fst.s $fa0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <8 x float>, ptr %src %e = extractelement <8 x float> %v, i32 %idx From 88c7440dbf0131215f9e9795f70be2b41dbb9f28 Mon Sep 17 00:00:00 2001 From: Qi Zhao Date: Fri, 1 Aug 2025 21:04:10 +0800 Subject: [PATCH 2/5] deal with other lasx types --- .../LoongArch/LoongArchISelLowering.cpp | 68 +++++++++++++-- .../LoongArch/LoongArchLASXInstrInfo.td | 2 +- .../lasx/ir-instruction/extractelement.ll | 86 +++++++------------ 3 files changed, 90 insertions(+), 66 deletions(-) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 597650c8229a7..547a3163249cd 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -2614,21 +2614,71 @@ LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, SDValue Idx = Op->getOperand(1); unsigned NumElts = VecTy.getVectorNumElements(); SDLoc DL(Op); + MVT GRLenVT = Subtarget.getGRLenVT(); assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type"); if (isa(Idx) && Idx->getAsZExtVal() < NumElts) return Op; - // TODO: Deal with other legal 256-bits vector types? - if (!isa(Idx) && - (VecTy == MVT::v8i32 || VecTy == MVT::v8f32)) { - SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); - SDValue SplatValue = - DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx); - - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue, - DAG.getConstant(0, DL, Subtarget.getGRLenVT())); + if (!isa(Idx)) { + switch (VecTy.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unexpected type"); + case MVT::v32i8: + case MVT::v16i16: { + SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + SDValue NewIdx = DAG.getNode( + LoongArchISD::BSTRPICK, DL, GRLenVT, Idx, + DAG.getConstant(31, DL, GRLenVT), + DAG.getConstant(((VecTy == MVT::v32i8) ? 2 : 1), DL, GRLenVT)); + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, NewIdx); + SDValue SplatValue = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdx); + SDValue SplatVec = DAG.getBitcast(VecTy, SplatValue); + + SDValue LocalIdx = DAG.getNode( + ISD::AND, DL, GRLenVT, Idx, + DAG.getConstant(((VecTy == MVT::v32i8) ? 3 : 1), DL, GRLenVT)); + SDValue ExtractVec = + DAG.getNode(LoongArchISD::VREPLVE, DL, VecTy, SplatVec, LocalIdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, + DAG.getConstant(0, DL, GRLenVT)); + } + case MVT::v8i32: + case MVT::v8f32: { + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); + SDValue SplatValue = + DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue, + DAG.getConstant(0, DL, GRLenVT)); + } + case MVT::v4i64: + case MVT::v4f64: { + SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); + SDValue SplatIdxLo = + DAG.getNode(LoongArchISD::VSLLI, DL, MVT::v8i32, SplatIdx, + DAG.getConstant(1, DL, GRLenVT)); + SDValue SplatIdxHi = + DAG.getNode(ISD::ADD, DL, MVT::v8i32, SplatIdxLo, + DAG.getSplatBuildVector(MVT::v8i32, DL, + DAG.getConstant(1, DL, GRLenVT))); + + SDValue SplatVecLo = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxLo); + SDValue SplatVecHi = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxHi); + SDValue SplatValue = DAG.getNode(LoongArchISD::VILVL, DL, MVT::v8i32, + SplatVecHi, SplatVecLo); + SDValue ExtractVec = DAG.getBitcast(VecTy, SplatValue); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, + DAG.getConstant(0, DL, GRLenVT)); + } + } } return SDValue(); diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index 9dfd059a78286..506665538742b 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>, - SDTCisVec<2>, SDTCisInt<2>]>; + SDTCisVec<2>, SDTCisInt<2>]>; // Target nodes. def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>; diff --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll index b191a9d08ab2d..9b88ca85a170b 100644 --- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll +++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll @@ -76,21 +76,14 @@ define void @extract_4xdouble(ptr %src, ptr %dst) nounwind { define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_32xi8_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 0 -; CHECK-NEXT: ld.b $a0, $a0, 0 -; CHECK-NEXT: st.b $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 +; CHECK-NEXT: bstrpick.d $a0, $a0, 31, 2 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: andi $a0, $a2, 3 +; CHECK-NEXT: xvreplve.b $xr0, $xr0, $a0 +; CHECK-NEXT: xvstelm.b $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <32 x i8>, ptr %src %e = extractelement <32 x i8> %v, i32 %idx @@ -101,21 +94,14 @@ define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_16xi16_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 1 -; CHECK-NEXT: ld.h $a0, $a0, 0 -; CHECK-NEXT: st.h $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 +; CHECK-NEXT: bstrpick.d $a0, $a0, 31, 1 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: andi $a0, $a2, 1 +; CHECK-NEXT: xvreplve.h $xr0, $xr0, $a0 +; CHECK-NEXT: xvstelm.h $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <16 x i16>, ptr %src %e = extractelement <16 x i16> %v, i32 %idx @@ -141,21 +127,15 @@ define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_4xi64_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3 -; CHECK-NEXT: ld.d $a0, $a0, 0 -; CHECK-NEXT: st.d $a0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvslli.w $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1 +; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvilvl.w $xr0, $xr0, $xr2 +; CHECK-NEXT: xvstelm.d $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <4 x i64>, ptr %src %e = extractelement <4 x i64> %v, i32 %idx @@ -181,21 +161,15 @@ define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind { define void @extract_4xdouble_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_4xdouble_idx: ; CHECK: # %bb.0: -; CHECK-NEXT: addi.d $sp, $sp, -96 -; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill -; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill -; CHECK-NEXT: addi.d $fp, $sp, 96 -; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0 ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: xvst $xr0, $sp, 32 -; CHECK-NEXT: addi.d $a0, $sp, 32 -; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3 -; CHECK-NEXT: fld.d $fa0, $a0, 0 -; CHECK-NEXT: fst.d $fa0, $a1, 0 -; CHECK-NEXT: addi.d $sp, $fp, -96 -; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload -; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload -; CHECK-NEXT: addi.d $sp, $sp, 96 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvslli.w $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1 +; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1 +; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 +; CHECK-NEXT: xvilvl.w $xr0, $xr0, $xr2 +; CHECK-NEXT: xvstelm.d $xr0, $a1, 0, 0 ; CHECK-NEXT: ret %v = load volatile <4 x double>, ptr %src %e = extractelement <4 x double> %v, i32 %idx From 15842c115f9b3ecb68104174d538a3e40617bd68 Mon Sep 17 00:00:00 2001 From: Qi Zhao Date: Sat, 2 Aug 2025 11:49:31 +0800 Subject: [PATCH 3/5] add comments --- .../Target/LoongArch/LoongArchISelLowering.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 547a3163249cd..247069243cb3c 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -2627,7 +2627,12 @@ LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, llvm_unreachable("Unexpected type"); case MVT::v32i8: case MVT::v16i16: { + // Consider the source vector as v8i32 type. SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + + // Compute the adjusted index and use it to broadcast the vector. + // The original desired i8/i16 element is now replicated in each + // i32 lane of the splatted vector. SDValue NewIdx = DAG.getNode( LoongArchISD::BSTRPICK, DL, GRLenVT, Idx, DAG.getConstant(31, DL, GRLenVT), @@ -2637,6 +2642,9 @@ LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdx); SDValue SplatVec = DAG.getBitcast(VecTy, SplatValue); + // Compute the local index of the original i8/i16 element within the + // i32 element and then use it to broadcast the vector. Each elements + // of the vector will be the desired element. SDValue LocalIdx = DAG.getNode( ISD::AND, DL, GRLenVT, Idx, DAG.getConstant(((VecTy == MVT::v32i8) ? 3 : 1), DL, GRLenVT)); @@ -2657,7 +2665,11 @@ LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, } case MVT::v4i64: case MVT::v4f64: { + // Consider the source vector as v8i32 type. SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + + // Split the original element index into low and high parts: + // Lo = Idx * 2, Hi = Idx * 2 + 1. SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); SDValue SplatIdxLo = DAG.getNode(LoongArchISD::VSLLI, DL, MVT::v8i32, SplatIdx, @@ -2667,10 +2679,15 @@ LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, DAG.getSplatBuildVector(MVT::v8i32, DL, DAG.getConstant(1, DL, GRLenVT))); + // Use the broadcasted index to broadcast the low and high parts of the + // vector separately. SDValue SplatVecLo = DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxLo); SDValue SplatVecHi = DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxHi); + + // Combine the low and high i32 parts to reconstruct the original i64/f64 + // element. SDValue SplatValue = DAG.getNode(LoongArchISD::VILVL, DL, MVT::v8i32, SplatVecHi, SplatVecLo); SDValue ExtractVec = DAG.getBitcast(VecTy, SplatValue); From 4704d3a1e04e134e181de7a56be7a091067453f1 Mon Sep 17 00:00:00 2001 From: Qi Zhao Date: Mon, 4 Aug 2025 09:19:20 +0800 Subject: [PATCH 4/5] address comments --- .../LoongArch/LoongArchISelLowering.cpp | 152 +++++++++--------- 1 file changed, 73 insertions(+), 79 deletions(-) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 247069243cb3c..77b55b115beaa 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -2612,93 +2612,87 @@ LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, SDValue Vec = Op->getOperand(0); EVT VecTy = Vec->getValueType(0); SDValue Idx = Op->getOperand(1); - unsigned NumElts = VecTy.getVectorNumElements(); SDLoc DL(Op); MVT GRLenVT = Subtarget.getGRLenVT(); assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type"); - if (isa(Idx) && Idx->getAsZExtVal() < NumElts) + if (isa(Idx)) return Op; - if (!isa(Idx)) { - switch (VecTy.getSimpleVT().SimpleTy) { - default: - llvm_unreachable("Unexpected type"); - case MVT::v32i8: - case MVT::v16i16: { - // Consider the source vector as v8i32 type. - SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); - - // Compute the adjusted index and use it to broadcast the vector. - // The original desired i8/i16 element is now replicated in each - // i32 lane of the splatted vector. - SDValue NewIdx = DAG.getNode( - LoongArchISD::BSTRPICK, DL, GRLenVT, Idx, - DAG.getConstant(31, DL, GRLenVT), - DAG.getConstant(((VecTy == MVT::v32i8) ? 2 : 1), DL, GRLenVT)); - SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, NewIdx); - SDValue SplatValue = - DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdx); - SDValue SplatVec = DAG.getBitcast(VecTy, SplatValue); - - // Compute the local index of the original i8/i16 element within the - // i32 element and then use it to broadcast the vector. Each elements - // of the vector will be the desired element. - SDValue LocalIdx = DAG.getNode( - ISD::AND, DL, GRLenVT, Idx, - DAG.getConstant(((VecTy == MVT::v32i8) ? 3 : 1), DL, GRLenVT)); - SDValue ExtractVec = - DAG.getNode(LoongArchISD::VREPLVE, DL, VecTy, SplatVec, LocalIdx); - - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, - DAG.getConstant(0, DL, GRLenVT)); - } - case MVT::v8i32: - case MVT::v8f32: { - SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); - SDValue SplatValue = - DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx); - - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue, - DAG.getConstant(0, DL, GRLenVT)); - } - case MVT::v4i64: - case MVT::v4f64: { - // Consider the source vector as v8i32 type. - SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); - - // Split the original element index into low and high parts: - // Lo = Idx * 2, Hi = Idx * 2 + 1. - SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); - SDValue SplatIdxLo = - DAG.getNode(LoongArchISD::VSLLI, DL, MVT::v8i32, SplatIdx, - DAG.getConstant(1, DL, GRLenVT)); - SDValue SplatIdxHi = - DAG.getNode(ISD::ADD, DL, MVT::v8i32, SplatIdxLo, - DAG.getSplatBuildVector(MVT::v8i32, DL, - DAG.getConstant(1, DL, GRLenVT))); - - // Use the broadcasted index to broadcast the low and high parts of the - // vector separately. - SDValue SplatVecLo = - DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxLo); - SDValue SplatVecHi = - DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxHi); - - // Combine the low and high i32 parts to reconstruct the original i64/f64 - // element. - SDValue SplatValue = DAG.getNode(LoongArchISD::VILVL, DL, MVT::v8i32, - SplatVecHi, SplatVecLo); - SDValue ExtractVec = DAG.getBitcast(VecTy, SplatValue); - - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, - DAG.getConstant(0, DL, GRLenVT)); - } - } + switch (VecTy.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unexpected type"); + case MVT::v32i8: + case MVT::v16i16: { + // Consider the source vector as v8i32 type. + SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + + // Compute the adjusted index and use it to broadcast the vector. + // The original desired i8/i16 element is now replicated in each + // i32 lane of the splatted vector. + SDValue NewIdx = DAG.getNode( + LoongArchISD::BSTRPICK, DL, GRLenVT, Idx, + DAG.getConstant(31, DL, GRLenVT), + DAG.getConstant(((VecTy == MVT::v32i8) ? 2 : 1), DL, GRLenVT)); + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, NewIdx); + SDValue SplatValue = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdx); + SDValue SplatVec = DAG.getBitcast(VecTy, SplatValue); + + // Compute the local index of the original i8/i16 element within the + // i32 element and then use it to broadcast the vector. Each elements + // of the vector will be the desired element. + SDValue LocalIdx = DAG.getNode( + ISD::AND, DL, GRLenVT, Idx, + DAG.getConstant(((VecTy == MVT::v32i8) ? 3 : 1), DL, GRLenVT)); + SDValue ExtractVec = + DAG.getNode(LoongArchISD::VREPLVE, DL, VecTy, SplatVec, LocalIdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, + DAG.getConstant(0, DL, GRLenVT)); + } + case MVT::v8i32: + case MVT::v8f32: { + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); + SDValue SplatValue = + DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue, + DAG.getConstant(0, DL, GRLenVT)); + } + case MVT::v4i64: + case MVT::v4f64: { + // Consider the source vector as v8i32 type. + SDValue NewVec = DAG.getBitcast(MVT::v8i32, Vec); + + // Split the original element index into low and high parts: + // Lo = Idx * 2, Hi = Idx * 2 + 1. + SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx); + SDValue SplatIdxLo = DAG.getNode(LoongArchISD::VSLLI, DL, MVT::v8i32, + SplatIdx, DAG.getConstant(1, DL, GRLenVT)); + SDValue SplatIdxHi = + DAG.getNode(ISD::ADD, DL, MVT::v8i32, SplatIdxLo, + DAG.getSplatBuildVector(MVT::v8i32, DL, + DAG.getConstant(1, DL, GRLenVT))); + + // Use the broadcasted index to broadcast the low and high parts of the + // vector separately. + SDValue SplatVecLo = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxLo); + SDValue SplatVecHi = + DAG.getNode(LoongArchISD::XVPERM, DL, MVT::v8i32, NewVec, SplatIdxHi); + + // Combine the low and high i32 parts to reconstruct the original i64/f64 + // element. + SDValue SplatValue = DAG.getNode(LoongArchISD::VILVL, DL, MVT::v8i32, + SplatVecHi, SplatVecLo); + SDValue ExtractVec = DAG.getBitcast(VecTy, SplatValue); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ExtractVec, + DAG.getConstant(0, DL, GRLenVT)); + } } - - return SDValue(); } SDValue From f8b7d4ce77ad46cdbb890a6b1f42d65619847706 Mon Sep 17 00:00:00 2001 From: Qi Zhao Date: Mon, 4 Aug 2025 11:15:53 +0800 Subject: [PATCH 5/5] perform combine for extract_vector_elt --- .../LoongArch/LoongArchISelLowering.cpp | 43 +++++++++++++++++++ .../lasx/ir-instruction/extractelement.ll | 18 +++----- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 77b55b115beaa..026b1e5981233 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -412,6 +412,11 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::BITCAST); } + // Set DAG combine for 'LASX' feature. + + if (Subtarget.hasExtLASX()) + setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + // Compute derived properties from the register classes. computeRegisterProperties(Subtarget.getRegisterInfo()); @@ -5907,6 +5912,42 @@ performSPLIT_PAIR_F64Combine(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue +performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const LoongArchSubtarget &Subtarget) { + if (!DCI.isBeforeLegalize()) + return SDValue(); + + MVT EltVT = N->getSimpleValueType(0); + SDValue Vec = N->getOperand(0); + EVT VecTy = Vec->getValueType(0); + SDValue Idx = N->getOperand(1); + unsigned IdxOp = Idx.getOpcode(); + SDLoc DL(N); + + if (!VecTy.is256BitVector() || isa(Idx)) + return SDValue(); + + // Combine: + // t2 = truncate t1 + // t3 = {zero/sign/any}_extend t2 + // t4 = extract_vector_elt t0, t3 + // to: + // t4 = extract_vector_elt t0, t1 + if (IdxOp == ISD::ZERO_EXTEND || IdxOp == ISD::SIGN_EXTEND || + IdxOp == ISD::ANY_EXTEND) { + SDValue IdxOrig = Idx.getOperand(0); + if (!(IdxOrig.getOpcode() == ISD::TRUNCATE)) + return SDValue(); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec, + IdxOrig.getOperand(0)); + } + + return SDValue(); +} + SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -5936,6 +5977,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N, return performVMSKLTZCombine(N, DAG, DCI, Subtarget); case LoongArchISD::SPLIT_PAIR_F64: return performSPLIT_PAIR_F64Combine(N, DAG, DCI, Subtarget); + case ISD::EXTRACT_VECTOR_ELT: + return performEXTRACT_VECTOR_ELTCombine(N, DAG, DCI, Subtarget); } return SDValue(); } diff --git a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll index 9b88ca85a170b..72542df328250 100644 --- a/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll +++ b/llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll @@ -77,8 +77,7 @@ define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_32xi8_idx: ; CHECK: # %bb.0: ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 -; CHECK-NEXT: bstrpick.d $a0, $a0, 31, 2 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 2 ; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 ; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 ; CHECK-NEXT: andi $a0, $a2, 3 @@ -95,8 +94,7 @@ define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_16xi16_idx: ; CHECK: # %bb.0: ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 -; CHECK-NEXT: bstrpick.d $a0, $a0, 31, 1 +; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 1 ; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 ; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 ; CHECK-NEXT: andi $a0, $a2, 1 @@ -113,8 +111,7 @@ define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_8xi32_idx: ; CHECK: # %bb.0: ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 -; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 ; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 ; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0 ; CHECK-NEXT: ret @@ -128,8 +125,7 @@ define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_4xi64_idx: ; CHECK: # %bb.0: ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 -; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 ; CHECK-NEXT: xvslli.w $xr1, $xr1, 1 ; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1 ; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1 @@ -147,8 +143,7 @@ define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_8xfloat_idx: ; CHECK: # %bb.0: ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 -; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 ; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1 ; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0 ; CHECK-NEXT: ret @@ -162,8 +157,7 @@ define void @extract_4xdouble_idx(ptr %src, ptr %dst, i32 %idx) nounwind { ; CHECK-LABEL: extract_4xdouble_idx: ; CHECK: # %bb.0: ; CHECK-NEXT: xvld $xr0, $a0, 0 -; CHECK-NEXT: bstrpick.d $a0, $a2, 31, 0 -; CHECK-NEXT: xvreplgr2vr.w $xr1, $a0 +; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2 ; CHECK-NEXT: xvslli.w $xr1, $xr1, 1 ; CHECK-NEXT: xvperm.w $xr2, $xr0, $xr1 ; CHECK-NEXT: xvaddi.wu $xr1, $xr1, 1