Skip to content

Commit cfb2c00

Browse files
author
ayasyrev
committed
added groups
1 parent 8214241 commit cfb2c00

File tree

5 files changed

+439
-217
lines changed

5 files changed

+439
-217
lines changed

docs/Net.html

Lines changed: 77 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,60 @@ <h2 id="ResBlock" class="doc_header"><code>class</code> <code>ResBlock</code><a
190190
<div class="cell border-box-sizing code_cell rendered">
191191
<div class="input">
192192

193+
<div class="inner_cell">
194+
<div class="input_area">
195+
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">ResBlock</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="n">sa</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
196+
</pre></div>
197+
198+
</div>
199+
</div>
200+
</div>
201+
202+
<div class="output_wrapper">
203+
<div class="output">
204+
205+
<div class="output_area">
206+
207+
208+
209+
<div class="output_text output_subarea output_execute_result">
210+
<pre>ResBlock(
211+
(convs): Sequential(
212+
(conv_0): ConvLayer(
213+
(conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
214+
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
215+
(act_fn): ReLU(inplace=True)
216+
)
217+
(conv_1): ConvLayer(
218+
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
219+
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
220+
(act_fn): ReLU(inplace=True)
221+
)
222+
(conv_2): ConvLayer(
223+
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
224+
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
225+
)
226+
(sa): SimpleSelfAttention(
227+
(conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
228+
)
229+
)
230+
(act_fn): ReLU(inplace=True)
231+
)</pre>
232+
</div>
233+
234+
</div>
235+
236+
</div>
237+
</div>
238+
239+
</div>
240+
{% endraw %}
241+
242+
{% raw %}
243+
244+
<div class="cell border-box-sizing code_cell rendered">
245+
<div class="input">
246+
193247
<div class="inner_cell">
194248
<div class="input_area">
195249
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">ResBlock</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="mi">64</span><span class="p">,</span><span class="n">act_fn</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> <span class="n">bn_1st</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
@@ -260,7 +314,7 @@ <h1 id="NewResBlock">NewResBlock<a class="anchor-link" href="#NewResBlock"> </a>
260314

261315

262316
<div class="output_markdown rendered_html output_subarea ">
263-
<h2 id="NewResBlock" class="doc_header"><code>class</code> <code>NewResBlock</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/net.py#L46" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>NewResBlock</code>(<strong><code>expansion</code></strong>, <strong><code>ni</code></strong>, <strong><code>nh</code></strong>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>conv_layer</code></strong>=<em><code>'ConvLayer'</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>zero_bn</code></strong>=<em><code>True</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>pool</code></strong>=<em><code>AvgPool2d(kernel_size=2, stride=2, padding=0)</code></em>, <strong><code>sa</code></strong>=<em><code>False</code></em>, <strong><code>sym</code></strong>=<em><code>False</code></em>) :: <code>Module</code></p>
317+
<h2 id="NewResBlock" class="doc_header"><code>class</code> <code>NewResBlock</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/net.py#L46" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>NewResBlock</code>(<strong><code>expansion</code></strong>, <strong><code>ni</code></strong>, <strong><code>nh</code></strong>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>conv_layer</code></strong>=<em><code>'ConvLayer'</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>zero_bn</code></strong>=<em><code>True</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>pool</code></strong>=<em><code>AvgPool2d(kernel_size=2, stride=2, padding=0)</code></em>, <strong><code>sa</code></strong>=<em><code>False</code></em>, <strong><code>sym</code></strong>=<em><code>False</code></em>, <strong><code>groups</code></strong>=<em><code>1</code></em>) :: <code>Module</code></p>
264318
</blockquote>
265319
<p>Base class for all neural network modules.</p>
266320
<p>Your models should also subclass this class.</p>
@@ -404,7 +458,10 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
404458

405459

406460
<div class="output_text output_subarea output_execute_result">
407-
<pre> constr Net</pre>
461+
<pre> constr Net
462+
expansion: 1, sa: 0, groups: 1
463+
stem sizes: [3, 32, 32, 64]
464+
body sizes [64, 64, 128, 256, 512]</pre>
408465
</div>
409466

410467
</div>
@@ -461,7 +518,10 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
461518
</div>
462519
</div>
463520
</div>
464-
521+
<details class="description">
522+
<summary data-open="Hide Output" data-close="Show Output"></summary>
523+
<summary></summary>
524+
465525
<div class="output_wrapper">
466526
<div class="output">
467527

@@ -473,18 +533,18 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
473533
<pre>Sequential(
474534
(conv_0): ConvLayer(
475535
(conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
536+
(act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
476537
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
477-
(act_fn): ReLU(inplace=True)
478538
)
479539
(conv_1): ConvLayer(
480540
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
541+
(act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
481542
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
482-
(act_fn): ReLU(inplace=True)
483543
)
484544
(conv_2): ConvLayer(
485545
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
546+
(act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
486547
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
487-
(act_fn): ReLU(inplace=True)
488548
)
489549
(stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
490550
)</pre>
@@ -495,6 +555,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
495555
</div>
496556
</div>
497557

558+
</details>
498559
</div>
499560
{% endraw %}
500561

@@ -562,7 +623,10 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
562623
</div>
563624
</div>
564625
</div>
565-
626+
<details class="description">
627+
<summary data-open="Hide Output" data-close="Show Output"></summary>
628+
<summary></summary>
629+
566630
<div class="output_wrapper">
567631
<div class="output">
568632

@@ -613,6 +677,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
613677
</div>
614678
</div>
615679

680+
</details>
616681
</div>
617682
{% endraw %}
618683

@@ -647,73 +712,6 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
647712
</div>
648713
</div>
649714

650-
</div>
651-
{% endraw %}
652-
653-
{% raw %}
654-
655-
<div class="cell border-box-sizing code_cell rendered">
656-
<div class="input">
657-
658-
<div class="inner_cell">
659-
<div class="input_area">
660-
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">model</span><span class="o">.</span><span class="n">stem_bn_end</span> <span class="o">=</span> <span class="kc">True</span>
661-
</pre></div>
662-
663-
</div>
664-
</div>
665-
</div>
666-
667-
</div>
668-
{% endraw %}
669-
670-
{% raw %}
671-
672-
<div class="cell border-box-sizing code_cell rendered">
673-
<div class="input">
674-
675-
<div class="inner_cell">
676-
<div class="input_area">
677-
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">model</span><span class="o">.</span><span class="n">stem</span>
678-
</pre></div>
679-
680-
</div>
681-
</div>
682-
</div>
683-
684-
<div class="output_wrapper">
685-
<div class="output">
686-
687-
<div class="output_area">
688-
689-
690-
691-
<div class="output_text output_subarea output_execute_result">
692-
<pre>Sequential(
693-
(conv_0): ConvLayer(
694-
(conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
695-
(act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
696-
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
697-
)
698-
(conv_1): ConvLayer(
699-
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
700-
(act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
701-
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
702-
)
703-
(conv_2): ConvLayer(
704-
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
705-
(act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
706-
)
707-
(stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
708-
(norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
709-
)</pre>
710-
</div>
711-
712-
</div>
713-
714-
</div>
715-
</div>
716-
717715
</div>
718716
{% endraw %}
719717

@@ -804,9 +802,6 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
804802
(conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
805803
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
806804
)
807-
(sa): SimpleSelfAttention(
808-
(conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
809-
)
810805
)
811806
(merge): LeakyReLU(negative_slope=0.01, inplace=True)
812807
)
@@ -955,23 +950,6 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
955950
</div>
956951

957952
</details>
958-
</div>
959-
{% endraw %}
960-
961-
{% raw %}
962-
963-
<div class="cell border-box-sizing code_cell rendered">
964-
<div class="input">
965-
966-
<div class="inner_cell">
967-
<div class="input_area">
968-
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">model</span><span class="o">.</span><span class="n">stem_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">]</span>
969-
</pre></div>
970-
971-
</div>
972-
</div>
973-
</div>
974-
975953
</div>
976954
{% endraw %}
977955

@@ -1027,7 +1005,11 @@ <h2 id="xresnet-constructor">xresnet constructor<a class="anchor-link" href="#xr
10271005

10281006

10291007
<div class="output_text output_subarea output_execute_result">
1030-
<pre>( constr xresnet50, 10)</pre>
1008+
<pre>( constr xresnet50
1009+
expansion: 4, sa: 0, groups: 1
1010+
stem sizes: [3, 32, 32, 64]
1011+
body sizes [16, 64, 128, 256, 512],
1012+
10)</pre>
10311013
</div>
10321014

10331015
</div>

model_constructor/net.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ def init_cnn(m):
2323
class ResBlock(nn.Module):
2424
def __init__(self, expansion, ni, nh, stride=1,
2525
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
26-
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False):
26+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False, groups=1):
2727
super().__init__()
2828
nf,ni = nh*expansion,ni*expansion
2929
layers = [(f"conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
3030
(f"conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
3131
] if expansion == 1 else [
3232
(f"conv_0",conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
33-
(f"conv_1",conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
33+
(f"conv_1",conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
34+
groups=int(nh/groups))),
3435
(f"conv_2",conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
3536
]
3637
if sa: layers.append(('sa', SimpleSelfAttention(nf,ks=1,sym=sym)))
@@ -46,15 +47,15 @@ def forward(self, x): return self.act_fn(self.convs(x) + self.idconv(self.pool(x
4647
class NewResBlock(nn.Module):
4748
def __init__(self, expansion, ni, nh, stride=1,
4849
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
49-
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False):
50+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False, groups=1):
5051
super().__init__()
5152
nf,ni = nh*expansion,ni*expansion
5253
self.reduce = noop if stride==1 else pool
5354
layers = [(f"conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st)), # stride 1 !!!
5455
(f"conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
5556
] if expansion == 1 else [
5657
(f"conv_0",conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
57-
(f"conv_1",conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st)), # stride 1 !!!
58+
(f"conv_1",conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st, groups=int(nh/groups))), # stride 1 !!!
5859
(f"conv_2",conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
5960
]
6061
if sa: layers.append(('sa', SimpleSelfAttention(nf,ks=1,sym=sym)))
@@ -83,7 +84,7 @@ def _make_layer(self,expansion,ni,nf,blocks,stride,sa):
8384
[(f"bl_{i}", self.block(expansion, ni if i==0 else nf, nf,
8485
stride if i==0 else 1, sa=sa if i==blocks-1 else False,
8586
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
86-
zero_bn=self.zero_bn, bn_1st=self.bn_1st))
87+
zero_bn=self.zero_bn, bn_1st=self.bn_1st, groups=self.groups))
8788
for i in range(blocks)]))
8889

8990
# Cell
@@ -110,7 +111,7 @@ def __init__(self, expansion=1, layers=[2,2,2,2], c_in=3, c_out=1000, name='Net'
110111
self.name = name
111112
self.c_in, self.c_out,self.expansion,self.layers = c_in,c_out,expansion,layers # todo setter for expansion
112113
self.act_fn, self.pool, self.sa = act_fn, pool, sa
113-
114+
self.groups = 1
114115

115116
self.stem_sizes = [c_in,32,32,64]
116117
self.stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -129,7 +130,7 @@ def __init__(self, expansion=1, layers=[2,2,2,2], c_in=3, c_out=1000, name='Net'
129130

130131
@property
131132
def block_szs(self):
132-
return [64//self.expansion,64,128,256,512] +[256]*(len(self.layers)-4)
133+
return [self.stem_sizes[-1]//self.expansion,64,128,256,512] +[256]*(len(self.layers)-4)
133134

134135
@property
135136
def stem(self):
@@ -154,7 +155,7 @@ def __call__(self):
154155
model.extra_repr = lambda : f"model {self.name}"
155156
return model
156157
def __repr__(self):
157-
return f" constr {self.name}"
158+
return f" constr {self.name}\n expansion: {self.expansion}, sa: {self.sa}, groups: {self.groups}\n stem sizes: {self.stem_sizes}\n body sizes {self.block_szs}"
158159

159160
# Cell
160161
# me = sys.modules[__name__]

model_constructor/twist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ConvTwist(nn.Module):
2222
groups_ch = 8
2323
def __init__(self, ni, nf,
2424
ks=3, stride=1, padding=1, bias=False,
25-
groups=1, iters=1, init_max=0.7):
25+
groups=1, iters=1, init_max=0.7, **kvargs):
2626
super().__init__()
2727
self.same = ni==nf and stride==1
2828
self.groups = ni//self.groups_ch if self.use_groups else 1
@@ -108,7 +108,7 @@ class ConvLayerTwist(ConvLayer): # replace Conv2d by Twist
108108
class NewResBlockTwist(nn.Module):
109109
def __init__(self, expansion, ni, nh, stride=1,
110110
conv_layer=ConvLayer, act_fn=act_fn, bn_1st=True,
111-
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, zero_bn=True):
111+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, zero_bn=True, **kvargs):
112112
super().__init__()
113113
nf,ni = nh*expansion,ni*expansion
114114
# conv_layer = ConvLayerTwist
@@ -134,7 +134,7 @@ def forward(self, x):
134134
class ResBlockTwist(nn.Module):
135135
def __init__(self, expansion, ni, nh, stride=1,
136136
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
137-
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False):
137+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False, **kvargs):
138138
super().__init__()
139139
nf,ni = nh*expansion,ni*expansion
140140
# conv_layer = ConvLayerTwist

0 commit comments

Comments
 (0)