3
3
# Adopted from https://github.com/tmp-iclr/convmixer
4
4
# Home for convmixer: https://github.com/locuslab/convmixer
5
5
from collections import OrderedDict
6
- from typing import Callable , Optional , Union
6
+ from typing import Callable , List , Optional , Union
7
+
7
8
import torch .nn as nn
9
+ from torch import TensorType
8
10
9
11
10
12
class Residual (nn .Module ):
11
- def __init__ (self , fn ):
13
+ def __init__ (self , fn : Callable [[ TensorType ], TensorType ] ):
12
14
super ().__init__ ()
13
15
self .fn = fn
14
16
15
- def forward (self , x ) :
17
+ def forward (self , x : TensorType ) -> TensorType :
16
18
return self .fn (x ) + x
17
19
18
20
19
21
# As original version, act_fn as argument.
20
22
def ConvMixerOriginal (
21
- dim , depth , kernel_size = 9 , patch_size = 7 , n_classes = 1000 , act_fn = nn .GELU ()
23
+ dim : int ,
24
+ depth : int ,
25
+ kernel_size : int = 9 ,
26
+ patch_size : int = 7 ,
27
+ n_classes : int = 1000 ,
28
+ act_fn : nn .Module = nn .GELU (),
22
29
):
23
30
return nn .Sequential (
24
31
nn .Conv2d (3 , dim , kernel_size = patch_size , stride = patch_size ),
@@ -61,7 +68,7 @@ def __init__(
61
68
pre_act : bool = False ,
62
69
):
63
70
64
- conv_layer = [
71
+ conv_layer : List [ tuple [ str , nn . Module ]] = [
65
72
(
66
73
"conv" ,
67
74
nn .Conv2d (
@@ -74,7 +81,10 @@ def __init__(
74
81
),
75
82
)
76
83
]
77
- act_bn = [("act_fn" , act_fn ), ("bn" , nn .BatchNorm2d (out_channels ))]
84
+ act_bn : List [tuple [str , nn .Module ]] = [
85
+ ("act_fn" , act_fn ),
86
+ ("bn" , nn .BatchNorm2d (out_channels )),
87
+ ]
78
88
if bn_1st :
79
89
act_bn .reverse ()
80
90
if pre_act :
0 commit comments