13
13
# Written by Ze Liu
14
14
# --------------------------------------------------------
15
15
import math
16
- from typing import Callable , Optional , Tuple , Union
16
+ from typing import Callable , Optional , Tuple , Union , Set , Dict
17
17
18
18
import torch
19
19
import torch .nn as nn
32
32
_int_or_tuple_2_t = Union [int , Tuple [int , int ]]
33
33
34
34
35
- def window_partition (x , window_size : Tuple [int , int ]):
35
+ def window_partition (x : torch . Tensor , window_size : Tuple [int , int ]) -> torch . Tensor :
36
36
"""
37
37
Args:
38
38
x: (B, H, W, C)
@@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]):
48
48
49
49
50
50
@register_notrace_function # reason: int argument is a Proxy
51
- def window_reverse (windows , window_size : Tuple [int , int ], img_size : Tuple [int , int ]):
51
+ def window_reverse (windows : torch . Tensor , window_size : Tuple [int , int ], img_size : Tuple [int , int ]) -> torch . Tensor :
52
52
"""
53
53
Args:
54
54
windows: (num_windows * B, window_size[0], window_size[1], C)
@@ -81,14 +81,14 @@ class WindowAttention(nn.Module):
81
81
82
82
def __init__ (
83
83
self ,
84
- dim ,
85
- window_size ,
86
- num_heads ,
87
- qkv_bias = True ,
88
- attn_drop = 0. ,
89
- proj_drop = 0. ,
90
- pretrained_window_size = [ 0 , 0 ] ,
91
- ):
84
+ dim : int ,
85
+ window_size : Tuple [ int , int ] ,
86
+ num_heads : int ,
87
+ qkv_bias : bool = True ,
88
+ attn_drop : float = 0. ,
89
+ proj_drop : float = 0. ,
90
+ pretrained_window_size : Tuple [ int , int ] = ( 0 , 0 ) ,
91
+ ) -> None :
92
92
super ().__init__ ()
93
93
self .dim = dim
94
94
self .window_size = window_size # Wh, Ww
@@ -149,7 +149,7 @@ def __init__(
149
149
self .proj_drop = nn .Dropout (proj_drop )
150
150
self .softmax = nn .Softmax (dim = - 1 )
151
151
152
- def forward (self , x , mask : Optional [torch .Tensor ] = None ):
152
+ def forward (self , x : torch . Tensor , mask : Optional [torch .Tensor ] = None ) -> torch . Tensor :
153
153
"""
154
154
Args:
155
155
x: input features with shape of (num_windows*B, N, C)
@@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module):
197
197
198
198
def __init__ (
199
199
self ,
200
- dim ,
201
- input_resolution ,
202
- num_heads ,
203
- window_size = 7 ,
204
- shift_size = 0 ,
205
- mlp_ratio = 4. ,
206
- qkv_bias = True ,
207
- proj_drop = 0. ,
208
- attn_drop = 0. ,
209
- drop_path = 0. ,
210
- act_layer = nn .GELU ,
211
- norm_layer = nn .LayerNorm ,
212
- pretrained_window_size = 0 ,
213
- ):
200
+ dim : int ,
201
+ input_resolution : _int_or_tuple_2_t ,
202
+ num_heads : int ,
203
+ window_size : _int_or_tuple_2_t = 7 ,
204
+ shift_size : _int_or_tuple_2_t = 0 ,
205
+ mlp_ratio : float = 4. ,
206
+ qkv_bias : bool = True ,
207
+ proj_drop : float = 0. ,
208
+ attn_drop : float = 0. ,
209
+ drop_path : float = 0. ,
210
+ act_layer : nn . Module = nn .GELU ,
211
+ norm_layer : nn . Module = nn .LayerNorm ,
212
+ pretrained_window_size : _int_or_tuple_2_t = 0 ,
213
+ ) -> None :
214
214
"""
215
215
Args:
216
216
dim: Number of input channels.
@@ -282,14 +282,16 @@ def __init__(
282
282
283
283
self .register_buffer ("attn_mask" , attn_mask , persistent = False )
284
284
285
- def _calc_window_shift (self , target_window_size , target_shift_size ) -> Tuple [Tuple [int , int ], Tuple [int , int ]]:
285
+ def _calc_window_shift (self ,
286
+ target_window_size : _int_or_tuple_2_t ,
287
+ target_shift_size : _int_or_tuple_2_t ) -> Tuple [Tuple [int , int ], Tuple [int , int ]]:
286
288
target_window_size = to_2tuple (target_window_size )
287
289
target_shift_size = to_2tuple (target_shift_size )
288
290
window_size = [r if r <= w else w for r , w in zip (self .input_resolution , target_window_size )]
289
291
shift_size = [0 if r <= w else s for r , w , s in zip (self .input_resolution , window_size , target_shift_size )]
290
292
return tuple (window_size ), tuple (shift_size )
291
293
292
- def _attn (self , x ) :
294
+ def _attn (self , x : torch . Tensor ) -> torch . Tensor :
293
295
B , H , W , C = x .shape
294
296
295
297
# cyclic shift
@@ -317,7 +319,7 @@ def _attn(self, x):
317
319
x = shifted_x
318
320
return x
319
321
320
- def forward (self , x ) :
322
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
321
323
B , H , W , C = x .shape
322
324
x = x + self .drop_path1 (self .norm1 (self ._attn (x )))
323
325
x = x .reshape (B , - 1 , C )
@@ -330,7 +332,7 @@ class PatchMerging(nn.Module):
330
332
""" Patch Merging Layer.
331
333
"""
332
334
333
- def __init__ (self , dim , out_dim = None , norm_layer = nn .LayerNorm ):
335
+ def __init__ (self , dim : int , out_dim : Optional [ int ] = None , norm_layer : nn . Module = nn .LayerNorm ) -> None :
334
336
"""
335
337
Args:
336
338
dim (int): Number of input channels.
@@ -343,7 +345,7 @@ def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
343
345
self .reduction = nn .Linear (4 * dim , self .out_dim , bias = False )
344
346
self .norm = norm_layer (self .out_dim )
345
347
346
- def forward (self , x ) :
348
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
347
349
B , H , W , C = x .shape
348
350
_assert (H % 2 == 0 , f"x height ({ H } ) is not even." )
349
351
_assert (W % 2 == 0 , f"x width ({ W } ) is not even." )
@@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module):
359
361
360
362
def __init__ (
361
363
self ,
362
- dim ,
363
- out_dim ,
364
- input_resolution ,
365
- depth ,
366
- num_heads ,
367
- window_size ,
368
- downsample = False ,
369
- mlp_ratio = 4. ,
370
- qkv_bias = True ,
371
- proj_drop = 0. ,
372
- attn_drop = 0. ,
373
- drop_path = 0. ,
374
- norm_layer = nn .LayerNorm ,
375
- pretrained_window_size = 0 ,
376
- output_nchw = False ,
377
- ):
364
+ dim : int ,
365
+ out_dim : int ,
366
+ input_resolution : _int_or_tuple_2_t ,
367
+ depth : int ,
368
+ num_heads : int ,
369
+ window_size : _int_or_tuple_2_t ,
370
+ downsample : bool = False ,
371
+ mlp_ratio : float = 4. ,
372
+ qkv_bias : bool = True ,
373
+ proj_drop : float = 0. ,
374
+ attn_drop : float = 0. ,
375
+ drop_path : float = 0. ,
376
+ norm_layer : nn . Module = nn .LayerNorm ,
377
+ pretrained_window_size : _int_or_tuple_2_t = 0 ,
378
+ output_nchw : bool = False ,
379
+ ) -> None :
378
380
"""
379
381
Args:
380
382
dim: Number of input channels.
@@ -428,7 +430,7 @@ def __init__(
428
430
)
429
431
for i in range (depth )])
430
432
431
- def forward (self , x ) :
433
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
432
434
x = self .downsample (x )
433
435
434
436
for blk in self .blocks :
@@ -438,7 +440,7 @@ def forward(self, x):
438
440
x = blk (x )
439
441
return x
440
442
441
- def _init_respostnorm (self ):
443
+ def _init_respostnorm (self ) -> None :
442
444
for blk in self .blocks :
443
445
nn .init .constant_ (blk .norm1 .bias , 0 )
444
446
nn .init .constant_ (blk .norm1 .weight , 0 )
0 commit comments