File tree 3 files changed +9
-4
lines changed
3 files changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")
30
30
set (XPU_XPUDNN_LIB_NAME "libxpu_dnn.so" )
31
31
32
32
if (NOT DEFINED XPU_XHPC_BASE_DATE)
33
- set (XPU_XHPC_BASE_DATE "dev/20250318 " )
33
+ set (XPU_XHPC_BASE_DATE "dev/20250405 " )
34
34
endif ()
35
35
set (XPU_XCCL_BASE_VERSION "3.0.2.5" ) # For XRE5
36
36
if (NOT DEFINED XPU_XFT_BASE_VERSION)
Original file line number Diff line number Diff line change @@ -232,6 +232,8 @@ void FlashAttnGradKernelBase(
232
232
is_flashmask ? flashmask_stream : nullptr );
233
233
PADDLE_ENFORCE_XDNN_SUCCESS (r, " mha_varlen_bwd" );
234
234
if (is_flashmask && flashmask_stream != nullptr ) {
235
+ r = xpu_wait (flashmask_stream);
236
+ PADDLE_ENFORCE_XPU_SUCCESS (r);
235
237
xpu_stream_destroy (flashmask_stream);
236
238
}
237
239
}
Original file line number Diff line number Diff line change @@ -210,12 +210,15 @@ void FlashAttnKernelBase(
210
210
(const int *)upstart_row_indices_data, // upstart_row_indices_data
211
211
(const int *)upend_row_indices_data, // upend_row_indices_data
212
212
is_flashmask ? startend_row_indices->dims ()[1 ]
213
- : 0 , // flash_mask_head_num
214
- nullptr , // flashmask_maxmin
215
- is_flashmask ? flashmask_stream : nullptr // side_stream
213
+ : 0 , // flash_mask_head_num
214
+ nullptr , // flashmask_maxmin
215
+ is_flashmask ? flashmask_stream : nullptr , // side_stream
216
+ 0 // fixlen_batch_num
216
217
);
217
218
PADDLE_ENFORCE_XDNN_SUCCESS (r, " mha_varlen_fwd" );
218
219
if (is_flashmask && flashmask_stream != nullptr ) {
220
+ r = xpu_wait (flashmask_stream);
221
+ PADDLE_ENFORCE_XPU_SUCCESS (r);
219
222
xpu_stream_destroy (flashmask_stream);
220
223
}
221
224
}
You can’t perform that action at this time.
0 commit comments