Skip to content

Commit d1c120a

Browse files
committed
Disallow the ragged offsets for decode for now
1 parent 15de811 commit d1c120a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

flashinfer/cudnn/decode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _create_cudnn_handle(stream: torch.cuda.Stream):
2222
global _cudnn_handle
2323
if _cudnn_handle is None:
2424
_cudnn_handle = cudnn.create_handle()
25-
cudnn.set_stream(_cudnn_handle, stream.cuda_stream)
25+
# cudnn.set_stream(_cudnn_handle, stream.cuda_stream) # TODO: Will fix this in future
2626
return _cudnn_handle
2727

2828

@@ -89,6 +89,10 @@ def _build_decode_graph(
8989
):
9090
handle = _create_cudnn_handle(torch.cuda.current_stream())
9191

92+
# WAR: override batch offsets for now, as it leads to a poor performance
93+
batch_offsets_q = None
94+
batch_offsets_o = None
95+
9296
with cudnn.graph(handle) as (g, _):
9397

9498
if q.dim() == 3:

0 commit comments

Comments
 (0)