Skip to content

[RISCV] Implement load/store support for XAndesBFHCvt #150350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2025

Conversation

tclin914
Copy link
Contributor

We use lh to load 2 bytes from memory into a gpr, then mask this gpr with -65536 to emulate nan-boxing behavior, and then the value in gpr is moved to fpr using fmv.w.x.
To move the value back from fpr to gpr, we use fmv.x.w and finally, sh is used to store the lower 2 bytes back to memory.

If zfh is enabled at the same time, we can just use flh/fsw to load/store bf16 directly.

We use lh to load 2 bytes from memory into a gpr, then mask this gpr with
-65536 to emulate nan-boxing behavior, and then the value in gpr is moved to
fpr using `fmv.w.x`.
To move the value back from fpr to gpr, we use `fmv.x.w` and finally, `sh`
is used to store the lower 2 bytes back to memory.

If zfh is enabled at the same time, we can just use flh/fsw to
load/store bf16 directly.
@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Jim Lin (tclin914)

Changes

We use lh to load 2 bytes from memory into a gpr, then mask this gpr with -65536 to emulate nan-boxing behavior, and then the value in gpr is moved to fpr using fmv.w.x.
To move the value back from fpr to gpr, we use fmv.x.w and finally, sh is used to store the lower 2 bytes back to memory.

If zfh is enabled at the same time, we can just use flh/fsw to load/store bf16 directly.


Full diff: https://github.com/llvm/llvm-project/pull/150350.diff

4 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+54)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+3)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td (+33)
  • (modified) llvm/test/CodeGen/RISCV/xandesbfhcvt.ll (+50-2)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3918dd21bc09d..8b5ae01282293 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     }
   }
 
+  // Customize load and store operation for bf16 if zfh isn't enabled.
+  if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
+    setOperationAction(ISD::LOAD, MVT::bf16, Custom);
+    setOperationAction(ISD::STORE, MVT::bf16, Custom);
+  }
+
   // Function alignments.
   const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
   setMinFunctionAlignment(FunctionAlignment);
@@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
   return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
 }
 
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
+                                                   SelectionDAG &DAG) const {
+  assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+         "Unexpected bfloat16 load lowering");
+
+  SDLoc DL(Op);
+  LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+  EVT MemVT = LD->getMemoryVT();
+  SDValue Load = DAG.getExtLoad(
+      ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
+      LD->getBasePtr(),
+      EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
+      LD->getMemOperand());
+  // Using mask to make bf16 nan-boxing valid when we don't have flh
+  // instruction. -65536 would be treat as a small number and thus it can be
+  // directly used lui to get the constant.
+  SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
+  SDValue OrSixteenOne =
+      DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
+  SDValue ConvertedResult =
+      DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
+  return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
+}
+
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+         "Unexpected bfloat16 store lowering");
+
+  StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+  SDLoc DL(Op);
+  SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
+                            Subtarget.getXLenVT(), ST->getValue());
+  return DAG.getTruncStore(
+      ST->getChain(), DL, FMV, ST->getBasePtr(),
+      EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
+      ST->getMemOperand());
+}
+
 SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
                                             SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       return DAG.getMergeValues({Pair, Chain}, DL);
     }
 
+    if (VT == MVT::bf16)
+      return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
+
     // Handle normal vector tuple load.
     if (VT.isRISCVVectorTuple()) {
       SDLoc DL(Op);
@@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
           {Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
           Store->getMemOperand());
     }
+
+    if (VT == MVT::bf16)
+      return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
+
     // Handle normal vector tuple store.
     if (VT.isRISCVVectorTuple()) {
       SDLoc DL(Op);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index f0447e02191ae..ca70c46988b4e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -578,6 +578,9 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue lowerXAndesBfHCvtBFloat16Load(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerXAndesBfHCvtBFloat16Store(SDValue Op, SelectionDAG &DAG) const;
+
   bool isEligibleForTailCallOptimization(
       CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
       const SmallVector<CCValAssign, 16> &ArgLocs) const;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
index 5220815336441..5d0b66aab5320 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
@@ -10,6 +10,20 @@
 //
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// RISC-V specific DAG Nodes.
+//===----------------------------------------------------------------------===//
+
+def SDT_NDS_FMV_BF16_X
+    : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>;
+def SDT_NDS_FMV_X_ANYEXTBF16
+    : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>;
+
+def riscv_nds_fmv_bf16_x
+    : SDNode<"RISCVISD::NDS_FMV_BF16_X", SDT_NDS_FMV_BF16_X>;
+def riscv_nds_fmv_x_anyextbf16
+    : SDNode<"RISCVISD::NDS_FMV_X_ANYEXTBF16", SDT_NDS_FMV_X_ANYEXTBF16>;
+
 //===----------------------------------------------------------------------===//
 // Operand and SDNode transformation definitions.
 //===----------------------------------------------------------------------===//
@@ -774,6 +788,25 @@ def : Pat<(bf16 (fpround FPR32:$rs)),
           (NDS_FCVT_BF16_S FPR32:$rs)>;
 } // Predicates = [HasVendorXAndesBFHCvt]
 
+let isCodeGenOnly = 1 in {
+def NDS_FMV_BF16_X : FPUnaryOp_r<0b1111000, 0b00000, 0b000, FPR16, GPR, "fmv.w.x">,
+                     Sched<[WriteFMovI32ToF32, ReadFMovI32ToF32]>;
+def NDS_FMV_X_BF16 : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR16, "fmv.x.w">,
+                     Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>;
+}
+
+let Predicates = [HasVendorXAndesBFHCvt] in {
+def : Pat<(riscv_nds_fmv_bf16_x GPR:$src), (NDS_FMV_BF16_X GPR:$src)>;
+def : Pat<(riscv_nds_fmv_x_anyextbf16 (bf16 FPR16:$src)),
+          (NDS_FMV_X_BF16 (bf16 FPR16:$src))>;
+} // Predicates = [HasVendorXAndesBFHCvt]
+
+// Use flh/fsh to load/store bf16 if zfh is enabled.
+let Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] in {
+def : LdPat<load, FLH, bf16>;
+def : StPat<store, FSH, FPR16, bf16>;
+} // Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt]
+
 let Predicates = [HasVendorXAndesVBFHCvt] in {
 defm PseudoNDS_VFWCVT_S_BF16 : VPseudoVWCVT_S_BF16;
 defm PseudoNDS_VFNCVT_BF16_S : VPseudoVNCVT_BF16_S;
diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
index 854d0b659ea73..c0c15172676fd 100644
--- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
+++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
@@ -1,8 +1,12 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv32 -mattr=+xandesbfhcvt -target-abi ilp32f \
-; RUN:   -verify-machineinstrs < %s | FileCheck %s
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv32 -mattr=+zfh,+xandesbfhcvt -target-abi ilp32f \
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
 ; RUN: llc -mtriple=riscv64 -mattr=+xandesbfhcvt -target-abi lp64f \
-; RUN:   -verify-machineinstrs < %s | FileCheck %s
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv64 -mattr=+zfh,+xandesbfhcvt -target-abi lp64f \
+; RUN:   -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
 
 define float @fcvt_s_bf16(bfloat %a) nounwind {
 ; CHECK-LABEL: fcvt_s_bf16:
@@ -21,3 +25,47 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
   %1 = fptrunc float %a to bfloat
   ret bfloat %1
 }
+
+@sf = dso_local global float 0.000000e+00, align 4
+@bf = dso_local global bfloat 0xR0000, align 2
+
+; Check load and store to bf16.
+define void @loadstorebf16() nounwind {
+; XANDESBFHCVT-LABEL: loadstorebf16:
+; XANDESBFHCVT:       # %bb.0: # %entry
+; XANDESBFHCVT-NEXT:    lui a0, %hi(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT:    lhu a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT:    lui a2, 1048560
+; XANDESBFHCVT-NEXT:    or a1, a1, a2
+; XANDESBFHCVT-NEXT:    fmv.w.x fa5, a1
+; XANDESBFHCVT-NEXT:    addi a1, a0, %lo(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT:    nds.fcvt.s.bf16 fa5, fa5
+; XANDESBFHCVT-NEXT:    fsw fa5, 4(a1)
+; XANDESBFHCVT-NEXT:    flw fa5, 4(a1)
+; XANDESBFHCVT-NEXT:    nds.fcvt.bf16.s fa5, fa5
+; XANDESBFHCVT-NEXT:    fmv.x.w a1, fa5
+; XANDESBFHCVT-NEXT:    sh a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT:    ret
+;
+; ZFH-LABEL: loadstorebf16:
+; ZFH:       # %bb.0: # %entry
+; ZFH-NEXT:    lui a0, %hi(.L_MergedGlobals)
+; ZFH-NEXT:    flh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT:    addi a1, a0, %lo(.L_MergedGlobals)
+; ZFH-NEXT:    nds.fcvt.s.bf16 fa5, fa5
+; ZFH-NEXT:    fsw fa5, 4(a1)
+; ZFH-NEXT:    flw fa5, 4(a1)
+; ZFH-NEXT:    nds.fcvt.bf16.s fa5, fa5
+; ZFH-NEXT:    fsh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT:    ret
+entry:
+  %0 = load bfloat, bfloat* @bf, align 2
+  %1 = fpext bfloat %0 to float
+  store volatile float %1, float* @sf, align 4
+
+  %2 = load float, float* @sf, align 4
+  %3 = fptrunc float %2 to bfloat
+  store volatile bfloat %3, bfloat* @bf, align 2
+
+  ret void
+}

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tclin914 tclin914 merged commit 4e3266f into llvm:main Jul 25, 2025
9 checks passed
@tclin914 tclin914 deleted the andes-xandesbfhcvt-loadstore branch July 25, 2025 03:29
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 25, 2025

LLVM Buildbot has detected a new failure on builder llvm-clang-x86_64-expensive-checks-ubuntu running on as-builder-4 while building llvm at step 7 "test-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/187/builds/8489

Here is the relevant piece of the build log for the reference
Step 7 (test-check-all) failure: Test just built components: check-all completed (failure)
******************** TEST 'LLVM :: TableGen/RuntimeLibcallEmitter.td' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
/home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/build/bin/llvm-tblgen -gen-runtime-libcalls -I /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/../../include /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/RuntimeLibcallEmitter.td | /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/build/bin/FileCheck /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/RuntimeLibcallEmitter.td # RUN: at line 1
+ /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/build/bin/FileCheck /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/RuntimeLibcallEmitter.td
+ /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/build/bin/llvm-tblgen -gen-runtime-libcalls -I /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/../../include /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/RuntimeLibcallEmitter.td
/home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/RuntimeLibcallEmitter.td:98:16: error: CHECK-NEXT: expected string not found in input
// CHECK-NEXT: sqrtl_f80 = 7, // sqrtl
               ^
<stdin>:32:23: note: scanning from here
 calloc = 6, // calloc
                      ^
<stdin>:34:2: note: possible intended match here
 sqrtl_f80 = 8, // sqrtl
 ^

Input file: <stdin>
Check file: /home/buildbot/worker/as-builder-4/ramdisk/expensive-checks/llvm-project/llvm/test/TableGen/RuntimeLibcallEmitter.td

-dump-input=help explains the following input dump.

Input was:
<<<<<<
           .
           .
           .
          27:  ___memcpy = 1, // ___memcpy 
          28:  ___memset = 2, // ___memset 
          29:  __ashlsi3 = 3, // __ashlsi3 
          30:  __lshrdi3 = 4, // __lshrdi3 
          31:  bzero = 5, // bzero 
          32:  calloc = 6, // calloc 
next:98'0                           X error: no match found
          33:  sqrtl_f128 = 7, // sqrtl 
next:98'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~
          34:  sqrtl_f80 = 8, // sqrtl 
next:98'0     ~~~~~~~~~~~~~~~~~~~~~~~~~
next:98'1      ?                        possible intended match
          35:  NumLibcallImpls = 9 
next:98'0     ~~~~~~~~~~~~~~~~~~~~~
          36: }; 
next:98'0     ~~~
          37: } // End namespace RTLIB 
next:98'0     ~~~~~~~~~~~~~~~~~~~~~~~~~
          38: } // End namespace llvm 
next:98'0     ~~~~~~~~~~~~~~~~~~~~~~~~
          39: #endif 
next:98'0     ~~~~~~~
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants