Skip to content

Commit 7bd5a2a

Browse files
authored
Merge pull request #90 from ayasyrev/convmixer
Convmixer
2 parents 0bd3816 + 9b9b443 commit 7bd5a2a

File tree

1 file changed

+108
-42
lines changed

1 file changed

+108
-42
lines changed

src/model_constructor/convmixer.py

Lines changed: 108 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,49 @@
33
# Adopted from https://github.com/tmp-iclr/convmixer
44
# Home for convmixer: https://github.com/locuslab/convmixer
55
from collections import OrderedDict
6-
from typing import Callable
6+
from typing import Callable, List, Optional, Union
7+
78
import torch.nn as nn
9+
from torch import TensorType
810

911

1012
class Residual(nn.Module):
11-
def __init__(self, fn):
13+
def __init__(self, fn: Callable[[TensorType], TensorType]):
1214
super().__init__()
1315
self.fn = fn
1416

15-
def forward(self, x):
17+
def forward(self, x: TensorType) -> TensorType:
1618
return self.fn(x) + x
1719

1820

1921
# As original version, act_fn as argument.
20-
def ConvMixerOriginal(dim, depth,
21-
kernel_size=9, patch_size=7, n_classes=1000,
22-
act_fn=nn.GELU()):
22+
def ConvMixerOriginal(
23+
dim: int,
24+
depth: int,
25+
kernel_size: int = 9,
26+
patch_size: int = 7,
27+
n_classes: int = 1000,
28+
act_fn: nn.Module = nn.GELU(),
29+
):
2330
return nn.Sequential(
2431
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
2532
act_fn,
2633
nn.BatchNorm2d(dim),
27-
*[nn.Sequential(
28-
Residual(nn.Sequential(
29-
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
34+
*[
35+
nn.Sequential(
36+
Residual(
37+
nn.Sequential(
38+
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
39+
act_fn,
40+
nn.BatchNorm2d(dim),
41+
)
42+
),
43+
nn.Conv2d(dim, dim, kernel_size=1),
3044
act_fn,
31-
nn.BatchNorm2d(dim)
32-
)),
33-
nn.Conv2d(dim, dim, kernel_size=1),
34-
act_fn,
35-
nn.BatchNorm2d(dim)
36-
) for i in range(depth)],
45+
nn.BatchNorm2d(dim),
46+
)
47+
for _i in range(depth)
48+
],
3749
nn.AdaptiveAvgPool2d((1, 1)),
3850
nn.Flatten(),
3951
nn.Linear(dim, n_classes)
@@ -43,15 +55,35 @@ def ConvMixerOriginal(dim, depth,
4355
class ConvLayer(nn.Sequential):
4456
"""Basic conv layers block"""
4557

46-
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
47-
act_fn=nn.GELU(), padding=0, groups=1,
48-
bn_1st=False, pre_act=False):
58+
def __init__(
59+
self,
60+
in_channels: int,
61+
out_channels: int,
62+
kernel_size: Union[int, tuple[int, int]],
63+
stride: int = 1,
64+
act_fn: nn.Module = nn.GELU(),
65+
padding: Union[int, str] = 0,
66+
groups: int = 1,
67+
bn_1st: bool = False,
68+
pre_act: bool = False,
69+
):
4970

50-
conv_layer = [('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
51-
padding=padding, groups=groups))]
52-
act_bn = [
53-
('act_fn', act_fn),
54-
('bn', nn.BatchNorm2d(out_channels))
71+
conv_layer: List[tuple[str, nn.Module]] = [
72+
(
73+
"conv",
74+
nn.Conv2d(
75+
in_channels,
76+
out_channels,
77+
kernel_size,
78+
stride=stride,
79+
padding=padding,
80+
groups=groups,
81+
),
82+
)
83+
]
84+
act_bn: List[tuple[str, nn.Module]] = [
85+
("act_fn", act_fn),
86+
("bn", nn.BatchNorm2d(out_channels)),
5587
]
5688
if bn_1st:
5789
act_bn.reverse()
@@ -64,45 +96,79 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
6496

6597

6698
class ConvMixer(nn.Sequential):
67-
68-
def __init__(self, dim: int, depth: int,
69-
kernel_size: int = 9, patch_size: int = 7, n_classes: int = 1000,
70-
act_fn: nn.Module = nn.GELU(),
71-
stem: nn.Module = None,
72-
bn_1st: bool = False, pre_act: bool = False,
73-
init_func: Callable = None):
99+
def __init__(
100+
self,
101+
dim: int,
102+
depth: int,
103+
kernel_size: int = 9,
104+
patch_size: int = 7,
105+
n_classes: int = 1000,
106+
act_fn: nn.Module = nn.GELU(),
107+
stem: Optional[nn.Module] = None,
108+
in_chans: int = 3,
109+
bn_1st: bool = False,
110+
pre_act: bool = False,
111+
init_func: Optional[Callable[[nn.Module], None]] = None,
112+
):
74113
"""ConvMixer constructor.
75114
Adopted from https://github.com/tmp-iclr/convmixer
76115
77116
Args:
78-
dim (int): Dimention of model.
117+
dim (int): Dimension of model.
79118
depth (int): Depth of model.
80119
kernel_size (int, optional): Kernel size. Defaults to 9.
81120
patch_size (int, optional): Patch size. Defaults to 7.
82121
n_classes (int, optional): Number of classes. Defaults to 1000.
83122
act_fn (nn.Module, optional): Activation function. Defaults to nn.GELU().
84123
stem (nn.Module, optional): You can path different first layer..
85-
stem_ks (int, optional): If stem_ch not 0 - kernel size for adittional layer. Defaults to 1.
86-
bn_1st (bool, optional): If True - BatchNorm befor activation function. Defaults to False.
87-
pre_act (bool, optional): If True - activatin function befor convolution layer. Defaults to False.
124+
stem_ks (int, optional): If stem_ch not 0 - kernel size for additional layer. Defaults to 1.
125+
bn_1st (bool, optional): If True - BatchNorm before activation function. Defaults to False.
126+
pre_act (bool, optional): If True - activation function before convolution layer. Defaults to False.
88127
init_func (Callable, optional): External function for init model.
89128
90129
"""
91130
if pre_act:
92131
bn_1st = False
93132
if stem is None:
94-
stem = ConvLayer(3, dim, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st)
133+
stem = ConvLayer(
134+
in_chans,
135+
dim,
136+
kernel_size=patch_size,
137+
stride=patch_size,
138+
act_fn=act_fn,
139+
bn_1st=bn_1st,
140+
)
95141

96142
super().__init__(
97143
stem,
98-
*[nn.Sequential(
99-
Residual(
100-
ConvLayer(dim, dim, kernel_size, act_fn=act_fn,
101-
groups=dim, padding="same", bn_1st=bn_1st, pre_act=pre_act)),
102-
ConvLayer(dim, dim, kernel_size=1, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act))
103-
for i in range(depth)],
144+
*[
145+
nn.Sequential(
146+
Residual(
147+
ConvLayer(
148+
dim,
149+
dim,
150+
kernel_size,
151+
act_fn=act_fn,
152+
groups=dim,
153+
padding="same",
154+
bn_1st=bn_1st,
155+
pre_act=pre_act,
156+
)
157+
),
158+
ConvLayer(
159+
dim,
160+
dim,
161+
kernel_size=1,
162+
act_fn=act_fn,
163+
bn_1st=bn_1st,
164+
pre_act=pre_act,
165+
),
166+
)
167+
for _ in range(depth)
168+
],
104169
nn.AdaptiveAvgPool2d((1, 1)),
105170
nn.Flatten(),
106-
nn.Linear(dim, n_classes))
171+
nn.Linear(dim, n_classes)
172+
)
107173
if init_func is not None: # pragma: no cover
108174
init_func(self)

0 commit comments

Comments
 (0)