|
| 1 | +from typing import Final, Optional, Type |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn as nn |
| 5 | +from torch.nn import functional as F |
| 6 | + |
| 7 | +from .config import use_fused_attn |
| 8 | +from .pos_embed_sincos import apply_rot_embed_cat |
| 9 | + |
| 10 | + |
| 11 | +class Attention(nn.Module): |
| 12 | + """Standard Multi-head Self Attention module with QKV projection. |
| 13 | +
|
| 14 | + This module implements the standard multi-head attention mechanism used in transformers. |
| 15 | + It supports both the fused attention implementation (scaled_dot_product_attention) for |
| 16 | + efficiency when available, and a manual implementation otherwise. The module includes |
| 17 | + options for QK normalization, attention dropout, and projection dropout. |
| 18 | + """ |
| 19 | + fused_attn: Final[bool] |
| 20 | + |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + dim: int, |
| 24 | + num_heads: int = 8, |
| 25 | + qkv_bias: bool = False, |
| 26 | + qk_norm: bool = False, |
| 27 | + proj_bias: bool = True, |
| 28 | + attn_drop: float = 0., |
| 29 | + proj_drop: float = 0., |
| 30 | + norm_layer: Type[nn.Module] = nn.LayerNorm, |
| 31 | + ) -> None: |
| 32 | + """Initialize the Attention module. |
| 33 | +
|
| 34 | + Args: |
| 35 | + dim: Input dimension of the token embeddings |
| 36 | + num_heads: Number of attention heads |
| 37 | + qkv_bias: Whether to use bias in the query, key, value projections |
| 38 | + qk_norm: Whether to apply normalization to query and key vectors |
| 39 | + proj_bias: Whether to use bias in the output projection |
| 40 | + attn_drop: Dropout rate applied to the attention weights |
| 41 | + proj_drop: Dropout rate applied after the output projection |
| 42 | + norm_layer: Normalization layer constructor for QK normalization if enabled |
| 43 | + """ |
| 44 | + super().__init__() |
| 45 | + assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| 46 | + self.num_heads = num_heads |
| 47 | + self.head_dim = dim // num_heads |
| 48 | + self.scale = self.head_dim ** -0.5 |
| 49 | + self.fused_attn = use_fused_attn() |
| 50 | + |
| 51 | + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| 52 | + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| 53 | + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| 54 | + self.attn_drop = nn.Dropout(attn_drop) |
| 55 | + self.proj = nn.Linear(dim, dim, bias=proj_bias) |
| 56 | + self.proj_drop = nn.Dropout(proj_drop) |
| 57 | + |
| 58 | + def forward( |
| 59 | + self, |
| 60 | + x: torch.Tensor, |
| 61 | + attn_mask: Optional[torch.Tensor] = None, |
| 62 | + ) -> torch.Tensor: |
| 63 | + B, N, C = x.shape |
| 64 | + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| 65 | + q, k, v = qkv.unbind(0) |
| 66 | + q, k = self.q_norm(q), self.k_norm(k) |
| 67 | + |
| 68 | + if self.fused_attn: |
| 69 | + x = F.scaled_dot_product_attention( |
| 70 | + q, k, v, |
| 71 | + attn_mask=attn_mask, |
| 72 | + dropout_p=self.attn_drop.p if self.training else 0., |
| 73 | + ) |
| 74 | + else: |
| 75 | + q = q * self.scale |
| 76 | + attn = q @ k.transpose(-2, -1) |
| 77 | + if attn_mask is not None: |
| 78 | + attn = attn + attn_mask |
| 79 | + attn = attn.softmax(dim=-1) |
| 80 | + attn = self.attn_drop(attn) |
| 81 | + x = attn @ v |
| 82 | + |
| 83 | + x = x.transpose(1, 2).reshape(B, N, C) |
| 84 | + x = self.proj(x) |
| 85 | + x = self.proj_drop(x) |
| 86 | + return x |
| 87 | + |
| 88 | + |
| 89 | +class AttentionRope(nn.Module): |
| 90 | + """ A Self Attention module with ROPE support. |
| 91 | +
|
| 92 | + Includes options for: |
| 93 | + * QK normalization option |
| 94 | + * Attention output (scale) normalization |
| 95 | + * Fused or unfused QKV projection support |
| 96 | + """ |
| 97 | + fused_attn: torch.jit.Final[bool] |
| 98 | + |
| 99 | + def __init__( |
| 100 | + self, |
| 101 | + dim: int, |
| 102 | + num_heads: int = 8, |
| 103 | + qkv_bias: bool = True, |
| 104 | + qkv_fused: bool = True, |
| 105 | + num_prefix_tokens: int = 1, |
| 106 | + attn_drop: float = 0., |
| 107 | + proj_drop: float = 0., |
| 108 | + attn_head_dim: Optional[int] = None, |
| 109 | + norm_layer: Type[nn.Module] = None, |
| 110 | + qk_norm: bool = False, |
| 111 | + scale_norm: bool = False, |
| 112 | + ): |
| 113 | + """Initialize the Attention module. |
| 114 | +
|
| 115 | + Args: |
| 116 | + dim: Input dimension of the token embeddings |
| 117 | + num_heads: Number of attention heads |
| 118 | + qkv_bias: Whether to add a bias term to the query, key, and value projections |
| 119 | + num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that |
| 120 | + should not have position embeddings applied |
| 121 | + attn_drop: Dropout rate for attention weights |
| 122 | + proj_drop: Dropout rate for the output projection |
| 123 | + attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) |
| 124 | + norm_layer: Normalization layer constructor to use for QK and scale normalization |
| 125 | + qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer |
| 126 | + scale_norm: Enable normalization (scaling) of attention output with norm_layer |
| 127 | + """ |
| 128 | + super().__init__() |
| 129 | + if scale_norm or qk_norm: |
| 130 | + assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' |
| 131 | + self.num_heads = num_heads |
| 132 | + head_dim = dim // num_heads |
| 133 | + if attn_head_dim is not None: |
| 134 | + head_dim = attn_head_dim |
| 135 | + attn_dim = head_dim * self.num_heads |
| 136 | + self.scale = head_dim ** -0.5 |
| 137 | + self.num_prefix_tokens = num_prefix_tokens |
| 138 | + self.fused_attn = use_fused_attn() |
| 139 | + |
| 140 | + if qkv_fused: |
| 141 | + self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) |
| 142 | + self.q_proj = self.k_proj = self.v_proj = None |
| 143 | + else: |
| 144 | + self.qkv = None |
| 145 | + self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) |
| 146 | + self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) |
| 147 | + self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) |
| 148 | + |
| 149 | + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| 150 | + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| 151 | + self.attn_drop = nn.Dropout(attn_drop) |
| 152 | + self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity() |
| 153 | + self.proj = nn.Linear(attn_dim, dim) |
| 154 | + self.proj_drop = nn.Dropout(proj_drop) |
| 155 | + |
| 156 | + def forward( |
| 157 | + self, |
| 158 | + x, |
| 159 | + rope: Optional[torch.Tensor] = None, |
| 160 | + attn_mask: Optional[torch.Tensor] = None, |
| 161 | + ): |
| 162 | + """Forward pass for the attention module. |
| 163 | +
|
| 164 | + Args: |
| 165 | + x: Input tensor of shape (batch_size, sequence_length, embedding_dim) |
| 166 | + rope: Rotary position embeddings tensor for position-aware attention |
| 167 | + attn_mask: Optional attention mask to apply during attention computation |
| 168 | +
|
| 169 | + Returns: |
| 170 | + Tensor of shape (batch_size, sequence_length, embedding_dim) |
| 171 | + """ |
| 172 | + B, N, C = x.shape |
| 173 | + |
| 174 | + if self.qkv is not None: |
| 175 | + qkv = self.qkv(x) |
| 176 | + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
| 177 | + q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim |
| 178 | + else: |
| 179 | + q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C |
| 180 | + k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) |
| 181 | + v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) |
| 182 | + |
| 183 | + q, k = self.q_norm(q), self.k_norm(k) |
| 184 | + |
| 185 | + if rope is not None: |
| 186 | + npt = self.num_prefix_tokens |
| 187 | + q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) |
| 188 | + k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v) |
| 189 | + |
| 190 | + if self.fused_attn: |
| 191 | + x = F.scaled_dot_product_attention( |
| 192 | + q, k, v, |
| 193 | + attn_mask=attn_mask, |
| 194 | + dropout_p=self.attn_drop.p if self.training else 0., |
| 195 | + ) |
| 196 | + else: |
| 197 | + q = q * self.scale |
| 198 | + attn = (q @ k.transpose(-2, -1)) |
| 199 | + |
| 200 | + if attn_mask is not None: |
| 201 | + attn_mask = attn_mask.to(torch.bool) |
| 202 | + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) |
| 203 | + attn = attn.softmax(dim=-1) |
| 204 | + |
| 205 | + attn = self.attn_drop(attn) |
| 206 | + x = attn @ v |
| 207 | + |
| 208 | + x = x.transpose(1, 2).reshape(B, N, C) |
| 209 | + x = self.norm(x) |
| 210 | + x = self.proj(x) |
| 211 | + x = self.proj_drop(x) |
| 212 | + return x |
0 commit comments