Skip to content

Commit 0bd3816

Browse files
authored
Merge pull request #89 from ayasyrev/mc_refactor
Mc refactor
2 parents 609540c + 57c89ca commit 0bd3816

17 files changed

+1178
-357
lines changed

src/model_constructor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from model_constructor.convmixer import ConvMixer # noqa F401
22
from model_constructor.model_constructor import (
33
ModelConstructor,
4-
ResBlock,
54
ModelCfg,
65
) # noqa F401
76

src/model_constructor/base_constructor.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1-
import torch.nn as nn
1+
"""First version of constructor.
2+
"""
3+
# Used in examples.
4+
# first implementation of xresnet - inspired by fastai version.
25
from collections import OrderedDict
3-
from .layers import ConvLayer, Noop, Flatten
6+
from functools import partial
47

8+
import torch.nn as nn
59

6-
__all__ = ['act_fn', 'Stem', 'DownsampleBlock', 'BasicBlock', 'Bottleneck', 'BasicLayer', 'Body', 'Head', 'init_model',
7-
'Net']
10+
from .layers import ConvLayer, Flatten, Noop
11+
12+
__all__ = [
13+
"act_fn",
14+
"Stem",
15+
"DownsampleBlock",
16+
"BasicBlock",
17+
"Bottleneck",
18+
"BasicLayer",
19+
"Body",
20+
"Head",
21+
"init_model",
22+
"Net",
23+
"DownsampleLayer",
24+
"XResBlock",
25+
"xresnet18",
26+
"xresnet34",
27+
"xresnet50",
28+
]
829

930

1031
act_fn = nn.ReLU(inplace=True)
@@ -162,3 +183,58 @@ def __init__(self, stem=Stem,
162183
('head', head(body_out * expansion, num_classes, **kwargs))
163184
]))
164185
self.init_model(self)
186+
187+
188+
# xresnet from fastai
189+
190+
191+
class DownsampleLayer(nn.Sequential):
192+
"""Downsample layer for Xresnet Resblock"""
193+
194+
def __init__(self, conv_layer, ni, nf, stride, act,
195+
pool=nn.AvgPool2d(2, ceil_mode=True), pool_1st=True,
196+
**kwargs):
197+
layers = [] if stride == 1 else [('pool', pool)]
198+
layers += [] if ni == nf else [('idconv', conv_layer(ni, nf, 1, act=act, **kwargs))]
199+
if not pool_1st:
200+
layers.reverse()
201+
super().__init__(OrderedDict(layers))
202+
203+
204+
class XResBlock(nn.Module):
205+
'''XResnet block'''
206+
207+
def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
208+
conv_layer=ConvLayer, act_fn=act_fn, **kwargs):
209+
super().__init__()
210+
nf, ni = nh * expansion, ni * expansion
211+
layers = [('conv_0', conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
212+
('conv_1', conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
213+
] if expansion == 1 else [
214+
('conv_0', conv_layer(ni, nh, 1, act_fn=act_fn, **kwargs)),
215+
('conv_1', conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
216+
('conv_2', conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
217+
]
218+
self.convs = nn.Sequential(OrderedDict(layers))
219+
self.identity = DownsampleLayer(conv_layer, ni, nf, stride,
220+
act=False, act_fn=act_fn, **kwargs) if ni != nf or stride == 2 else Noop()
221+
self.merge = Noop()
222+
self.act_fn = act_fn
223+
224+
def forward(self, x):
225+
return self.act_fn(self.merge(self.convs(x) + self.identity(x)))
226+
227+
228+
def xresnet18(**kwargs):
229+
"""Constructs xresnet18 model. """
230+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[2, 2, 2, 2], expansion=1, **kwargs)
231+
232+
233+
def xresnet34(**kwargs):
234+
"""Constructs xresnet34 model. """
235+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=1, **kwargs)
236+
237+
238+
def xresnet50(**kwargs):
239+
"""Constructs xresnet50 model. """
240+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=4, **kwargs)

src/model_constructor/helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from collections import OrderedDict
2+
from typing import Iterable
3+
4+
from torch import nn
5+
6+
7+
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:
8+
"""return nn.Sequential from OrderedDict from list of tuples"""
9+
return nn.Sequential(OrderedDict(list_of_tuples)) #

src/model_constructor/layers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Optional, Type, Union
33

44
import torch
5-
import torch.nn as nn
5+
from torch import nn
66
from torch.nn.utils.spectral_norm import spectral_norm
77

88
__all__ = [
@@ -21,19 +21,19 @@
2121
class Flatten(nn.Module):
2222
"""flat x to vector"""
2323

24-
def forward(self, x):
24+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2525
return x.view(x.size(0), -1)
2626

2727

28-
def noop(x):
28+
def noop(x: torch.Tensor) -> torch.Tensor:
2929
"""Dummy func. Return input"""
3030
return x
3131

3232

3333
class Noop(nn.Module):
3434
"""Dummy module"""
3535

36-
def forward(self, x):
36+
def forward(self, x: torch.Tensor) -> torch.Tensor:
3737
return x
3838

3939

@@ -114,15 +114,15 @@ def __init__(
114114
nf,
115115
ks=3,
116116
stride=1,
117-
act=True,
117+
act=True, # pylint: disable=redefined-outer-name
118118
act_fn=act,
119119
bn_layer=True,
120120
bn_1st=True,
121121
zero_bn=False,
122122
padding=None,
123123
bias=False,
124124
groups=1,
125-
**kwargs
125+
**kwargs # pylint: disable=unused-argument
126126
):
127127

128128
if padding is None:
@@ -176,7 +176,7 @@ def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
176176
self.sym = sym
177177
self.n_in = n_in
178178

179-
def forward(self, x):
179+
def forward(self, x: torch.Tensor) -> torch.Tensor:
180180
if self.sym: # check ks=3
181181
# symmetry hack by https://github.com/mgrankin
182182
c = self.conv.weight.view(self.n_in, self.n_in)
@@ -195,13 +195,14 @@ def forward(self, x):
195195
return o.view(*size).contiguous()
196196

197197

198-
class SEBlock(nn.Module): # todo: deprecation warning.
199-
"se block"
198+
class SEBlock(nn.Module):
199+
"""se block"""
200+
# first version
200201
se_layer = nn.Linear
201202
act_fn = nn.ReLU(inplace=True)
202203
use_bias = True
203204

204-
def __init__(self, c, r=16):
205+
def __init__(self, c: int, r: int = 16):
205206
super().__init__()
206207
ch = max(c // r, 1)
207208
self.squeeze = nn.AdaptiveAvgPool2d(1)
@@ -216,15 +217,16 @@ def __init__(self, c, r=16):
216217
)
217218
)
218219

219-
def forward(self, x):
220+
def forward(self, x: torch.Tensor) -> torch.Tensor:
220221
bs, c, _, _ = x.shape
221222
y = self.squeeze(x).view(bs, c)
222223
y = self.excitation(y).view(bs, c, 1, 1)
223224
return x * y.expand_as(x)
224225

225226

226-
class SEBlockConv(nn.Module): # todo: deprecation warning.
227-
"se block with conv on excitation"
227+
class SEBlockConv(nn.Module):
228+
"""se block with conv on excitation"""
229+
# first version
228230
se_layer = nn.Conv2d
229231
act_fn = nn.ReLU(inplace=True)
230232
use_bias = True

0 commit comments

Comments
 (0)