Skip to content

Commit aed02a2

Browse files
user3984littletomatodonkey
authored andcommitted
update pefd
1 parent 7ee8471 commit aed02a2

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

ppcls/loss/pefdloss.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Regressor(nn.Layer):
2424

2525
def __init__(self, dim_in=1024, dim_out=1024):
2626
super(Regressor, self).__init__()
27-
self.conv = nn.Conv2D(dim_in, dim_out, 1)
27+
self.conv = nn.Linear(dim_in, dim_out)
2828

2929
def forward(self, x):
3030
x = self.conv(x)
@@ -38,29 +38,38 @@ class PEFDLoss(nn.Layer):
3838
Code reference: https://github.com/chenyd7/PEFD
3939
"""
4040

41-
def __init__(self, student_channel, teacher_channel, num_projectors=3):
41+
def __init__(self,
42+
student_channel,
43+
teacher_channel,
44+
num_projectors=3,
45+
mode="flatten"):
4246
super().__init__()
4347

4448
if num_projectors <= 0:
4549
raise ValueError("Number of projectors must be greater than 0.")
4650

51+
if mode not in ["flatten", "gap"]:
52+
raise ValueError("Mode must be \"flatten\" or \"gap\".")
53+
54+
self.mode = mode
4755
self.projectors = nn.LayerList()
4856

4957
for _ in range(num_projectors):
5058
self.projectors.append(Regressor(student_channel, teacher_channel))
5159

5260
def forward(self, student_feature, teacher_feature):
53-
if student_feature.shape[2:] != teacher_feature.shape[2:]:
54-
raise ValueError(
55-
"Student feature must have the same H and W as teacher feature."
56-
)
61+
if self.mode == "gap":
62+
student_feature = F.adaptive_avg_pool2d(student_feature, (1, 1))
63+
teacher_feature = F.adaptive_avg_pool2d(teacher_feature, (1, 1))
64+
65+
student_feature = student_feature.flatten(1)
66+
f_t = teacher_feature.flatten(1)
5767

5868
q = len(self.projectors)
5969
f_s = 0.0
6070
for i in range(q):
6171
f_s += self.projectors[i](student_feature)
62-
f_s = (f_s / q).flatten(1)
63-
f_t = teacher_feature.flatten(1)
72+
f_s = f_s / q
6473

6574
# inner product (normalize first and inner product)
6675
normft = f_t.pow(2).sum(1, keepdim=True).pow(1. / 2)

0 commit comments

Comments
 (0)