|
20 | 20 | create_block_mask, |
21 | 21 | flex_attention, |
22 | 22 | ) |
| 23 | +from vllm.vllm_flash_attn import flash_attn_varlen_func |
23 | 24 |
|
24 | 25 |
|
25 | 26 | __all__ = [ |
@@ -103,6 +104,69 @@ def forward( |
103 | 104 | with sdpa_kernel(self.sdpa_backends, set_priority=True): |
104 | 105 | return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) |
105 | 106 |
|
| 107 | +class VLLMCompatibleFlashAttention(torch.nn.Module): |
| 108 | + """Wrapper around FlashAttention as used by VLLM""" |
| 109 | + def __init__(self) -> None: |
| 110 | + super().__init__() |
| 111 | + self.flash_attn_varlen_func = flash_attn_varlen_func |
| 112 | + from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant |
| 113 | + self.vllm_is_batch_invariant = vllm_is_batch_invariant |
| 114 | + |
| 115 | + def forward( |
| 116 | + self, |
| 117 | + q: torch.Tensor, |
| 118 | + k: torch.Tensor, |
| 119 | + v: torch.Tensor, |
| 120 | + *, |
| 121 | + scale: float | None = None, |
| 122 | + ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: |
| 123 | + # Flash Attention varlen expects: (batch, seqlen, nheads, headdim) |
| 124 | + # The input from TorchTitan is always (batch, num_heads, seq_len, head_dim) |
| 125 | + # We need to transpose to (batch, seq_len, num_heads, head_dim) |
| 126 | + |
| 127 | + # Input is (batch, num_heads, seq_len, head_dim) - need to transpose |
| 128 | + q = q.transpose(1, 2) # -> (batch, seq_len, num_heads, head_dim) |
| 129 | + k = k.transpose(1, 2) |
| 130 | + v = v.transpose(1, 2) |
| 131 | + |
| 132 | + # Get dimensions |
| 133 | + batch_size, seq_len, num_heads, head_dim = q.shape |
| 134 | + |
| 135 | + # Convert to varlen format: flatten batch and sequence dimensions |
| 136 | + # (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim) |
| 137 | + q_varlen = q.reshape(-1, num_heads, head_dim) |
| 138 | + k_varlen = k.reshape(-1, k.shape[2], head_dim) |
| 139 | + v_varlen = v.reshape(-1, v.shape[2], head_dim) |
| 140 | + |
| 141 | + # Create cumulative sequence lengths |
| 142 | + # cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len] |
| 143 | + cu_seqlens = torch.arange( |
| 144 | + 0, (batch_size + 1) * seq_len, seq_len, |
| 145 | + dtype=torch.int32, device=q.device |
| 146 | + ) |
| 147 | + |
| 148 | + # Call Flash Attention varlen (works with both standard flash-attn and vLLM's wrapper) |
| 149 | + output_varlen = self.flash_attn_varlen_func( |
| 150 | + q_varlen, k_varlen, v_varlen, |
| 151 | + cu_seqlens_q=cu_seqlens, |
| 152 | + cu_seqlens_k=cu_seqlens, |
| 153 | + max_seqlen_q=seq_len, |
| 154 | + max_seqlen_k=seq_len, |
| 155 | + softmax_scale=scale, |
| 156 | + causal=True, |
| 157 | + num_splits=1 if self.vllm_is_batch_invariant() else 0, |
| 158 | + ) |
| 159 | + |
| 160 | + # Convert back to batch format |
| 161 | + # (total_tokens, nheads, headdim) -> (batch, seqlen, nheads, headdim) |
| 162 | + output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) |
| 163 | + |
| 164 | + # Transpose back to (batch, num_heads, seq_len, head_dim) to match input format |
| 165 | + output = output.transpose(1, 2) |
| 166 | + |
| 167 | + return output |
| 168 | + |
| 169 | + |
106 | 170 |
|
107 | 171 | # We cannot do inner function/closure because we won't be able to cache it -- |
108 | 172 | # if we an inner function, a new closure will be created every time |
|
0 commit comments