Skip to content

Commit f14f650

Browse files
authored
Merge pull request #2487 from huggingface/eva_pe_integration
Add EVA ViT based PE (Perceptual Encoder) impl
2 parents cabd26d + 88b7ef6 commit f14f650

File tree

6 files changed

+728
-100
lines changed

6 files changed

+728
-100
lines changed

timm/layers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .activations import *
22
from .adaptive_avgmax_pool import \
33
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4+
from .attention import Attention, AttentionRope
45
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
56
from .attention_pool import AttentionPoolLatent
67
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
@@ -41,6 +42,7 @@
4142
from .padding import get_padding, get_same_padding, pad_same
4243
from .patch_dropout import PatchDropout
4344
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
45+
from .pool1d import global_pool_nlc
4446
from .pool2d_same import AvgPool2dSame, create_pool2d
4547
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
4648
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \

timm/layers/attention.py

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from typing import Final, Optional, Type
2+
3+
import torch
4+
from torch import nn as nn
5+
from torch.nn import functional as F
6+
7+
from .config import use_fused_attn
8+
from .pos_embed_sincos import apply_rot_embed_cat
9+
10+
11+
class Attention(nn.Module):
12+
"""Standard Multi-head Self Attention module with QKV projection.
13+
14+
This module implements the standard multi-head attention mechanism used in transformers.
15+
It supports both the fused attention implementation (scaled_dot_product_attention) for
16+
efficiency when available, and a manual implementation otherwise. The module includes
17+
options for QK normalization, attention dropout, and projection dropout.
18+
"""
19+
fused_attn: Final[bool]
20+
21+
def __init__(
22+
self,
23+
dim: int,
24+
num_heads: int = 8,
25+
qkv_bias: bool = False,
26+
qk_norm: bool = False,
27+
proj_bias: bool = True,
28+
attn_drop: float = 0.,
29+
proj_drop: float = 0.,
30+
norm_layer: Type[nn.Module] = nn.LayerNorm,
31+
) -> None:
32+
"""Initialize the Attention module.
33+
34+
Args:
35+
dim: Input dimension of the token embeddings
36+
num_heads: Number of attention heads
37+
qkv_bias: Whether to use bias in the query, key, value projections
38+
qk_norm: Whether to apply normalization to query and key vectors
39+
proj_bias: Whether to use bias in the output projection
40+
attn_drop: Dropout rate applied to the attention weights
41+
proj_drop: Dropout rate applied after the output projection
42+
norm_layer: Normalization layer constructor for QK normalization if enabled
43+
"""
44+
super().__init__()
45+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
46+
self.num_heads = num_heads
47+
self.head_dim = dim // num_heads
48+
self.scale = self.head_dim ** -0.5
49+
self.fused_attn = use_fused_attn()
50+
51+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
53+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
54+
self.attn_drop = nn.Dropout(attn_drop)
55+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
56+
self.proj_drop = nn.Dropout(proj_drop)
57+
58+
def forward(
59+
self,
60+
x: torch.Tensor,
61+
attn_mask: Optional[torch.Tensor] = None,
62+
) -> torch.Tensor:
63+
B, N, C = x.shape
64+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
65+
q, k, v = qkv.unbind(0)
66+
q, k = self.q_norm(q), self.k_norm(k)
67+
68+
if self.fused_attn:
69+
x = F.scaled_dot_product_attention(
70+
q, k, v,
71+
attn_mask=attn_mask,
72+
dropout_p=self.attn_drop.p if self.training else 0.,
73+
)
74+
else:
75+
q = q * self.scale
76+
attn = q @ k.transpose(-2, -1)
77+
if attn_mask is not None:
78+
attn = attn + attn_mask
79+
attn = attn.softmax(dim=-1)
80+
attn = self.attn_drop(attn)
81+
x = attn @ v
82+
83+
x = x.transpose(1, 2).reshape(B, N, C)
84+
x = self.proj(x)
85+
x = self.proj_drop(x)
86+
return x
87+
88+
89+
class AttentionRope(nn.Module):
90+
""" A Self Attention module with ROPE support.
91+
92+
Includes options for:
93+
* QK normalization option
94+
* Attention output (scale) normalization
95+
* Fused or unfused QKV projection support
96+
"""
97+
fused_attn: torch.jit.Final[bool]
98+
99+
def __init__(
100+
self,
101+
dim: int,
102+
num_heads: int = 8,
103+
qkv_bias: bool = True,
104+
qkv_fused: bool = True,
105+
num_prefix_tokens: int = 1,
106+
attn_drop: float = 0.,
107+
proj_drop: float = 0.,
108+
attn_head_dim: Optional[int] = None,
109+
norm_layer: Type[nn.Module] = None,
110+
qk_norm: bool = False,
111+
scale_norm: bool = False,
112+
):
113+
"""Initialize the Attention module.
114+
115+
Args:
116+
dim: Input dimension of the token embeddings
117+
num_heads: Number of attention heads
118+
qkv_bias: Whether to add a bias term to the query, key, and value projections
119+
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
120+
should not have position embeddings applied
121+
attn_drop: Dropout rate for attention weights
122+
proj_drop: Dropout rate for the output projection
123+
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
124+
norm_layer: Normalization layer constructor to use for QK and scale normalization
125+
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
126+
scale_norm: Enable normalization (scaling) of attention output with norm_layer
127+
"""
128+
super().__init__()
129+
if scale_norm or qk_norm:
130+
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
131+
self.num_heads = num_heads
132+
head_dim = dim // num_heads
133+
if attn_head_dim is not None:
134+
head_dim = attn_head_dim
135+
attn_dim = head_dim * self.num_heads
136+
self.scale = head_dim ** -0.5
137+
self.num_prefix_tokens = num_prefix_tokens
138+
self.fused_attn = use_fused_attn()
139+
140+
if qkv_fused:
141+
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
142+
self.q_proj = self.k_proj = self.v_proj = None
143+
else:
144+
self.qkv = None
145+
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
146+
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
147+
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
148+
149+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
150+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
151+
self.attn_drop = nn.Dropout(attn_drop)
152+
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
153+
self.proj = nn.Linear(attn_dim, dim)
154+
self.proj_drop = nn.Dropout(proj_drop)
155+
156+
def forward(
157+
self,
158+
x,
159+
rope: Optional[torch.Tensor] = None,
160+
attn_mask: Optional[torch.Tensor] = None,
161+
):
162+
"""Forward pass for the attention module.
163+
164+
Args:
165+
x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
166+
rope: Rotary position embeddings tensor for position-aware attention
167+
attn_mask: Optional attention mask to apply during attention computation
168+
169+
Returns:
170+
Tensor of shape (batch_size, sequence_length, embedding_dim)
171+
"""
172+
B, N, C = x.shape
173+
174+
if self.qkv is not None:
175+
qkv = self.qkv(x)
176+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
177+
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
178+
else:
179+
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
180+
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
181+
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
182+
183+
q, k = self.q_norm(q), self.k_norm(k)
184+
185+
if rope is not None:
186+
npt = self.num_prefix_tokens
187+
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
188+
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
189+
190+
if self.fused_attn:
191+
x = F.scaled_dot_product_attention(
192+
q, k, v,
193+
attn_mask=attn_mask,
194+
dropout_p=self.attn_drop.p if self.training else 0.,
195+
)
196+
else:
197+
q = q * self.scale
198+
attn = (q @ k.transpose(-2, -1))
199+
200+
if attn_mask is not None:
201+
attn_mask = attn_mask.to(torch.bool)
202+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
203+
attn = attn.softmax(dim=-1)
204+
205+
attn = self.attn_drop(attn)
206+
x = attn @ v
207+
208+
x = x.transpose(1, 2).reshape(B, N, C)
209+
x = self.norm(x)
210+
x = self.proj(x)
211+
x = self.proj_drop(x)
212+
return x

timm/layers/attention_pool.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Type
22

33
import torch
44
import torch.nn as nn
@@ -28,8 +28,8 @@ def __init__(
2828
latent_dim: int = None,
2929
pos_embed: str = '',
3030
pool_type: str = 'token',
31-
norm_layer: Optional[nn.Module] = None,
32-
act_layer: Optional[nn.Module] = nn.GELU,
31+
norm_layer: Optional[Type[nn.Module]] = None,
32+
act_layer: Optional[Type[nn.Module]] = nn.GELU,
3333
drop: float = 0.0,
3434
):
3535
super().__init__()

timm/layers/pool1d.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
3+
4+
def global_pool_nlc(
5+
x: torch.Tensor,
6+
pool_type: str = 'token',
7+
num_prefix_tokens: int = 1,
8+
reduce_include_prefix: bool = False,
9+
):
10+
if not pool_type:
11+
return x
12+
13+
if pool_type == 'token':
14+
x = x[:, 0] # class token
15+
else:
16+
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
17+
if pool_type == 'avg':
18+
x = x.mean(dim=1)
19+
elif pool_type == 'avgmax':
20+
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
21+
elif pool_type == 'max':
22+
x = x.amax(dim=1)
23+
else:
24+
assert not pool_type, f'Unknown pool type {pool_type}'
25+
26+
return x

0 commit comments

Comments
 (0)