Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ namespace fbgemm_gpu {
// outputs are of size float[D]

#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
#if HIP_FP8_TYPE_OCP
#if HIP_FP8_TYPE_OCP && !HIP_FP8_TYPE_FNUZ
using __nv_fp8x4_e4m3 = __hip_fp8x4_e4m3;
using __nv_fp8x2_e4m3 = __hip_fp8x2_e4m3;
using __nv_fp8_e4m3 = __hip_fp8_e4m3;
Expand Down Expand Up @@ -1075,7 +1075,11 @@ void invokeComputeScalesAndQuantizeMatrix(
bool stochastic_rounding,
cudaStream_t stream) {
dim3 grid(numel / lda);
#ifdef USE_ROCM
bool use_shmem = true;
#else
bool use_shmem = false;
#endif
auto const shmem_size = lda * sizeof(T_IN);
if (shmem_size >= (48 << 10)) {
cudaError_t ret;
Expand Down
16 changes: 7 additions & 9 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
["rowwise", "blockwise"]
+ (["tensorwise_broadcast", "tensorwise"] if torch.version.cuda else [])
),
QType=st.sampled_from([fp8_e4m3, fp8_e5m2]),
QType=(
st.sampled_from([fp8_e4m3, fp8_e5m2] if torch.version.cuda else [fp8_e4m3])
),
Bias=st.sampled_from([True, False]),
CudaGraph=st.sampled_from([True, False]),
UseTriton=st.sampled_from([False] + ([True] if torch.version.cuda else [])),
Expand Down Expand Up @@ -406,14 +408,10 @@ def f(
def f(
x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
if torch.version.cuda:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, output_dtype=QType
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
else:
xq, x_scale = quantize_fp8_row(x)
wq, w_scale = quantize_fp8_row(w)
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, output_dtype=QType
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
if UseTriton and torch.version.cuda:
zq = matmul_fp8_row(xq, wq, x_scale, w_scale)
if bias is not None:
Expand Down
Loading