From 1a8c16be7ca7755f7b2a222faf270167ba863d28 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Thu, 17 Jul 2025 09:23:19 +0800 Subject: [PATCH 1/2] [RISCV] Implement load/store support for XAndesBFHCvt We use lh to load 2 bytes from memory into a gpr, then mask this gpr with -65536 to emulate nan-boxing behavior, and then the value in gpr is moved to fpr using `fmv.w.x`. To move the value back from fpr to gpr, we use `fmv.x.w` and finally, `sh` is used to store the lower 2 bytes back to memory. If zfh is enabled at the same time, we can just use flh/fsw to load/store bf16 directly. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 54 +++++++++++++++++++ llvm/lib/Target/RISCV/RISCVISelLowering.h | 3 ++ llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td | 33 ++++++++++++ llvm/test/CodeGen/RISCV/xandesbfhcvt.ll | 52 +++++++++++++++++- 4 files changed, 140 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 3918dd21bc09d..8b5ae01282293 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } } + // Customize load and store operation for bf16 if zfh isn't enabled. + if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) { + setOperationAction(ISD::LOAD, MVT::bf16, Custom); + setOperationAction(ISD::STORE, MVT::bf16, Custom); + } + // Function alignments. const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4); setMinFunctionAlignment(FunctionAlignment); @@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) { return DAG.getMergeValues({V, HiRes.getValue(1)}, DL); } +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 load lowering"); + + SDLoc DL(Op); + LoadSDNode *LD = cast(Op.getNode()); + EVT MemVT = LD->getMemoryVT(); + SDValue Load = DAG.getExtLoad( + ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(), + LD->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()), + LD->getMemOperand()); + // Using mask to make bf16 nan-boxing valid when we don't have flh + // instruction. -65536 would be treat as a small number and thus it can be + // directly used lui to get the constant. + SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT()); + SDValue OrSixteenOne = + DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask}); + SDValue ConvertedResult = + DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne); + return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL); +} + +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 store lowering"); + + StoreSDNode *ST = cast(Op.getNode()); + SDLoc DL(Op); + SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL, + Subtarget.getXLenVT(), ST->getValue()); + return DAG.getTruncStore( + ST->getChain(), DL, FMV, ST->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()), + ST->getMemOperand()); +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getMergeValues({Pair, Chain}, DL); } + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Load(Op, DAG); + // Handle normal vector tuple load. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); @@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, {Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64, Store->getMemOperand()); } + + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Store(Op, DAG); + // Handle normal vector tuple store. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index f0447e02191ae..ca70c46988b4e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -578,6 +578,9 @@ class RISCVTargetLowering : public TargetLowering { SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const; SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerXAndesBfHCvtBFloat16Load(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerXAndesBfHCvtBFloat16Store(SDValue Op, SelectionDAG &DAG) const; + bool isEligibleForTailCallOptimization( CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF, const SmallVector &ArgLocs) const; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td index 5220815336441..5d0b66aab5320 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td @@ -10,6 +10,20 @@ // //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// RISC-V specific DAG Nodes. +//===----------------------------------------------------------------------===// + +def SDT_NDS_FMV_BF16_X + : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>; +def SDT_NDS_FMV_X_ANYEXTBF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>; + +def riscv_nds_fmv_bf16_x + : SDNode<"RISCVISD::NDS_FMV_BF16_X", SDT_NDS_FMV_BF16_X>; +def riscv_nds_fmv_x_anyextbf16 + : SDNode<"RISCVISD::NDS_FMV_X_ANYEXTBF16", SDT_NDS_FMV_X_ANYEXTBF16>; + //===----------------------------------------------------------------------===// // Operand and SDNode transformation definitions. //===----------------------------------------------------------------------===// @@ -774,6 +788,25 @@ def : Pat<(bf16 (fpround FPR32:$rs)), (NDS_FCVT_BF16_S FPR32:$rs)>; } // Predicates = [HasVendorXAndesBFHCvt] +let isCodeGenOnly = 1 in { +def NDS_FMV_BF16_X : FPUnaryOp_r<0b1111000, 0b00000, 0b000, FPR16, GPR, "fmv.w.x">, + Sched<[WriteFMovI32ToF32, ReadFMovI32ToF32]>; +def NDS_FMV_X_BF16 : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR16, "fmv.x.w">, + Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>; +} + +let Predicates = [HasVendorXAndesBFHCvt] in { +def : Pat<(riscv_nds_fmv_bf16_x GPR:$src), (NDS_FMV_BF16_X GPR:$src)>; +def : Pat<(riscv_nds_fmv_x_anyextbf16 (bf16 FPR16:$src)), + (NDS_FMV_X_BF16 (bf16 FPR16:$src))>; +} // Predicates = [HasVendorXAndesBFHCvt] + +// Use flh/fsh to load/store bf16 if zfh is enabled. +let Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] in { +def : LdPat; +def : StPat; +} // Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] + let Predicates = [HasVendorXAndesVBFHCvt] in { defm PseudoNDS_VFWCVT_S_BF16 : VPseudoVWCVT_S_BF16; defm PseudoNDS_VFNCVT_BF16_S : VPseudoVNCVT_BF16_S; diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll index 854d0b659ea73..c0c15172676fd 100644 --- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll +++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll @@ -1,8 +1,12 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=riscv32 -mattr=+xandesbfhcvt -target-abi ilp32f \ -; RUN: -verify-machineinstrs < %s | FileCheck %s +; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s +; RUN: llc -mtriple=riscv32 -mattr=+zfh,+xandesbfhcvt -target-abi ilp32f \ +; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s ; RUN: llc -mtriple=riscv64 -mattr=+xandesbfhcvt -target-abi lp64f \ -; RUN: -verify-machineinstrs < %s | FileCheck %s +; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s +; RUN: llc -mtriple=riscv64 -mattr=+zfh,+xandesbfhcvt -target-abi lp64f \ +; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s define float @fcvt_s_bf16(bfloat %a) nounwind { ; CHECK-LABEL: fcvt_s_bf16: @@ -21,3 +25,47 @@ define bfloat @fcvt_bf16_s(float %a) nounwind { %1 = fptrunc float %a to bfloat ret bfloat %1 } + +@sf = dso_local global float 0.000000e+00, align 4 +@bf = dso_local global bfloat 0xR0000, align 2 + +; Check load and store to bf16. +define void @loadstorebf16() nounwind { +; XANDESBFHCVT-LABEL: loadstorebf16: +; XANDESBFHCVT: # %bb.0: # %entry +; XANDESBFHCVT-NEXT: lui a0, %hi(.L_MergedGlobals) +; XANDESBFHCVT-NEXT: lhu a1, %lo(.L_MergedGlobals)(a0) +; XANDESBFHCVT-NEXT: lui a2, 1048560 +; XANDESBFHCVT-NEXT: or a1, a1, a2 +; XANDESBFHCVT-NEXT: fmv.w.x fa5, a1 +; XANDESBFHCVT-NEXT: addi a1, a0, %lo(.L_MergedGlobals) +; XANDESBFHCVT-NEXT: nds.fcvt.s.bf16 fa5, fa5 +; XANDESBFHCVT-NEXT: fsw fa5, 4(a1) +; XANDESBFHCVT-NEXT: flw fa5, 4(a1) +; XANDESBFHCVT-NEXT: nds.fcvt.bf16.s fa5, fa5 +; XANDESBFHCVT-NEXT: fmv.x.w a1, fa5 +; XANDESBFHCVT-NEXT: sh a1, %lo(.L_MergedGlobals)(a0) +; XANDESBFHCVT-NEXT: ret +; +; ZFH-LABEL: loadstorebf16: +; ZFH: # %bb.0: # %entry +; ZFH-NEXT: lui a0, %hi(.L_MergedGlobals) +; ZFH-NEXT: flh fa5, %lo(.L_MergedGlobals)(a0) +; ZFH-NEXT: addi a1, a0, %lo(.L_MergedGlobals) +; ZFH-NEXT: nds.fcvt.s.bf16 fa5, fa5 +; ZFH-NEXT: fsw fa5, 4(a1) +; ZFH-NEXT: flw fa5, 4(a1) +; ZFH-NEXT: nds.fcvt.bf16.s fa5, fa5 +; ZFH-NEXT: fsh fa5, %lo(.L_MergedGlobals)(a0) +; ZFH-NEXT: ret +entry: + %0 = load bfloat, bfloat* @bf, align 2 + %1 = fpext bfloat %0 to float + store volatile float %1, float* @sf, align 4 + + %2 = load float, float* @sf, align 4 + %3 = fptrunc float %2 to bfloat + store volatile bfloat %3, bfloat* @bf, align 2 + + ret void +} From 11dc0c3a1058e637dedec97e87cf1c44a5cd6f78 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Thu, 24 Jul 2025 14:36:44 +0800 Subject: [PATCH 2/2] Load/store bf16 from/to the address passed from argument instead of global variables --- llvm/test/CodeGen/RISCV/xandesbfhcvt.ll | 39 ++++++++++--------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll index c0c15172676fd..72242f1dd312d 100644 --- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll +++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll @@ -26,46 +26,39 @@ define bfloat @fcvt_bf16_s(float %a) nounwind { ret bfloat %1 } -@sf = dso_local global float 0.000000e+00, align 4 -@bf = dso_local global bfloat 0xR0000, align 2 - ; Check load and store to bf16. -define void @loadstorebf16() nounwind { +define void @loadstorebf16(ptr %bf, ptr %sf) nounwind { ; XANDESBFHCVT-LABEL: loadstorebf16: ; XANDESBFHCVT: # %bb.0: # %entry -; XANDESBFHCVT-NEXT: lui a0, %hi(.L_MergedGlobals) -; XANDESBFHCVT-NEXT: lhu a1, %lo(.L_MergedGlobals)(a0) -; XANDESBFHCVT-NEXT: lui a2, 1048560 -; XANDESBFHCVT-NEXT: or a1, a1, a2 -; XANDESBFHCVT-NEXT: fmv.w.x fa5, a1 -; XANDESBFHCVT-NEXT: addi a1, a0, %lo(.L_MergedGlobals) +; XANDESBFHCVT-NEXT: lhu a2, 0(a0) +; XANDESBFHCVT-NEXT: lui a3, 1048560 +; XANDESBFHCVT-NEXT: or a2, a2, a3 +; XANDESBFHCVT-NEXT: fmv.w.x fa5, a2 ; XANDESBFHCVT-NEXT: nds.fcvt.s.bf16 fa5, fa5 -; XANDESBFHCVT-NEXT: fsw fa5, 4(a1) -; XANDESBFHCVT-NEXT: flw fa5, 4(a1) +; XANDESBFHCVT-NEXT: fsw fa5, 0(a1) +; XANDESBFHCVT-NEXT: flw fa5, 0(a1) ; XANDESBFHCVT-NEXT: nds.fcvt.bf16.s fa5, fa5 ; XANDESBFHCVT-NEXT: fmv.x.w a1, fa5 -; XANDESBFHCVT-NEXT: sh a1, %lo(.L_MergedGlobals)(a0) +; XANDESBFHCVT-NEXT: sh a1, 0(a0) ; XANDESBFHCVT-NEXT: ret ; ; ZFH-LABEL: loadstorebf16: ; ZFH: # %bb.0: # %entry -; ZFH-NEXT: lui a0, %hi(.L_MergedGlobals) -; ZFH-NEXT: flh fa5, %lo(.L_MergedGlobals)(a0) -; ZFH-NEXT: addi a1, a0, %lo(.L_MergedGlobals) +; ZFH-NEXT: flh fa5, 0(a0) ; ZFH-NEXT: nds.fcvt.s.bf16 fa5, fa5 -; ZFH-NEXT: fsw fa5, 4(a1) -; ZFH-NEXT: flw fa5, 4(a1) +; ZFH-NEXT: fsw fa5, 0(a1) +; ZFH-NEXT: flw fa5, 0(a1) ; ZFH-NEXT: nds.fcvt.bf16.s fa5, fa5 -; ZFH-NEXT: fsh fa5, %lo(.L_MergedGlobals)(a0) +; ZFH-NEXT: fsh fa5, 0(a0) ; ZFH-NEXT: ret entry: - %0 = load bfloat, bfloat* @bf, align 2 + %0 = load bfloat, bfloat* %bf, align 2 %1 = fpext bfloat %0 to float - store volatile float %1, float* @sf, align 4 + store volatile float %1, float* %sf, align 4 - %2 = load float, float* @sf, align 4 + %2 = load float, float* %sf, align 4 %3 = fptrunc float %2 to bfloat - store volatile bfloat %3, bfloat* @bf, align 2 + store volatile bfloat %3, bfloat* %bf, align 2 ret void }