1
1
import torch
2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
-
4
+ from utils import Lambda
5
5
6
6
class EasyMnist (nn .Module ):
7
7
@@ -13,11 +13,49 @@ def __init__(self):
13
13
14
14
def forward (self , x_batch : torch .Tensor ):
15
15
"""Simple ReLU-based activations through all layers of the DNN.
16
- Simple and effectively deep neural network. No frills.
16
+ Simple and sufficiently deep neural network. No frills.
17
17
"""
18
18
_input = x_batch .view (- 1 , 784 ) # shape for our linear1
19
19
out1 = F .relu (self .linear1 (x_batch ))
20
20
out2 = F .relu (self .linear2 (out1 ))
21
21
out3 = F .relu (self .linear3 (out2 ))
22
22
23
- return out3
23
+ return out3
24
+
25
+
26
+ # for comparison with the above
27
+ def EasyMnistSeq ():
28
+ return nn .Sequential (
29
+ Lambda (lambda x : x .reshape (- 1 , 784 )),
30
+ nn .Linear (784 , 1000 ),
31
+ nn .Relu (),
32
+ nn .Linear (1000 , 300 ),
33
+ nn .Relu (),
34
+ nn .Linear (300 , 10 ),
35
+ nn .Relu (),
36
+ )
37
+
38
+
39
+ class MnistConvNet (nn .Module ):
40
+ def __init__ (self , interim_size = 16 ):
41
+ """
42
+ A simple and shallow deep CNN to show that morph will shrink this architecture,
43
+ which will inherently be wasteful on the task of classifying MNIST digits with
44
+ accuracy above 95%.
45
+ By default produces a 1x16 -> 16x16 -> 16x10 convnet
46
+ """
47
+ super ().__init__ ()
48
+ self .conv1 = nn .Conv2d (1 , interim_size , kernel_size = 3 , stride = 2 , padding = 1 )
49
+ self .conv2 = nn .Conv2d (interim_size , interim_size , kernel_size = 3 , stride = 2 , padding = 1 )
50
+ self .conv3 = nn .Conv2d (interim_size , 10 , kernel_size = 3 , stride = 2 , padding = 1 )
51
+
52
+ def forward (self , xb ):
53
+ xb = xb .view (- 1 , 1 , 28 , 28 ) # any batch_size, 1 channel, 28x28 pixels
54
+ xb = F .relu (self .conv1 (xb ))
55
+ xb = F .relu (self .conv2 (xb ))
56
+ xb = F .relu (self .conv3 (xb ))
57
+ xb = F .avg_pool2d (xb , 4 )
58
+
59
+ # reshape the output to the second dimension of the pool size, and just fill the rest to whatever.
60
+ return xb .view (- 1 , xb .size (1 ))
61
+
0 commit comments