Skip to content

Commit 7f00fc7

Browse files
committed
Adds squeeze and excitation (scSE) modules, resolves mapbox#157
1 parent 54e20dc commit 7f00fc7

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

robosat/scse.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Squeeze and Excitation blocks - attention for classification and segmentation
2+
3+
See:
4+
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation Networks
5+
- https://arxiv.org/abs/1803.02579 - Concurrent Spatial and Channel 'Squeeze & Excitation' in Fully Convolutional Networks
6+
7+
"""
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
13+
class SpatialSqChannelEx:
14+
"""Spatial Squeeze and Channel Excitation (cSE) block
15+
See https://arxiv.org/abs/1803.02579 Figure 1 b
16+
"""
17+
18+
def __init__(self, num_in, r):
19+
super().__init__()
20+
self.fc0 = Conv1x1(num_in, num_in // r)
21+
self.fc1 = Conv1x1(num_in // r, num_in)
22+
23+
def forward(self, x):
24+
xx = nn.functional.adaptive_avg_pool2d(x, 1)
25+
xx = self.fc0(xx)
26+
xx = nn.functional.relu(xx, inplace=True)
27+
xx = self.fc1(xx)
28+
xx = nn.functional.sigmoid(xx)
29+
return x * xx
30+
31+
32+
class ChannelSqSpatialEx:
33+
"""Channel Squeeze and Spatial Excitation (sSE) block
34+
See https://arxiv.org/abs/1803.02579 Figure 1 c
35+
"""
36+
37+
def __init__(self, num_in):
38+
super().__init__()
39+
self.conv = Conv1x1(num_in, 1)
40+
41+
def forward(self, x):
42+
xx = self.conv(x)
43+
xx = nn.functional.sigmoid(xx)
44+
return x * xx
45+
46+
47+
class SpatialChannelSqChannelEx:
48+
"""Concurrent Spatial and Channel Squeeze and Channel Excitation (csSE) block
49+
See https://arxiv.org/abs/1803.02579 Figure 1 d
50+
"""
51+
52+
def __init__(self, num_in, r=16):
53+
super().__init__()
54+
55+
self.cse = SpatialSqChannelEx(num_in, r)
56+
self.sse = ChannelSqSpatialEx(num_in)
57+
58+
def forward(self, x):
59+
return self.cse(x) + self.sse(x)
60+
61+
62+
def Conv1x1(num_in, num_out):
63+
return nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)

robosat/unet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from torchvision.models import resnet50
1616

17+
from robosat.scse import SpatialChannelSqChannelEx
18+
1719

1820
class ConvRelu(nn.Module):
1921
"""3x3 convolution followed by ReLU activation building block.

0 commit comments

Comments
 (0)