Skip to content

[RISCV] Implement load/store support for XAndesBFHCvt #150350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}
}

// Customize load and store operation for bf16 if zfh isn't enabled.
if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
setOperationAction(ISD::LOAD, MVT::bf16, Custom);
setOperationAction(ISD::STORE, MVT::bf16, Custom);
}

// Function alignments.
const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
setMinFunctionAlignment(FunctionAlignment);
Expand Down Expand Up @@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
}

SDValue
RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
SelectionDAG &DAG) const {
assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
"Unexpected bfloat16 load lowering");

SDLoc DL(Op);
LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
EVT MemVT = LD->getMemoryVT();
SDValue Load = DAG.getExtLoad(
ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
LD->getBasePtr(),
EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
LD->getMemOperand());
// Using mask to make bf16 nan-boxing valid when we don't have flh
// instruction. -65536 would be treat as a small number and thus it can be
// directly used lui to get the constant.
SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
SDValue OrSixteenOne =
DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
SDValue ConvertedResult =
DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
}

SDValue
RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
SelectionDAG &DAG) const {
assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
"Unexpected bfloat16 store lowering");

StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
SDLoc DL(Op);
SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
Subtarget.getXLenVT(), ST->getValue());
return DAG.getTruncStore(
ST->getChain(), DL, FMV, ST->getBasePtr(),
EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
ST->getMemOperand());
}

SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
Expand Down Expand Up @@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getMergeValues({Pair, Chain}, DL);
}

if (VT == MVT::bf16)
return lowerXAndesBfHCvtBFloat16Load(Op, DAG);

// Handle normal vector tuple load.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
Expand Down Expand Up @@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
{Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
Store->getMemOperand());
}

if (VT == MVT::bf16)
return lowerXAndesBfHCvtBFloat16Store(Op, DAG);

// Handle normal vector tuple store.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;

SDValue lowerXAndesBfHCvtBFloat16Load(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerXAndesBfHCvtBFloat16Store(SDValue Op, SelectionDAG &DAG) const;

bool isEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
const SmallVector<CCValAssign, 16> &ArgLocs) const;
Expand Down
33 changes: 33 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
//
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// RISC-V specific DAG Nodes.
//===----------------------------------------------------------------------===//

def SDT_NDS_FMV_BF16_X
: SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>;
def SDT_NDS_FMV_X_ANYEXTBF16
: SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>;

def riscv_nds_fmv_bf16_x
: SDNode<"RISCVISD::NDS_FMV_BF16_X", SDT_NDS_FMV_BF16_X>;
def riscv_nds_fmv_x_anyextbf16
: SDNode<"RISCVISD::NDS_FMV_X_ANYEXTBF16", SDT_NDS_FMV_X_ANYEXTBF16>;

//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -774,6 +788,25 @@ def : Pat<(bf16 (fpround FPR32:$rs)),
(NDS_FCVT_BF16_S FPR32:$rs)>;
} // Predicates = [HasVendorXAndesBFHCvt]

let isCodeGenOnly = 1 in {
def NDS_FMV_BF16_X : FPUnaryOp_r<0b1111000, 0b00000, 0b000, FPR16, GPR, "fmv.w.x">,
Sched<[WriteFMovI32ToF32, ReadFMovI32ToF32]>;
def NDS_FMV_X_BF16 : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR16, "fmv.x.w">,
Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>;
}

let Predicates = [HasVendorXAndesBFHCvt] in {
def : Pat<(riscv_nds_fmv_bf16_x GPR:$src), (NDS_FMV_BF16_X GPR:$src)>;
def : Pat<(riscv_nds_fmv_x_anyextbf16 (bf16 FPR16:$src)),
(NDS_FMV_X_BF16 (bf16 FPR16:$src))>;
} // Predicates = [HasVendorXAndesBFHCvt]

// Use flh/fsh to load/store bf16 if zfh is enabled.
let Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] in {
def : LdPat<load, FLH, bf16>;
def : StPat<store, FSH, FPR16, bf16>;
} // Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt]

let Predicates = [HasVendorXAndesVBFHCvt] in {
defm PseudoNDS_VFWCVT_S_BF16 : VPseudoVWCVT_S_BF16;
defm PseudoNDS_VFNCVT_BF16_S : VPseudoVNCVT_BF16_S;
Expand Down
45 changes: 43 additions & 2 deletions llvm/test/CodeGen/RISCV/xandesbfhcvt.ll
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+xandesbfhcvt -target-abi ilp32f \
; RUN: -verify-machineinstrs < %s | FileCheck %s
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
; RUN: llc -mtriple=riscv32 -mattr=+zfh,+xandesbfhcvt -target-abi ilp32f \
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
; RUN: llc -mtriple=riscv64 -mattr=+xandesbfhcvt -target-abi lp64f \
; RUN: -verify-machineinstrs < %s | FileCheck %s
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
; RUN: llc -mtriple=riscv64 -mattr=+zfh,+xandesbfhcvt -target-abi lp64f \
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s

define float @fcvt_s_bf16(bfloat %a) nounwind {
; CHECK-LABEL: fcvt_s_bf16:
Expand All @@ -21,3 +25,40 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
%1 = fptrunc float %a to bfloat
ret bfloat %1
}

; Check load and store to bf16.
define void @loadstorebf16(ptr %bf, ptr %sf) nounwind {
; XANDESBFHCVT-LABEL: loadstorebf16:
; XANDESBFHCVT: # %bb.0: # %entry
; XANDESBFHCVT-NEXT: lhu a2, 0(a0)
; XANDESBFHCVT-NEXT: lui a3, 1048560
; XANDESBFHCVT-NEXT: or a2, a2, a3
; XANDESBFHCVT-NEXT: fmv.w.x fa5, a2
; XANDESBFHCVT-NEXT: nds.fcvt.s.bf16 fa5, fa5
; XANDESBFHCVT-NEXT: fsw fa5, 0(a1)
; XANDESBFHCVT-NEXT: flw fa5, 0(a1)
; XANDESBFHCVT-NEXT: nds.fcvt.bf16.s fa5, fa5
; XANDESBFHCVT-NEXT: fmv.x.w a1, fa5
; XANDESBFHCVT-NEXT: sh a1, 0(a0)
; XANDESBFHCVT-NEXT: ret
;
; ZFH-LABEL: loadstorebf16:
; ZFH: # %bb.0: # %entry
; ZFH-NEXT: flh fa5, 0(a0)
; ZFH-NEXT: nds.fcvt.s.bf16 fa5, fa5
; ZFH-NEXT: fsw fa5, 0(a1)
; ZFH-NEXT: flw fa5, 0(a1)
; ZFH-NEXT: nds.fcvt.bf16.s fa5, fa5
; ZFH-NEXT: fsh fa5, 0(a0)
; ZFH-NEXT: ret
entry:
%0 = load bfloat, bfloat* %bf, align 2
%1 = fpext bfloat %0 to float
store volatile float %1, float* %sf, align 4

%2 = load float, float* %sf, align 4
%3 = fptrunc float %2 to bfloat
store volatile bfloat %3, bfloat* %bf, align 2

ret void
}