Skip to content

Commit 7da34a9

Browse files
lizhuoqrwightman
authored andcommitted
add type annotations in the code of swin_transformer_v2
1 parent bbe7983 commit 7da34a9

File tree

1 file changed

+51
-49
lines changed

1 file changed

+51
-49
lines changed

timm/models/swin_transformer_v2.py

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# Written by Ze Liu
1414
# --------------------------------------------------------
1515
import math
16-
from typing import Callable, Optional, Tuple, Union
16+
from typing import Callable, Optional, Tuple, Union, Set, Dict
1717

1818
import torch
1919
import torch.nn as nn
@@ -32,7 +32,7 @@
3232
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
3333

3434

35-
def window_partition(x, window_size: Tuple[int, int]):
35+
def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor:
3636
"""
3737
Args:
3838
x: (B, H, W, C)
@@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]):
4848

4949

5050
@register_notrace_function # reason: int argument is a Proxy
51-
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
51+
def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor:
5252
"""
5353
Args:
5454
windows: (num_windows * B, window_size[0], window_size[1], C)
@@ -81,14 +81,14 @@ class WindowAttention(nn.Module):
8181

8282
def __init__(
8383
self,
84-
dim,
85-
window_size,
86-
num_heads,
87-
qkv_bias=True,
88-
attn_drop=0.,
89-
proj_drop=0.,
90-
pretrained_window_size=[0, 0],
91-
):
84+
dim: int,
85+
window_size: Tuple[int, int],
86+
num_heads: int,
87+
qkv_bias: bool = True,
88+
attn_drop: float = 0.,
89+
proj_drop: float = 0.,
90+
pretrained_window_size: Tuple[int, int] = (0, 0),
91+
) -> None:
9292
super().__init__()
9393
self.dim = dim
9494
self.window_size = window_size # Wh, Ww
@@ -149,7 +149,7 @@ def __init__(
149149
self.proj_drop = nn.Dropout(proj_drop)
150150
self.softmax = nn.Softmax(dim=-1)
151151

152-
def forward(self, x, mask: Optional[torch.Tensor] = None):
152+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
153153
"""
154154
Args:
155155
x: input features with shape of (num_windows*B, N, C)
@@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module):
197197

198198
def __init__(
199199
self,
200-
dim,
201-
input_resolution,
202-
num_heads,
203-
window_size=7,
204-
shift_size=0,
205-
mlp_ratio=4.,
206-
qkv_bias=True,
207-
proj_drop=0.,
208-
attn_drop=0.,
209-
drop_path=0.,
210-
act_layer=nn.GELU,
211-
norm_layer=nn.LayerNorm,
212-
pretrained_window_size=0,
213-
):
200+
dim: int,
201+
input_resolution: _int_or_tuple_2_t,
202+
num_heads: int,
203+
window_size: _int_or_tuple_2_t = 7,
204+
shift_size: _int_or_tuple_2_t = 0,
205+
mlp_ratio: float = 4.,
206+
qkv_bias: bool = True,
207+
proj_drop: float = 0.,
208+
attn_drop: float = 0.,
209+
drop_path: float = 0.,
210+
act_layer: nn.Module = nn.GELU,
211+
norm_layer: nn.Module = nn.LayerNorm,
212+
pretrained_window_size: _int_or_tuple_2_t = 0,
213+
) -> None:
214214
"""
215215
Args:
216216
dim: Number of input channels.
@@ -282,14 +282,16 @@ def __init__(
282282

283283
self.register_buffer("attn_mask", attn_mask, persistent=False)
284284

285-
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
285+
def _calc_window_shift(self,
286+
target_window_size: _int_or_tuple_2_t,
287+
target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]:
286288
target_window_size = to_2tuple(target_window_size)
287289
target_shift_size = to_2tuple(target_shift_size)
288290
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
289291
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
290292
return tuple(window_size), tuple(shift_size)
291293

292-
def _attn(self, x):
294+
def _attn(self, x: torch.Tensor) -> torch.Tensor:
293295
B, H, W, C = x.shape
294296

295297
# cyclic shift
@@ -317,7 +319,7 @@ def _attn(self, x):
317319
x = shifted_x
318320
return x
319321

320-
def forward(self, x):
322+
def forward(self, x: torch.Tensor) -> torch.Tensor:
321323
B, H, W, C = x.shape
322324
x = x + self.drop_path1(self.norm1(self._attn(x)))
323325
x = x.reshape(B, -1, C)
@@ -330,7 +332,7 @@ class PatchMerging(nn.Module):
330332
""" Patch Merging Layer.
331333
"""
332334

333-
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
335+
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None:
334336
"""
335337
Args:
336338
dim (int): Number of input channels.
@@ -343,7 +345,7 @@ def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
343345
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
344346
self.norm = norm_layer(self.out_dim)
345347

346-
def forward(self, x):
348+
def forward(self, x: torch.Tensor) -> torch.Tensor:
347349
B, H, W, C = x.shape
348350
_assert(H % 2 == 0, f"x height ({H}) is not even.")
349351
_assert(W % 2 == 0, f"x width ({W}) is not even.")
@@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module):
359361

360362
def __init__(
361363
self,
362-
dim,
363-
out_dim,
364-
input_resolution,
365-
depth,
366-
num_heads,
367-
window_size,
368-
downsample=False,
369-
mlp_ratio=4.,
370-
qkv_bias=True,
371-
proj_drop=0.,
372-
attn_drop=0.,
373-
drop_path=0.,
374-
norm_layer=nn.LayerNorm,
375-
pretrained_window_size=0,
376-
output_nchw=False,
377-
):
364+
dim: int,
365+
out_dim: int,
366+
input_resolution: _int_or_tuple_2_t,
367+
depth: int,
368+
num_heads: int,
369+
window_size: _int_or_tuple_2_t,
370+
downsample: bool = False,
371+
mlp_ratio: float = 4.,
372+
qkv_bias: bool = True,
373+
proj_drop: float = 0.,
374+
attn_drop: float = 0.,
375+
drop_path: float = 0.,
376+
norm_layer: nn.Module = nn.LayerNorm,
377+
pretrained_window_size: _int_or_tuple_2_t = 0,
378+
output_nchw: bool = False,
379+
) -> None:
378380
"""
379381
Args:
380382
dim: Number of input channels.
@@ -428,7 +430,7 @@ def __init__(
428430
)
429431
for i in range(depth)])
430432

431-
def forward(self, x):
433+
def forward(self, x: torch.Tensor) -> torch.Tensor:
432434
x = self.downsample(x)
433435

434436
for blk in self.blocks:
@@ -438,7 +440,7 @@ def forward(self, x):
438440
x = blk(x)
439441
return x
440442

441-
def _init_respostnorm(self):
443+
def _init_respostnorm(self) -> None:
442444
for blk in self.blocks:
443445
nn.init.constant_(blk.norm1.bias, 0)
444446
nn.init.constant_(blk.norm1.weight, 0)

0 commit comments

Comments
 (0)