|
| 1 | +"""Feature Pyramid Attention blocks |
| 2 | +
|
| 3 | +See: |
| 4 | +- https://arxiv.org/abs/1805.10180 - Pyramid Attention Network for Semantic Segmentation |
| 5 | +
|
| 6 | +""" |
| 7 | + |
| 8 | +import torch.nn as nn |
| 9 | + |
| 10 | + |
| 11 | +class FeaturePyramidAttention(nn.Module): |
| 12 | + """Feature Pyramid Attetion (FPA) block |
| 13 | + See https://arxiv.org/abs/1805.10180 Figure 3 b |
| 14 | + """ |
| 15 | + |
| 16 | + def __init__(self, num_in, num_out): |
| 17 | + super().__init__() |
| 18 | + |
| 19 | + # no batch norm for tensors of shape NxCx1x1 |
| 20 | + self.top1x1 = nn.Sequential(nn.Conv2d(num_in, num_out, 1, bias=False), nn.ReLU(inplace=True)) |
| 21 | + |
| 22 | + self.mid1x1 = ConvBnRelu(num_in, num_out, 1) |
| 23 | + |
| 24 | + self.bot5x5 = ConvBnRelu(num_in, num_in, 5, stride=2, padding=2) |
| 25 | + self.bot3x3 = ConvBnRelu(num_in, num_in, 3, stride=2, padding=1) |
| 26 | + |
| 27 | + self.lat5x5 = ConvBnRelu(num_in, num_out, 5, stride=1, padding=2) |
| 28 | + self.lat3x3 = ConvBnRelu(num_in, num_out, 3, stride=1, padding=1) |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + assert x.size()[-1] % 8 == 0 and x.size()[-2] % 8 == 0, "size has to be divisible by 8 for fpa" |
| 32 | + |
| 33 | + # global pooling top pathway |
| 34 | + top = self.top1x1(nn.functional.adaptive_avg_pool2d(x, 1)) |
| 35 | + top = nn.functional.interpolate(top, size=x.size()[-2:], mode="bilinear") |
| 36 | + |
| 37 | + # conv middle pathway |
| 38 | + mid = self.mid1x1(x) |
| 39 | + |
| 40 | + # multi-scale bottom and lateral pathways |
| 41 | + bot0 = self.bot5x5(x) |
| 42 | + bot1 = self.bot3x3(bot0) |
| 43 | + |
| 44 | + lat0 = self.lat5x5(bot0) |
| 45 | + lat1 = self.lat3x3(bot1) |
| 46 | + |
| 47 | + # upward accumulation pathways |
| 48 | + up = lat0 + nn.functional.interpolate(lat1, scale_factor=2, mode="bilinear") |
| 49 | + up = nn.functional.interpolate(up, scale_factor=2, mode="bilinear") |
| 50 | + |
| 51 | + return up * mid + top |
| 52 | + |
| 53 | + |
| 54 | +def ConvBnRelu(num_in, num_out, kernel_size, stride=1, padding=0, bias=False): |
| 55 | + return nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), |
| 56 | + nn.BatchNorm2d(num_out, num_out), |
| 57 | + nn.ReLU(inplace=True)) |
0 commit comments