Skip to content

Commit a5778ea

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Fix quantize kernels on rocm 6.4 (#4708)
Summary: X-link: facebookresearch/FBGEMM#1731 Interestingly, ROCM6.4 technically allows both OCP and FNUZ floating point formats. We have a check in our quantize kernels that sees if OCP formats are allowed and uses them if so. However, for pretty much any integration, FNUZ is still expected. This small diff fixes the behavior by checking env vars more carefully and exposes rowwise quantization on AMD to unit tests. Reviewed By: q10 Differential Revision: D80309166
1 parent 635ffe7 commit a5778ea

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ namespace fbgemm_gpu {
8989
// outputs are of size float[D]
9090

9191
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
92-
#if HIP_FP8_TYPE_OCP
92+
#if HIP_FP8_TYPE_OCP && !HIP_FP8_TYPE_FNUZ
9393
using __nv_fp8x4_e4m3 = __hip_fp8x4_e4m3;
9494
using __nv_fp8x2_e4m3 = __hip_fp8x2_e4m3;
9595
using __nv_fp8_e4m3 = __hip_fp8_e4m3;
@@ -1075,7 +1075,11 @@ void invokeComputeScalesAndQuantizeMatrix(
10751075
bool stochastic_rounding,
10761076
cudaStream_t stream) {
10771077
dim3 grid(numel / lda);
1078+
#ifdef USE_ROCM
10781079
bool use_shmem = true;
1080+
#else
1081+
bool use_shmem = false;
1082+
#endif
10791083
auto const shmem_size = lda * sizeof(T_IN);
10801084
if (shmem_size >= (48 << 10)) {
10811085
cudaError_t ret;

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
289289
["rowwise", "blockwise"]
290290
+ (["tensorwise_broadcast", "tensorwise"] if torch.version.cuda else [])
291291
),
292-
QType=st.sampled_from([fp8_e4m3, fp8_e5m2]),
292+
QType=(
293+
st.sampled_from([fp8_e4m3, fp8_e5m2] if torch.version.cuda else [fp8_e4m3])
294+
),
293295
Bias=st.sampled_from([True, False]),
294296
CudaGraph=st.sampled_from([True, False]),
295297
UseTriton=st.sampled_from([False] + ([True] if torch.version.cuda else [])),
@@ -406,14 +408,10 @@ def f(
406408
def f(
407409
x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor]
408410
) -> torch.Tensor:
409-
if torch.version.cuda:
410-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
411-
x, output_dtype=QType
412-
)
413-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
414-
else:
415-
xq, x_scale = quantize_fp8_row(x)
416-
wq, w_scale = quantize_fp8_row(w)
411+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
412+
x, output_dtype=QType
413+
)
414+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
417415
if UseTriton and torch.version.cuda:
418416
zq = matmul_fp8_row(xq, wq, x_scale, w_scale)
419417
if bias is not None:

0 commit comments

Comments
 (0)