-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[RISCV] Implement load/store support for XAndesBFHCvt #150350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
We use lh to load 2 bytes from memory into a gpr, then mask this gpr with -65536 to emulate nan-boxing behavior, and then the value in gpr is moved to fpr using `fmv.w.x`. To move the value back from fpr to gpr, we use `fmv.x.w` and finally, `sh` is used to store the lower 2 bytes back to memory. If zfh is enabled at the same time, we can just use flh/fsw to load/store bf16 directly.
@llvm/pr-subscribers-backend-risc-v Author: Jim Lin (tclin914) ChangesWe use If zfh is enabled at the same time, we can just use flh/fsw to load/store bf16 directly. Full diff: https://github.com/llvm/llvm-project/pull/150350.diff 4 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3918dd21bc09d..8b5ae01282293 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}
}
+ // Customize load and store operation for bf16 if zfh isn't enabled.
+ if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
+ setOperationAction(ISD::LOAD, MVT::bf16, Custom);
+ setOperationAction(ISD::STORE, MVT::bf16, Custom);
+ }
+
// Function alignments.
const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
setMinFunctionAlignment(FunctionAlignment);
@@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
}
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 load lowering");
+
+ SDLoc DL(Op);
+ LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+ EVT MemVT = LD->getMemoryVT();
+ SDValue Load = DAG.getExtLoad(
+ ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
+ LD->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
+ LD->getMemOperand());
+ // Using mask to make bf16 nan-boxing valid when we don't have flh
+ // instruction. -65536 would be treat as a small number and thus it can be
+ // directly used lui to get the constant.
+ SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
+ SDValue OrSixteenOne =
+ DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
+ SDValue ConvertedResult =
+ DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
+ return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
+}
+
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 store lowering");
+
+ StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+ SDLoc DL(Op);
+ SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
+ Subtarget.getXLenVT(), ST->getValue());
+ return DAG.getTruncStore(
+ ST->getChain(), DL, FMV, ST->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
+ ST->getMemOperand());
+}
+
SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getMergeValues({Pair, Chain}, DL);
}
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
+
// Handle normal vector tuple load.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
@@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
{Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
Store->getMemOperand());
}
+
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
+
// Handle normal vector tuple store.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index f0447e02191ae..ca70c46988b4e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -578,6 +578,9 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerXAndesBfHCvtBFloat16Load(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerXAndesBfHCvtBFloat16Store(SDValue Op, SelectionDAG &DAG) const;
+
bool isEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
const SmallVector<CCValAssign, 16> &ArgLocs) const;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
index 5220815336441..5d0b66aab5320 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
@@ -10,6 +10,20 @@
//
//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// RISC-V specific DAG Nodes.
+//===----------------------------------------------------------------------===//
+
+def SDT_NDS_FMV_BF16_X
+ : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>;
+def SDT_NDS_FMV_X_ANYEXTBF16
+ : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>;
+
+def riscv_nds_fmv_bf16_x
+ : SDNode<"RISCVISD::NDS_FMV_BF16_X", SDT_NDS_FMV_BF16_X>;
+def riscv_nds_fmv_x_anyextbf16
+ : SDNode<"RISCVISD::NDS_FMV_X_ANYEXTBF16", SDT_NDS_FMV_X_ANYEXTBF16>;
+
//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
//===----------------------------------------------------------------------===//
@@ -774,6 +788,25 @@ def : Pat<(bf16 (fpround FPR32:$rs)),
(NDS_FCVT_BF16_S FPR32:$rs)>;
} // Predicates = [HasVendorXAndesBFHCvt]
+let isCodeGenOnly = 1 in {
+def NDS_FMV_BF16_X : FPUnaryOp_r<0b1111000, 0b00000, 0b000, FPR16, GPR, "fmv.w.x">,
+ Sched<[WriteFMovI32ToF32, ReadFMovI32ToF32]>;
+def NDS_FMV_X_BF16 : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR16, "fmv.x.w">,
+ Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>;
+}
+
+let Predicates = [HasVendorXAndesBFHCvt] in {
+def : Pat<(riscv_nds_fmv_bf16_x GPR:$src), (NDS_FMV_BF16_X GPR:$src)>;
+def : Pat<(riscv_nds_fmv_x_anyextbf16 (bf16 FPR16:$src)),
+ (NDS_FMV_X_BF16 (bf16 FPR16:$src))>;
+} // Predicates = [HasVendorXAndesBFHCvt]
+
+// Use flh/fsh to load/store bf16 if zfh is enabled.
+let Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] in {
+def : LdPat<load, FLH, bf16>;
+def : StPat<store, FSH, FPR16, bf16>;
+} // Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt]
+
let Predicates = [HasVendorXAndesVBFHCvt] in {
defm PseudoNDS_VFWCVT_S_BF16 : VPseudoVWCVT_S_BF16;
defm PseudoNDS_VFNCVT_BF16_S : VPseudoVNCVT_BF16_S;
diff --git a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
index 854d0b659ea73..c0c15172676fd 100644
--- a/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
+++ b/llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
@@ -1,8 +1,12 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+xandesbfhcvt -target-abi ilp32f \
-; RUN: -verify-machineinstrs < %s | FileCheck %s
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv32 -mattr=+zfh,+xandesbfhcvt -target-abi ilp32f \
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
; RUN: llc -mtriple=riscv64 -mattr=+xandesbfhcvt -target-abi lp64f \
-; RUN: -verify-machineinstrs < %s | FileCheck %s
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
+; RUN: llc -mtriple=riscv64 -mattr=+zfh,+xandesbfhcvt -target-abi lp64f \
+; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
define float @fcvt_s_bf16(bfloat %a) nounwind {
; CHECK-LABEL: fcvt_s_bf16:
@@ -21,3 +25,47 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
%1 = fptrunc float %a to bfloat
ret bfloat %1
}
+
+@sf = dso_local global float 0.000000e+00, align 4
+@bf = dso_local global bfloat 0xR0000, align 2
+
+; Check load and store to bf16.
+define void @loadstorebf16() nounwind {
+; XANDESBFHCVT-LABEL: loadstorebf16:
+; XANDESBFHCVT: # %bb.0: # %entry
+; XANDESBFHCVT-NEXT: lui a0, %hi(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT: lhu a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT: lui a2, 1048560
+; XANDESBFHCVT-NEXT: or a1, a1, a2
+; XANDESBFHCVT-NEXT: fmv.w.x fa5, a1
+; XANDESBFHCVT-NEXT: addi a1, a0, %lo(.L_MergedGlobals)
+; XANDESBFHCVT-NEXT: nds.fcvt.s.bf16 fa5, fa5
+; XANDESBFHCVT-NEXT: fsw fa5, 4(a1)
+; XANDESBFHCVT-NEXT: flw fa5, 4(a1)
+; XANDESBFHCVT-NEXT: nds.fcvt.bf16.s fa5, fa5
+; XANDESBFHCVT-NEXT: fmv.x.w a1, fa5
+; XANDESBFHCVT-NEXT: sh a1, %lo(.L_MergedGlobals)(a0)
+; XANDESBFHCVT-NEXT: ret
+;
+; ZFH-LABEL: loadstorebf16:
+; ZFH: # %bb.0: # %entry
+; ZFH-NEXT: lui a0, %hi(.L_MergedGlobals)
+; ZFH-NEXT: flh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT: addi a1, a0, %lo(.L_MergedGlobals)
+; ZFH-NEXT: nds.fcvt.s.bf16 fa5, fa5
+; ZFH-NEXT: fsw fa5, 4(a1)
+; ZFH-NEXT: flw fa5, 4(a1)
+; ZFH-NEXT: nds.fcvt.bf16.s fa5, fa5
+; ZFH-NEXT: fsh fa5, %lo(.L_MergedGlobals)(a0)
+; ZFH-NEXT: ret
+entry:
+ %0 = load bfloat, bfloat* @bf, align 2
+ %1 = fpext bfloat %0 to float
+ store volatile float %1, float* @sf, align 4
+
+ %2 = load float, float* @sf, align 4
+ %3 = fptrunc float %2 to bfloat
+ store volatile bfloat %3, bfloat* @bf, align 2
+
+ ret void
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/187/builds/8489 Here is the relevant piece of the build log for the reference
|
We use
lh
to load 2 bytes from memory into a gpr, then mask this gpr with -65536 to emulate nan-boxing behavior, and then the value in gpr is moved to fpr usingfmv.w.x
.To move the value back from fpr to gpr, we use
fmv.x.w
and finally,sh
is used to store the lower 2 bytes back to memory.If zfh is enabled at the same time, we can just use flh/fsw to load/store bf16 directly.