We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 15de811 commit d1c120aCopy full SHA for d1c120a
flashinfer/cudnn/decode.py
@@ -22,7 +22,7 @@ def _create_cudnn_handle(stream: torch.cuda.Stream):
22
global _cudnn_handle
23
if _cudnn_handle is None:
24
_cudnn_handle = cudnn.create_handle()
25
- cudnn.set_stream(_cudnn_handle, stream.cuda_stream)
+ # cudnn.set_stream(_cudnn_handle, stream.cuda_stream) # TODO: Will fix this in future
26
return _cudnn_handle
27
28
@@ -89,6 +89,10 @@ def _build_decode_graph(
89
):
90
handle = _create_cudnn_handle(torch.cuda.current_stream())
91
92
+ # WAR: override batch offsets for now, as it leads to a poor performance
93
+ batch_offsets_q = None
94
+ batch_offsets_o = None
95
+
96
with cudnn.graph(handle) as (g, _):
97
98
if q.dim() == 3:
0 commit comments