Skip to content

WIP Blackwell fp4 #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: blackwell
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 99 additions & 1 deletion include/common/base_types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

#pragma once

#ifdef KITTENS_BLACKWELL
#include <cuda_fp4.h>
#endif

#ifdef KITTENS_HOPPER
#include <cuda_fp8.h>
#endif
Expand Down Expand Up @@ -55,6 +59,20 @@ using fp8e5m2_2 = __nv_fp8x2_e5m2;
using fp8e4m3_4 = __nv_fp8x4_e4m3;
using fp8e5m2_4 = __nv_fp8x4_e5m2;
#endif
#ifdef KITTENS_BLACKWELL
/**
* @brief float4 floating-point type.
*/
using fp4e2m1 = __nv_fp4_e2m1;
/**
* @brief 2-packed float4 floating-point type.
*/
using fp4e2m1_2 = __nv_fp4x2_e2m1;
/**
* @brief 4-packed float4 floating-point type.
*/
using fp4e2m1_4 = __nv_fp4x4_e2m1;
#endif

namespace ducks {
/**
Expand All @@ -64,7 +82,12 @@ namespace ducks {
*/
namespace base_types {

#ifdef KITTENS_HOPPER
#if defined(KITTENS_BLACKWELL)
template<typename T>
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4> || std::is_same_v<T, fp8e5m2_4> || std::is_same_v<T, fp4e2m1_4>;
template<typename T>
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3> || std::is_same_v<T, fp8e5m2> || std::is_same_v<T, fp4e2m1>;
#elif defined(KITTENS_HOPPER)
template<typename T>
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4> || std::is_same_v<T, fp8e5m2_4>; // could add half_2 later if implemented.
template<typename T>
Expand Down Expand Up @@ -143,6 +166,20 @@ template<> struct constants<half_2> {
static __device__ inline constexpr half_2 pos_infty() { return half_2{constants<half>::pos_infty(), constants<half>::pos_infty()}; }
static __device__ inline constexpr half_2 neg_infty() { return half_2{constants<half>::neg_infty(), constants<half>::neg_infty()}; }
};
#ifdef KITTENS_BLACKWELL
template<> struct constants<fp4e2m1> {
static __device__ inline constexpr fp4e2m1 zero() { return std::bit_cast<fp4e2m1>(uint8_t(0x00)); }
static __device__ inline constexpr fp4e2m1 one() { return std::bit_cast<fp4e2m1>(uint8_t(0x08)); } // do the packing myself?
};
template<> struct constants<fp4e2m1_2> {
static __device__ inline constexpr fp4e2m1_2 zero() { return std::bit_cast<fp4e2m1_2>(uint8_t(0x00)); }
static __device__ inline constexpr fp4e2m1_2 one() { return std::bit_cast<fp4e2m1_2>(uint8_t(0x22)); }
};
template<> struct constants<fp4e2m1_4> {
static __device__ inline constexpr fp4e2m1_4 zero() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x0000)); }
static __device__ inline constexpr fp4e2m1_4 one() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x0202)); }
};
#endif
#ifdef KITTENS_HOPPER
template<> struct constants<fp8e4m3> {
static __device__ inline constexpr fp8e4m3 zero() { return std::bit_cast<__nv_fp8_e4m3>(uint8_t(0x00)); }
Expand Down Expand Up @@ -253,6 +290,18 @@ template<> struct packing<float4> {
template<> struct packing<int4> {
static __device__ inline constexpr int num() { return 4; }
};
#ifdef KITTENS_BLACKWELL
template<> struct packing<fp4e2m1> {
static __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp4e2m1;
using packed_type = fp4e2m1_4;
};
template<> struct packing<fp4e2m1_4> {
static __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp4e2m1;
using packed_type = fp4e2m1_4;
};
#endif
#ifdef KITTENS_HOPPER
template<> struct packing<fp8e4m3> {
static __device__ inline constexpr int num() { return 1; }
Expand Down Expand Up @@ -354,6 +403,55 @@ template<> struct convertor<half_2, bf16_2> {
return __float22half2_rn(__bfloat1622float2(u));
}
};
#ifdef KITTENS_BLACKWELL
// fp4e2m1
template<> struct convertor<fp4e2m1_4, float4> {
static __host__ __device__ inline fp4e2m1_4 convert(const float4& u) {
return __nv_fp4x4_e2m1(u);
}
};
template<> struct convertor<float4, fp4e2m1_4> {
static __host__ __device__ inline float4 convert(const fp4e2m1_4& u) {
__nv_fp4_e2m1 *vals = reinterpret_cast<__nv_fp4_e2m1*>(const_cast<__nv_fp4x4_e2m1*>(&u));
return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3]));
}
};
template<> struct convertor<fp4e2m1_2, float2> {
static __host__ __device__ inline fp4e2m1_2 convert(const float2& u) {
return __nv_fp4x2_e2m1(u);
}
};
template<> struct convertor<float2, fp4e2m1_2> {
static __host__ __device__ inline float2 convert(const fp4e2m1_2& u) {
__nv_fp4_e2m1 *vals = reinterpret_cast<__nv_fp4_e2m1*>(const_cast<__nv_fp4x2_e2m1*>(&u));
return make_float2(float(vals[0]), float(vals[1]));
}
};
template<> struct convertor<fp4e2m1, float> {
static __host__ __device__ inline fp4e2m1 convert(const float & u) {
return __nv_fp4_e2m1(u);
}
};
template<> struct convertor<float, fp4e2m1> {
static __host__ __device__ inline float convert(const fp4e2m1 & u) {
return float(u);
}
};
template<> struct convertor<bf16_2, fp4e2m1_4> {
static __host__ __device__ inline bf16_2 convert(const fp4e2m1_4 & u) {
float4 f4 = convertor<float4, fp4e2m1_4>::convert(u);
float2 f2 = make_float2(f4.x, f4.y);
return __float22bfloat162_rn(f2);
}
};
template<> struct convertor<fp4e2m1_4, bf16_2> {
static __host__ __device__ inline fp4e2m1_4 convert(const bf16_2 & u) {
float2 f2 = __bfloat1622float2(u);
float4 f4 = make_float4(f2.x, f2.y, 0.0f, 0.0f);
return __nv_fp4x4_e2m1(f4);
}
};
#endif
#ifdef KITTENS_HOPPER
// fp8e4m3
template<> struct convertor<fp8e4m3_4, float4> {
Expand Down
5 changes: 4 additions & 1 deletion include/ops/warp/memory/util/util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,10 @@ template<typename T> struct size_info {
};
template<ducks::st::all ST> struct size_info<ST> {
static constexpr uint32_t elements = ST::num_elements;
static constexpr uint32_t bytes = ST::num_elements * sizeof(typename ST::dtype);
static constexpr uint32_t bytes =
std::is_same_v<typename ST::dtype, fp4e2m1>
? ST::num_elements * sizeof(fp4e2m1) / 2
: ST::num_elements * sizeof(typename ST::dtype);
};
template<ducks::sv::all SV> struct size_info<SV> {
static constexpr uint32_t elements = SV::length;
Expand Down
24 changes: 16 additions & 8 deletions include/ops/warp/tensor/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ __device__ static inline uint32_t instruction_descriptor() {
desc |= 0b001 << 7; // 16-bit A input type as BF16
desc |= 0b001 << 10; // 16-bit B input type as BF16
} else if constexpr (std::is_same_v<AB, fp8e4m3>) {
desc |= 0b000 << 7; // 8-bit A input type as FP8 e4m3
desc |= 0b000 << 10; // 8-bit B input type as FP8 e4m3
// !!! Temp hack to get fp4 working, needs removal.
desc |= 0b101 << 7; // 4-bit A input type as FP4 e2m1
desc |= 0b101 << 10; // 4-bit B input type as FP4 e2m1
} else if constexpr (std::is_same_v<AB, fp8e5m2>) {
desc |= 0b001 << 7; // 8-bit A input type as FP8 e5m2
desc |= 0b001 << 10; // 8-bit B input type as FP8 e5m2
Expand Down Expand Up @@ -93,11 +94,14 @@ __device__ static inline uint32_t instruction_descriptor() {
}
desc |= 0b0 << 6; // reserved
if constexpr (std::is_same_v<AB, fp8e4m3>) {
desc |= 0b000 << 7; // 8-bit A input type as FP8 e4m3
desc |= 0b000 << 10; // 8-bit B input type as FP8 e4m3
desc |= 0b101 << 7; // 4-bit A input type as FP4 e2m1
desc |= 0b101 << 10; // 4-bit B input type as FP4 e2m1
} else if constexpr (std::is_same_v<AB, fp8e5m2>) {
desc |= 0b001 << 7; // 8-bit A input type as FP8 e5m2
desc |= 0b001 << 10; // 8-bit B input type as FP8 e5m2
} else if constexpr (std::is_same_v<AB, fp4e2m1>) {
desc |= 0b101 << 7; // 4-bit A input type as FP4 e2m1
desc |= 0b101 << 10; // 4-bit B input type as FP4 e2m1
}
/* fp6 and fp4
else if constexpr (std::is_same_v<AB, fp6e2m3>) {
Expand Down Expand Up @@ -146,7 +150,7 @@ __device__ static inline uint32_t instruction_descriptor() {

template<typename T_AB, int acc, int ncta=1>
__device__ static inline void tmem_st(uint32_t d_tmem_addr, uint32_t a_tmem_addr, uint64_t b_desc, uint32_t idesc) {
if constexpr (std::is_same_v<T_AB, fp8e4m3> || std::is_same_v<T_AB, fp8e5m2>) {
if constexpr (std::is_same_v<T_AB, fp8e4m3> || std::is_same_v<T_AB, fp8e5m2> || std::is_same_v<T_AB, fp4e2m1>) {
// TODO(danfu): is there a better way to do this with string manipulation that the compiler likes?
if constexpr (ncta == 1) {
asm volatile(
Expand Down Expand Up @@ -186,7 +190,7 @@ __device__ static inline void tmem_st(uint32_t d_tmem_addr, uint32_t a_tmem_addr

template<typename T_AB, int acc, int ncta=1>
__device__ static inline void st_st(uint32_t d_tmem_addr, uint64_t a_desc, uint64_t b_desc, uint32_t idesc) {
if constexpr (std::is_same_v<T_AB, fp8e4m3> || std::is_same_v<T_AB, fp8e5m2>) {
if constexpr (std::is_same_v<T_AB, fp8e4m3> || std::is_same_v<T_AB, fp8e5m2> || std::is_same_v<T_AB, fp4e2m1>) {
// TODO(danfu): is there a better way to do this with string manipulation that the compiler likes?
if constexpr (ncta == 1) {
asm volatile(
Expand Down Expand Up @@ -269,10 +273,12 @@ __device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem)
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e5m2>) ||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp4e2m1>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, bf16>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e5m2>),
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e5m2>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp4e2m1>),
"Currently unsupported type combination for matrix multiply."
);
uint32_t idesc = detail::instruction_descriptor<T_D, T_AB, M, N, trans_a, trans_b, false>();
Expand Down Expand Up @@ -326,10 +332,12 @@ __device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem)
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e5m2>) ||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp4e2m1>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, bf16>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e5m2>),
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e5m2>) ||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp4e2m1>),
"Currently unsupported type combination for matrix multiply."
);
uint32_t idesc = detail::instruction_descriptor<T_D, T_AB, M, N, trans_a, trans_b, false>();
Expand Down
4 changes: 4 additions & 0 deletions include/types/global/tma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ __host__ static inline void create_tensor_map(CUtensorMap *tma_map, const typena
std::is_same_v<dtype, float> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT32 :
std::is_same_v<dtype, fp8e4m3> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
std::is_same_v<dtype, fp8e5m2> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
std::is_same_v<dtype, fp4e2m1> ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B :
std::is_same_v<dtype, fp4e2m1_2> ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B :
std::is_same_v<dtype, fp4e2m1_4> ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B :
CUtensorMapDataType(-1)
);
constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
Expand Down Expand Up @@ -224,6 +227,7 @@ __host__ static inline void create_tensor_map(CUtensorMap *tma_map, const typena
std::is_same_v<dtype, float> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT32 :
std::is_same_v<dtype, fp8e4m3> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
std::is_same_v<dtype, fp8e5m2> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
std::is_same_v<dtype, fp4e2m1> ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B :
CUtensorMapDataType(-1)
);
constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
Expand Down
3 changes: 3 additions & 0 deletions include/types/shared/st.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,7 @@ template<int _height, int _width> using st_fl = st<float, _height, _width>;
template<int _height, int _width> using st_fl8_e4m3 = st<fp8e4m3, _height, _width>;
template<int _height, int _width> using st_fl8_e5m2 = st<fp8e5m2, _height, _width>;
#endif
#ifdef KITTENS_BLACKWELL
template<int _height, int _width> using st_fl4_e2m1 = st<fp4e2m1, _height, _width>;
#endif
}
20 changes: 20 additions & 0 deletions kernels/matmul/FP4_B200/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# GPU Selection: 4090, A100, H100
GPU_TARGET=B200

# Compiler
NVCC=nvcc

NVCCFLAGS=-DNDEBUG -Xcompiler=-fPIE -Xcompiler -fopenmp --expt-extended-lambda --expt-relaxed-constexpr -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing --use_fast_math -forward-unknown-to-host-compiler -O3 -Xnvlink=--verbose -Xptxas=--verbose -Xptxas=--warn-on-spills -std=c++20 -MD -MT -MF -x cu -lrt -lpthread -ldl -DKITTENS_BLACKWELL -DKITTENS_HOPPER -arch=sm_100a -lcuda -lcudadevrt -lcudart_static -lgomp -I${THUNDERKITTENS_ROOT}/include -I${THUNDERKITTENS_ROOT}/prototype
TARGET=matmul
SRC=matmul.cu
# SRC=tmp.cu

# Default target
all: $(TARGET)

$(TARGET): $(SRC)
$(NVCC) $(SRC) $(NVCCFLAGS) -o $(TARGET)

# Clean target
clean:
rm -f $(TARGET)
Loading