Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

torch.library.define(
"blackwell_fmha::fmha_fwd",
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv) -> (Tensor, Tensor)",
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv, Tensor? page_table, int seqlen_k=-1, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True) -> (Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)

torch.library.define(
"blackwell_fmha::fmha_bwd",
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, bool? causal) -> (Tensor, Tensor, Tensor)",
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True, bool deterministic=False) -> (Tensor, Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)

Expand All @@ -35,13 +35,19 @@ def custom_op_fmha(
softmax_scale: Optional[float] = None,
causal: bool = False,
seqlen_kv: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
seqlen_k: Optional[int] = None,
window_size_left: int = -1,
window_size_right: int = -1,
bottom_right: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
assert q.is_contiguous(), "q is not contiguous"
assert k.is_contiguous(), "k is not contiguous"
assert v.is_contiguous(), "v is not contiguous"
assert q.is_cuda, "q must be on GPU"
assert k.is_cuda, "k must be on GPU"
assert v.is_cuda, "v must be on GPU"

return torch.ops.fbgemm.fmha_fwd(
q,
k,
Expand All @@ -53,6 +59,11 @@ def custom_op_fmha(
softmax_scale=softmax_scale,
causal=causal,
seqlen_kv=seqlen_kv,
page_table=page_table,
seqlen_k=seqlen_k,
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right=bottom_right,
)


Expand All @@ -68,6 +79,11 @@ def fmha_fwd_meta(
softmax_scale: Optional[float] = None,
causal: bool = False,
seqlen_kv: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
seqlen_k: Optional[int] = None,
window_size_left: int = -1,
window_size_right: int = -1,
bottom_right: bool = True,
):
if q.dtype == torch.float16:
out_dtype = torch.float16
Expand Down Expand Up @@ -122,8 +138,14 @@ def custom_op_fmha_bwd(
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seq_len_q: Optional[int] = None,
max_seq_len_k: Optional[int] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
bottom_right: bool = True,
deterministic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

return torch.ops.fbgemm.fmha_bwd(
dOutput,
query,
Expand All @@ -135,7 +157,12 @@ def custom_op_fmha_bwd(
cu_seqlens_k=cu_seqlens_k,
max_seq_len_q=max_seq_len_q,
max_seq_len_k=max_seq_len_k,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right=bottom_right,
deterministic=deterministic,
)


Expand All @@ -151,7 +178,12 @@ def fmha_bwd_meta(
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seq_len_q: Optional[int] = None,
max_seq_len_k: Optional[int] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
bottom_right: bool = True,
deterministic: bool = False,
):
return (
torch.empty_like(query),
Expand Down Expand Up @@ -198,9 +230,30 @@ def _backward(ctx, *grad):
ctx.cu_seqlens_k,
ctx.max_seq_len_q,
ctx.max_seq_len_k,
ctx.softmax_scale,
ctx.causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.bottom_right,
ctx.deterministic,
)
return (
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
return dq, dk, dv, None, None, None, None, None, None, None


def _setup_context(ctx, inputs, output):
Expand All @@ -215,6 +268,11 @@ def _setup_context(ctx, inputs, output):
softmax_scale,
causal,
seqlen_kv,
page_table,
seqlen_k,
window_size_left,
window_size_right,
bottom_right,
) = inputs
(out, softmax_lse) = output
ctx.save_for_backward(q, k, v, out, softmax_lse)
Expand All @@ -224,6 +282,10 @@ def _setup_context(ctx, inputs, output):
ctx.max_seq_len_k = max_seq_len_k
ctx.cu_seqlens_q = cu_seqlens_q
ctx.cu_seqlens_k = cu_seqlens_k
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.bottom_right = bottom_right
ctx.deterministic = False # Set default value
ctx.is_gen = False


Expand All @@ -246,6 +308,11 @@ def cutlass_blackwell_fmha_custom_op(
max_seq_len_q: int | None = None,
max_seq_len_k: int | None = None,
seqlen_kv: torch.Tensor | None = None,
page_table: torch.Tensor | None = None,
seqlen_k: int | None = -1,
window_size_left: int | None = -1,
window_size_right: int | None = -1,
bottom_right: bool | None = True,
):
return torch.ops.blackwell_fmha.fmha_fwd(
q=q,
Expand All @@ -258,4 +325,9 @@ def cutlass_blackwell_fmha_custom_op(
softmax_scale=softmax_scale,
causal=causal,
seqlen_kv=seqlen_kv,
page_table=page_table,
seqlen_k=seqlen_k,
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right=bottom_right,
)[0]
Loading