Skip to content

Commit 3c90aad

Browse files
committed
Adds feature pyramid attention (FPA) module, resolves #167
1 parent 27833da commit 3c90aad

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

robosat/fpa.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Feature Pyramid Attention blocks
2+
3+
See:
4+
- https://arxiv.org/abs/1805.10180 - Pyramid Attention Network for Semantic Segmentation
5+
6+
"""
7+
8+
import torch.nn as nn
9+
10+
11+
class FeaturePyramidAttention(nn.Module):
12+
"""Feature Pyramid Attetion (FPA) block
13+
See https://arxiv.org/abs/1805.10180 Figure 3 b
14+
"""
15+
16+
def __init__(self, num_in, num_out):
17+
super().__init__()
18+
19+
# no batch norm for tensors of shape NxCx1x1
20+
self.top1x1 = nn.Sequential(nn.Conv2d(num_in, num_out, 1, bias=False), nn.ReLU(inplace=True))
21+
22+
self.mid1x1 = ConvBnRelu(num_in, num_out, 1)
23+
24+
self.bot5x5 = ConvBnRelu(num_in, num_in, 5, stride=2, padding=2)
25+
self.bot3x3 = ConvBnRelu(num_in, num_in, 3, stride=2, padding=1)
26+
27+
self.lat5x5 = ConvBnRelu(num_in, num_out, 5, stride=1, padding=2)
28+
self.lat3x3 = ConvBnRelu(num_in, num_out, 3, stride=1, padding=1)
29+
30+
def forward(self, x):
31+
assert x.size()[-1] % 8 == 0 and x.size()[-2] % 8 == 0, "size has to be divisible by 8 for fpa"
32+
33+
# global pooling top pathway
34+
top = self.top1x1(nn.functional.adaptive_avg_pool2d(x, 1))
35+
top = nn.functional.interpolate(top, size=x.size()[-2:], mode="bilinear")
36+
37+
# conv middle pathway
38+
mid = self.mid1x1(x)
39+
40+
# multi-scale bottom and lateral pathways
41+
bot0 = self.bot5x5(x)
42+
bot1 = self.bot3x3(bot0)
43+
44+
lat0 = self.lat5x5(bot0)
45+
lat1 = self.lat3x3(bot1)
46+
47+
# upward accumulation pathways
48+
up = lat0 + nn.functional.interpolate(lat1, scale_factor=2, mode="bilinear")
49+
up = nn.functional.interpolate(up, scale_factor=2, mode="bilinear")
50+
51+
return up * mid + top
52+
53+
54+
def ConvBnRelu(num_in, num_out, kernel_size, stride=1, padding=0, bias=False):
55+
return nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
56+
nn.BatchNorm2d(num_out, num_out),
57+
nn.ReLU(inplace=True))

robosat/unet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from torchvision.models import resnet50
1616

17+
from robosat.fpa import FeaturePyramidAttention
18+
1719

1820
class ConvRelu(nn.Module):
1921
"""3x3 convolution followed by ReLU activation building block.
@@ -96,6 +98,8 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
9698
# Access resnet directly in forward pass; do not store refs here due to
9799
# https://github.com/pytorch/pytorch/issues/8392
98100

101+
self.fpa = FeaturePyramidAttention(2048, 2048)
102+
99103
self.center = DecoderBlock(2048, num_filters * 8)
100104

101105
self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
@@ -129,7 +133,9 @@ def forward(self, x):
129133
enc3 = self.resnet.layer3(enc2)
130134
enc4 = self.resnet.layer4(enc3)
131135

132-
center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))
136+
fpa = self.fpa(enc4)
137+
138+
center = self.center(nn.functional.max_pool2d(fpa, kernel_size=2, stride=2))
133139

134140
dec0 = self.dec0(torch.cat([enc4, center], dim=1))
135141
dec1 = self.dec1(torch.cat([enc3, dec0], dim=1))

0 commit comments

Comments
 (0)