@@ -257,12 +257,8 @@ void FlashAttnUnpaddedGradBaseKernel(
257
257
kdq = &dq_tmp;
258
258
}
259
259
260
- #ifdef PADDLE_WITH_HIP
261
- std::initializer_list<int64_t > dk_dv_shape = {total_k, num_heads, head_size};
262
- #else
263
260
std::initializer_list<int64_t > dk_dv_shape = {
264
261
total_k, num_heads_k, num_heads / num_heads_k, head_size};
265
- #endif
266
262
267
263
DenseTensor *kdk = dk, *kdv = dv;
268
264
DenseTensor dk_tmp;
@@ -316,43 +312,6 @@ void FlashAttnUnpaddedGradBaseKernel(
316
312
317
313
VLOG (10 ) << " FlashAttn bwd seed: " << params.seed
318
314
<< " , offset: " << params.offset ;
319
- #ifdef PADDLE_WITH_HIP
320
- bool succ = phi::dynload::flash_attn_varlen_bwd (
321
- dout.data (),
322
- q.data (),
323
- k.data (),
324
- v.data (),
325
- out.data (),
326
- params.softmax_d .data (),
327
- softmax_lse.data (),
328
- cu_seqlens_q.data <int32_t >(),
329
- cu_seqlens_k.data <int32_t >(),
330
- params.rng_state .data (),
331
- kdq->data (),
332
- kdk->data (),
333
- kdv->data (),
334
- params.dq_accum .data (),
335
- params.batch_size ,
336
- params.max_seqlen_q ,
337
- params.max_seqlen_k ,
338
- params.seqlen_q_rounded ,
339
- params.seqlen_k_rounded ,
340
- params.num_heads ,
341
- params.num_heads_k ,
342
- params.head_size ,
343
- params.head_size_rounded ,
344
- params.dropout ,
345
- params.softmax_scale ,
346
- 1 .0f / params.softmax_scale ,
347
- params.causal ,
348
- params.is_bf16 ,
349
- num_splits,
350
- stream,
351
- params.seed ,
352
- params.offset ,
353
- params.attn_mask_tensor ? params.attn_mask_tensor ->data () : nullptr ,
354
- params.attn_mask_tensor ? params.mask_dims .data () : nullptr );
355
- #else
356
315
bool succ = phi::dynload::flash_attn_varlen_bwd (
357
316
dout.data (),
358
317
q.data (),
@@ -413,56 +372,19 @@ void FlashAttnUnpaddedGradBaseKernel(
413
372
max_seqlen_k * kdv->strides ()[0 ],
414
373
max_seqlen_q * dout.strides ()[0 ],
415
374
varlen_padded);
416
- #endif
417
375
CheckFlashAttnStatus (succ);
418
376
if (!is_mha) {
419
377
if (dk) {
420
- #ifdef PADDLE_WITH_HIP
421
- if (dk->meta ().is_contiguous ())
422
- phi::SumKernel<T, Context>(
423
- ctx,
424
- dk_tmp.Resize (
425
- {total_k, num_heads_k, num_heads / num_heads_k, head_size}),
426
- {2 },
427
- dk->type (),
428
- false ,
429
- dk);
430
- else
431
- kvReduceForGQA<T, Context>(
432
- ctx,
433
- dk_tmp.Resize (
434
- {total_k, num_heads_k, num_heads / num_heads_k, head_size}),
435
- dk);
436
- #else
437
378
if (dk->meta ().is_contiguous ())
438
379
phi::SumKernel<T, Context>(ctx, dk_tmp, {2 }, dk->type (), false , dk);
439
380
else
440
381
kvReduceForGQA<T, Context>(ctx, dk_tmp, dk);
441
- #endif
442
382
}
443
383
if (dv) {
444
- #ifdef PADDLE_WITH_HIP
445
- if (dv->meta ().is_contiguous ())
446
- phi::SumKernel<T, Context>(
447
- ctx,
448
- dv_tmp.Resize (
449
- {total_k, num_heads_k, num_heads / num_heads_k, head_size}),
450
- {2 },
451
- dv->type (),
452
- false ,
453
- dv);
454
- else
455
- kvReduceForGQA<T, Context>(
456
- ctx,
457
- dv_tmp.Resize (
458
- {total_k, num_heads_k, num_heads / num_heads_k, head_size}),
459
- dv);
460
- #else
461
384
if (dv->meta ().is_contiguous ())
462
385
phi::SumKernel<T, Context>(ctx, dv_tmp, {2 }, dv->type (), false , dv);
463
386
else
464
387
kvReduceForGQA<T, Context>(ctx, dv_tmp, dv);
465
- #endif
466
388
}
467
389
}
468
390
#else
@@ -658,13 +580,8 @@ void FlashAttnGradBaseKernel(
658
580
659
581
bool is_mha = (num_heads == num_heads_k);
660
582
661
- #ifdef PADDLE_WITH_HIP
662
- std::initializer_list<int64_t > dk_dv_shape = {
663
- batch_size, seqlen_k, num_heads, head_size};
664
- #else
665
583
std::initializer_list<int64_t > dk_dv_shape = {
666
584
batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
667
- #endif
668
585
669
586
DenseTensor* kdq = dq;
670
587
DenseTensor dq_tmp;
@@ -825,7 +742,37 @@ void FlashAttnGradBaseKernel(
825
742
params.seed ,
826
743
params.offset ,
827
744
params.attn_mask_tensor ? params.attn_mask_tensor ->data () : nullptr ,
828
- params.attn_mask_tensor ? params.mask_dims .data () : nullptr );
745
+ params.attn_mask_tensor ? params.mask_dims .data () : nullptr ,
746
+ is_flashmask ? downstart_row_indices_data : nullptr ,
747
+ is_flashmask ? params.startend_row_indices_dims .data () : nullptr ,
748
+ is_flashmask ? upend_row_indices_data : nullptr ,
749
+ is_flashmask ? downend_row_indices_data : nullptr ,
750
+ is_flashmask ? upstart_row_indices_data : nullptr ,
751
+ is_flashmask ? flashmask_maxmin.data () : nullptr ,
752
+ q.strides ()[1 ],
753
+ k.strides ()[1 ],
754
+ v.strides ()[1 ],
755
+ q.strides ()[2 ],
756
+ k.strides ()[2 ],
757
+ v.strides ()[2 ],
758
+ out.strides ()[1 ],
759
+ out.strides ()[2 ],
760
+ q.strides ()[0 ],
761
+ k.strides ()[0 ],
762
+ v.strides ()[0 ],
763
+ out.strides ()[0 ],
764
+ kdq->strides ()[1 ],
765
+ kdk->strides ()[1 ],
766
+ kdv->strides ()[1 ],
767
+ kdq->strides ()[2 ],
768
+ kdk->strides ()[kdk->strides ().size () - 2 ],
769
+ kdv->strides ()[kdv->strides ().size () - 2 ],
770
+ dout.strides ()[1 ],
771
+ dout.strides ()[2 ],
772
+ kdq->strides ()[0 ],
773
+ kdk->strides ()[0 ],
774
+ kdv->strides ()[0 ],
775
+ dout.strides ()[0 ]);
829
776
#else
830
777
bool succ;
831
778
int arch =
@@ -981,63 +928,17 @@ void FlashAttnGradBaseKernel(
981
928
CheckFlashAttnStatus (succ);
982
929
if (!is_mha) {
983
930
if (dk) {
984
- #ifdef PADDLE_WITH_HIP
985
- if (dk->meta ().is_contiguous ())
986
- phi::SumKernel<T, Context>(ctx,
987
- dk_tmp.Resize ({batch_size,
988
- seqlen_k,
989
- num_heads_k,
990
- num_heads / num_heads_k,
991
- head_size}),
992
- {3 },
993
- dk->type (),
994
- false ,
995
- dk);
996
- else
997
- kvReduceBatchedForGQA<T, Context>(
998
- ctx,
999
- dk_tmp.Resize ({batch_size,
1000
- seqlen_k,
1001
- num_heads_k,
1002
- num_heads / num_heads_k,
1003
- head_size}),
1004
- dk);
1005
- #else
1006
931
if (dk->meta ().is_contiguous ())
1007
932
phi::SumKernel<T, Context>(ctx, dk_tmp, {3 }, dk->type (), false , dk);
1008
933
else
1009
934
kvReduceBatchedForGQA<T, Context>(ctx, dk_tmp, dk);
1010
- #endif
1011
935
}
1012
936
1013
937
if (dv) {
1014
- #ifdef PADDLE_WITH_HIP
1015
- if (dv->meta ().is_contiguous ())
1016
- phi::SumKernel<T, Context>(ctx,
1017
- dv_tmp.Resize ({batch_size,
1018
- seqlen_k,
1019
- num_heads_k,
1020
- num_heads / num_heads_k,
1021
- head_size}),
1022
- {3 },
1023
- dv->type (),
1024
- false ,
1025
- dv);
1026
- else
1027
- kvReduceBatchedForGQA<T, Context>(
1028
- ctx,
1029
- dv_tmp.Resize ({batch_size,
1030
- seqlen_k,
1031
- num_heads_k,
1032
- num_heads / num_heads_k,
1033
- head_size}),
1034
- dv);
1035
- #else
1036
938
if (dv->meta ().is_contiguous ())
1037
939
phi::SumKernel<T, Context>(ctx, dv_tmp, {3 }, dv->type (), false , dv);
1038
940
else
1039
941
kvReduceBatchedForGQA<T, Context>(ctx, dv_tmp, dv);
1040
- #endif
1041
942
}
1042
943
}
1043
944
#else
0 commit comments