@@ -71,128 +71,129 @@ def _sdpa_decode_key_fn(
71
71
)
72
72
73
73
74
- @cudnn .jit (heur_modes = [cudnn .heur_mode .A ])
75
- @cudnn .graph_cache (key_fn = _sdpa_decode_key_fn )
76
- def _build_decode_graph (
77
- q : torch .Tensor ,
78
- k_cache : torch .Tensor ,
79
- v_cache : torch .Tensor ,
80
- scale : float ,
81
- * ,
82
- max_sequence_kv : int ,
83
- block_size : Optional [int ] = 1 ,
84
- actual_seq_lens_q : Optional [torch .Tensor ] = None ,
85
- actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
86
- block_tables : Optional [torch .Tensor ] = None ,
87
- batch_offsets_q : Optional [torch .Tensor ] = None ,
88
- batch_offsets_o : Optional [torch .Tensor ] = None ,
89
- ):
90
- handle = _create_cudnn_handle (torch .cuda .current_stream ())
91
-
92
- # WAR: override batch offsets for now, as it leads to a poor performance
93
- batch_offsets_q = None
94
- batch_offsets_o = None
95
-
96
- with cudnn .graph (handle ) as (g , _ ):
97
-
98
- if q .dim () == 3 :
99
- s_qo = 1
100
- b , h_qo , d_qk = q .shape [0 ], q .shape [1 ], q .shape [2 ]
101
- elif q .dim () == 4 :
102
- b , h_qo , s_qo , d_qk = (
103
- q .shape [0 ],
104
- q .shape [1 ],
105
- q .shape [2 ],
106
- q .shape [3 ],
74
+ if CUDNN_AVAILABLE :
75
+
76
+ @cudnn .jit (heur_modes = [cudnn .heur_mode .A ])
77
+ @cudnn .graph_cache (key_fn = _sdpa_decode_key_fn )
78
+ def _build_decode_graph (
79
+ q : torch .Tensor ,
80
+ k_cache : torch .Tensor ,
81
+ v_cache : torch .Tensor ,
82
+ scale : float ,
83
+ * ,
84
+ max_sequence_kv : int ,
85
+ block_size : Optional [int ] = 1 ,
86
+ actual_seq_lens_q : Optional [torch .Tensor ] = None ,
87
+ actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
88
+ block_tables : Optional [torch .Tensor ] = None ,
89
+ batch_offsets_q : Optional [torch .Tensor ] = None ,
90
+ batch_offsets_o : Optional [torch .Tensor ] = None ,
91
+ ):
92
+ handle = _create_cudnn_handle (torch .cuda .current_stream ())
93
+
94
+ # WAR: override batch offsets for now, as it leads to a poor performance
95
+ batch_offsets_q = None
96
+ batch_offsets_o = None
97
+
98
+ with cudnn .graph (handle ) as (g , _ ):
99
+ if q .dim () == 3 :
100
+ s_qo = 1
101
+ b , h_qo , d_qk = q .shape [0 ], q .shape [1 ], q .shape [2 ]
102
+ elif q .dim () == 4 :
103
+ b , h_qo , s_qo , d_qk = (
104
+ q .shape [0 ],
105
+ q .shape [1 ],
106
+ q .shape [2 ],
107
+ q .shape [3 ],
108
+ )
109
+ else :
110
+ raise ValueError (f"q must have 3 or 4 dimensions, got { q .dim ()} " )
111
+
112
+ assert s_qo == 1 , "q must have a sequence length of 1"
113
+ assert k_cache .dim () == 4 , "k_cache must have 4 dimensions"
114
+
115
+ h_kv = k_cache .shape [1 ]
116
+ s_kv = max_sequence_kv
117
+ d_vo = v_cache .shape [3 ]
118
+
119
+ cudnn_q = g .tensor (
120
+ name = "q" ,
121
+ dim = (b , h_qo , s_qo , d_qk ),
122
+ stride = (h_qo * d_qk , d_qk , d_qk * h_qo , 1 ),
123
+ data_type = cudnn .data_type .BFLOAT16 ,
124
+ )
125
+ if batch_offsets_q is not None :
126
+ ragged_q = g .tensor_like (batch_offsets_q )
127
+ ragged_q .set_uid (UIDs .RAGGED_Q_UID .value )
128
+ cudnn_q .set_ragged_offset (ragged_q )
129
+
130
+ cudnn_k_cache = g .tensor_like (k_cache )
131
+ cudnn_v_cache = g .tensor_like (v_cache )
132
+
133
+ cudnn_q .set_uid (UIDs .Q_UID .value )
134
+ cudnn_k_cache .set_uid (UIDs .K_UID .value )
135
+ cudnn_v_cache .set_uid (UIDs .V_UID .value )
136
+
137
+ if block_tables is not None :
138
+ nd_block_tables = block_tables .reshape (
139
+ block_tables .shape [0 ], 1 , block_tables .shape [1 ], 1
140
+ )
141
+ cudnn_k_block_tables = g .tensor_like (nd_block_tables )
142
+ cudnn_k_block_tables .set_uid (UIDs .BLOCK_TABLES_K_UID .value )
143
+
144
+ cudnn_v_block_tables = g .tensor_like (nd_block_tables )
145
+ cudnn_v_block_tables .set_uid (UIDs .BLOCK_TABLES_V_UID .value )
146
+
147
+ if actual_seq_lens_q is not None :
148
+ cudnn_actual_seq_lens_q = g .tensor_like (actual_seq_lens_q )
149
+ cudnn_actual_seq_lens_q .set_uid (UIDs .ACTUAL_SEQ_LENS_Q_UID .value )
150
+
151
+ if actual_seq_lens_kv is not None :
152
+ cudnn_actual_seq_lens_kv = g .tensor_like (actual_seq_lens_kv )
153
+ cudnn_actual_seq_lens_kv .set_uid (UIDs .ACTUAL_SEQ_LENS_KV_UID .value )
154
+ cudnn_actual_seq_lens_kv .set_is_pass_by_value (False )
155
+
156
+ padding_mask = actual_seq_lens_kv is not None
157
+
158
+ O , _ = g .sdpa (
159
+ name = "sdpa" ,
160
+ q = cudnn_q ,
161
+ k = cudnn_k_cache ,
162
+ v = cudnn_v_cache ,
163
+ seq_len_q = (
164
+ cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
165
+ ),
166
+ seq_len_kv = (
167
+ cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
168
+ ),
169
+ use_padding_mask = padding_mask ,
170
+ is_inference = True ,
171
+ attn_scale = scale ,
172
+ paged_attention_k_table = cudnn_k_block_tables ,
173
+ paged_attention_v_table = cudnn_v_block_tables ,
174
+ paged_attention_max_seq_len_kv = max_sequence_kv ,
175
+ compute_data_type = cudnn .data_type .FLOAT ,
107
176
)
108
- else :
109
- raise ValueError (f"q must have 3 or 4 dimensions, got { q .dim ()} " )
110
-
111
- assert s_qo == 1 , "q must have a sequence length of 1"
112
- assert k_cache .dim () == 4 , "k_cache must have 4 dimensions"
113
-
114
- h_kv = k_cache .shape [1 ]
115
- s_kv = max_sequence_kv
116
- d_vo = v_cache .shape [3 ]
117
-
118
- cudnn_q = g .tensor (
119
- name = "q" ,
120
- dim = (b , h_qo , s_qo , d_qk ),
121
- stride = (h_qo * d_qk , d_qk , d_qk * h_qo , 1 ),
122
- data_type = cudnn .data_type .BFLOAT16 ,
123
- )
124
- if batch_offsets_q is not None :
125
- ragged_q = g .tensor_like (batch_offsets_q )
126
- ragged_q .set_uid (UIDs .RAGGED_Q_UID .value )
127
- cudnn_q .set_ragged_offset (ragged_q )
128
-
129
- cudnn_k_cache = g .tensor_like (k_cache )
130
- cudnn_v_cache = g .tensor_like (v_cache )
131
177
132
- cudnn_q .set_uid (UIDs .Q_UID .value )
133
- cudnn_k_cache .set_uid (UIDs .K_UID .value )
134
- cudnn_v_cache .set_uid (UIDs .V_UID .value )
178
+ if batch_offsets_o is not None :
179
+ ragged_o = g .tensor_like (batch_offsets_o )
180
+ ragged_o .set_uid (UIDs .RAGGED_O_UID .value )
181
+ O .set_ragged_offset (ragged_o )
135
182
136
- if block_tables is not None :
137
- nd_block_tables = block_tables .reshape (
138
- block_tables .shape [0 ], 1 , block_tables .shape [1 ], 1
183
+ O .set_uid (UIDs .O_UID .value ).set_output (True ).set_dim (
184
+ [b , h_qo , s_qo , d_vo ]
185
+ ).set_stride ([d_vo * h_qo , d_vo , d_vo * h_qo , 1 ]).set_data_type (
186
+ cudnn .data_type .BFLOAT16
139
187
)
140
- cudnn_k_block_tables = g .tensor_like (nd_block_tables )
141
- cudnn_k_block_tables .set_uid (UIDs .BLOCK_TABLES_K_UID .value )
142
188
143
- cudnn_v_block_tables = g .tensor_like (nd_block_tables )
144
- cudnn_v_block_tables .set_uid (UIDs .BLOCK_TABLES_V_UID .value )
189
+ tensors_to_return = [cudnn_q , cudnn_k_cache , cudnn_v_cache , O ]
145
190
146
191
if actual_seq_lens_q is not None :
147
- cudnn_actual_seq_lens_q = g .tensor_like (actual_seq_lens_q )
148
- cudnn_actual_seq_lens_q .set_uid (UIDs .ACTUAL_SEQ_LENS_Q_UID .value )
149
-
192
+ tensors_to_return .append (cudnn_actual_seq_lens_q )
150
193
if actual_seq_lens_kv is not None :
151
- cudnn_actual_seq_lens_kv = g .tensor_like (actual_seq_lens_kv )
152
- cudnn_actual_seq_lens_kv .set_uid (UIDs .ACTUAL_SEQ_LENS_KV_UID .value )
153
- cudnn_actual_seq_lens_kv .set_is_pass_by_value (False )
154
-
155
- padding_mask = actual_seq_lens_kv is not None
156
-
157
- O , _ = g .sdpa (
158
- name = "sdpa" ,
159
- q = cudnn_q ,
160
- k = cudnn_k_cache ,
161
- v = cudnn_v_cache ,
162
- seq_len_q = (
163
- cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
164
- ),
165
- seq_len_kv = (
166
- cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
167
- ),
168
- use_padding_mask = padding_mask ,
169
- is_inference = True ,
170
- attn_scale = scale ,
171
- paged_attention_k_table = cudnn_k_block_tables ,
172
- paged_attention_v_table = cudnn_v_block_tables ,
173
- paged_attention_max_seq_len_kv = max_sequence_kv ,
174
- compute_data_type = cudnn .data_type .FLOAT ,
175
- )
176
-
177
- if batch_offsets_o is not None :
178
- ragged_o = g .tensor_like (batch_offsets_o )
179
- ragged_o .set_uid (UIDs .RAGGED_O_UID .value )
180
- O .set_ragged_offset (ragged_o )
181
-
182
- O .set_uid (UIDs .O_UID .value ).set_output (True ).set_dim (
183
- [b , h_qo , s_qo , d_vo ]
184
- ).set_stride ([d_vo * h_qo , d_vo , d_vo * h_qo , 1 ]).set_data_type (
185
- cudnn .data_type .BFLOAT16
186
- )
187
-
188
- tensors_to_return = [cudnn_q , cudnn_k_cache , cudnn_v_cache , O ]
189
-
190
- if actual_seq_lens_q is not None :
191
- tensors_to_return .append (cudnn_actual_seq_lens_q )
192
- if actual_seq_lens_kv is not None :
193
- tensors_to_return .append (cudnn_actual_seq_lens_kv )
194
+ tensors_to_return .append (cudnn_actual_seq_lens_kv )
194
195
195
- return g , tensors_to_return
196
+ return g , tensors_to_return
196
197
197
198
198
199
def _batch_decode_with_kv_cache (
0 commit comments