Skip to content

Commit 5e14afc

Browse files
committed
Add two MNIST models for testing and demos
1. simple sequential linear model for testing 2. Simple 3-layer conv net, with a brief explanation of project goals
1 parent 57d0433 commit 5e14afc

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

morph/testing/models.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
4+
from utils import Lambda
55

66
class EasyMnist(nn.Module):
77

@@ -13,11 +13,49 @@ def __init__(self):
1313

1414
def forward(self, x_batch: torch.Tensor):
1515
"""Simple ReLU-based activations through all layers of the DNN.
16-
Simple and effectively deep neural network. No frills.
16+
Simple and sufficiently deep neural network. No frills.
1717
"""
1818
_input = x_batch.view(-1, 784) # shape for our linear1
1919
out1 = F.relu(self.linear1(x_batch))
2020
out2 = F.relu(self.linear2(out1))
2121
out3 = F.relu(self.linear3(out2))
2222

23-
return out3
23+
return out3
24+
25+
26+
# for comparison with the above
27+
def EasyMnistSeq():
28+
return nn.Sequential(
29+
Lambda(lambda x: x.reshape(-1, 784)),
30+
nn.Linear(784, 1000),
31+
nn.Relu(),
32+
nn.Linear(1000, 300),
33+
nn.Relu(),
34+
nn.Linear(300, 10),
35+
nn.Relu(),
36+
)
37+
38+
39+
class MnistConvNet(nn.Module):
40+
def __init__(self, interim_size=16):
41+
"""
42+
A simple and shallow deep CNN to show that morph will shrink this architecture,
43+
which will inherently be wasteful on the task of classifying MNIST digits with
44+
accuracy above 95%.
45+
By default produces a 1x16 -> 16x16 -> 16x10 convnet
46+
"""
47+
super().__init__()
48+
self.conv1 = nn.Conv2d(1, interim_size, kernel_size=3, stride=2, padding=1)
49+
self.conv2 = nn.Conv2d(interim_size, interim_size, kernel_size=3, stride=2, padding=1)
50+
self.conv3 = nn.Conv2d(interim_size, 10, kernel_size=3, stride=2, padding=1)
51+
52+
def forward(self, xb):
53+
xb = xb.view(-1, 1, 28, 28) # any batch_size, 1 channel, 28x28 pixels
54+
xb = F.relu(self.conv1(xb))
55+
xb = F.relu(self.conv2(xb))
56+
xb = F.relu(self.conv3(xb))
57+
xb = F.avg_pool2d(xb, 4)
58+
59+
# reshape the output to the second dimension of the pool size, and just fill the rest to whatever.
60+
return xb.view(-1, xb.size(1))
61+

morph/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._error import ValidationError
2-
2+
import torch.nn as nn
33

44
def check(pred: bool, message='Validation failed'):
55
if not pred: raise ValidationError(message)
@@ -9,4 +9,14 @@ def round(value: float) -> int:
99
"""Rounds a `value` up to the next integer if possible.
1010
Performs differently from the standard Python `round`
1111
"""
12-
return int(value + .5)
12+
return int(value + .5)
13+
14+
15+
# courtesy of https://pytorch.org/tutorials/beginner/nn_tutorial.html#nn-sequential
16+
class Lambda(nn.Module):
17+
def __init__(self, func):
18+
super().__init__()
19+
self.func = func
20+
21+
def forward(self, x):
22+
return self.func(x)

0 commit comments

Comments
 (0)