Skip to content

Commit 1b6fddc

Browse files
committed
support stgcn trainer when classes < 5
1 parent f98c757 commit 1b6fddc

File tree

5 files changed

+34
-11
lines changed

5 files changed

+34
-11
lines changed

paddlevideo/metrics/skeleton_metric.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,29 @@ class SkeletonMetric(BaseMetric):
3131
Args:
3232
out_file: str, file to save test results.
3333
"""
34+
3435
def __init__(self,
3536
data_size,
3637
batch_size,
3738
out_file='submission.csv',
38-
log_interval=1):
39+
log_interval=1,
40+
top_k=5):
3941
"""prepare for metrics
4042
"""
4143
super().__init__(data_size, batch_size, log_interval)
4244
self.top1 = []
4345
self.top5 = []
4446
self.values = []
4547
self.out_file = out_file
48+
self.k = top_k
4649

4750
def update(self, batch_id, data, outputs):
4851
"""update metrics during each iter
4952
"""
5053
if len(data) == 2: # data with label
5154
labels = data[1]
5255
top1 = paddle.metric.accuracy(input=outputs, label=labels, k=1)
53-
top5 = paddle.metric.accuracy(input=outputs, label=labels, k=2)
56+
top5 = paddle.metric.accuracy(input=outputs, label=labels, k=self.k)
5457
if self.world_size > 1:
5558
top1 = paddle.distributed.all_reduce(
5659
top1, op=paddle.distributed.ReduceOp.SUM) / self.world_size

paddlevideo/modeling/backbones/stgcn.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def normalize_digraph(A):
6868

6969

7070
class Graph():
71+
7172
def __init__(self,
7273
layout='openpose',
7374
strategy='uniform',
@@ -109,16 +110,15 @@ def get_edge(self, layout):
109110
neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
110111
self.edge = self_link + neighbor_link
111112
self.center = 21 - 1
112-
elif layout == 'ntu-rgb+d_fall':
113+
elif layout == 'coco_keypoint':
113114
self.num_node = 17
114115
self_link = [(i, i) for i in range(self.num_node)]
115-
neighbor_1base = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
116-
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
117-
(13, 15), (14, 16), (11, 12)]
116+
neighbor_1base = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6),
117+
(5, 7), (6, 8), (7, 9), (8, 10), (5, 11), (6, 12),
118+
(11, 13), (12, 14), (13, 15), (14, 16), (11, 12)]
118119
neighbor_link = [(i, j) for (i, j) in neighbor_1base]
119120
self.edge = self_link + neighbor_link
120121
self.center = 11
121-
122122
else:
123123
raise ValueError("Do Not Exist This Layout.")
124124

@@ -158,6 +158,7 @@ def get_adjacency(self, strategy):
158158

159159

160160
class ConvTemporalGraphical(nn.Layer):
161+
161162
def __init__(self,
162163
in_channels,
163164
out_channels,
@@ -188,6 +189,7 @@ def forward(self, x, A):
188189

189190

190191
class st_gcn_block(nn.Layer):
192+
191193
def __init__(self,
192194
in_channels,
193195
out_channels,
@@ -252,6 +254,7 @@ class STGCN(nn.Layer):
252254
edge_importance_weighting: bool, whether to use edge attention. Default True.
253255
data_bn: bool, whether to use data BatchNorm. Default True.
254256
"""
257+
255258
def __init__(self,
256259
in_channels=2,
257260
edge_importance_weighting=True,

paddlevideo/modeling/framework/recognizers/recognizer_gcn.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@
2121
class RecognizerGCN(BaseRecognizer):
2222
"""GCN Recognizer model framework.
2323
"""
24+
25+
def __init__(self,
26+
backbone=None,
27+
head=None,
28+
runtime_cfg=None,
29+
if_top5=True):
30+
"""
31+
Args:
32+
backbone (dict): Backbone modules to extract feature.
33+
head (dict): Classification head to process feature.
34+
is_top5 (bool): Whether to display top-5 accuracy during training/validation steps.
35+
"""
36+
super(RecognizerGCN, self).__init__(backbone, head, runtime_cfg)
37+
self.if_top5 = if_top5
38+
2439
def forward_net(self, data):
2540
"""Define how the model is going to run, from input to output.
2641
"""
@@ -36,7 +51,7 @@ def train_step(self, data_batch):
3651

3752
# call forward
3853
cls_score = self.forward_net(data)
39-
loss_metrics = self.head.loss(cls_score, label)
54+
loss_metrics = self.head.loss(cls_score, label, if_top5=self.if_top5)
4055
return loss_metrics
4156

4257
def val_step(self, data_batch):
@@ -47,7 +62,10 @@ def val_step(self, data_batch):
4762

4863
# call forward
4964
cls_score = self.forward_net(data)
50-
loss_metrics = self.head.loss(cls_score, label, valid_mode=True)
65+
loss_metrics = self.head.loss(cls_score,
66+
label,
67+
valid_mode=True,
68+
if_top5=self.if_top5)
5169
return loss_metrics
5270

5371
def test_step(self, data_batch):

paddlevideo/modeling/heads/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def label_smooth_loss(self, scores, labels, **kwargs):
155155
def get_acc(self, scores, labels, valid_mode, if_top5=True):
156156
if if_top5:
157157
top1 = paddle.metric.accuracy(input=scores, label=labels, k=1)
158-
top5 = paddle.metric.accuracy(input=scores, label=labels, k=2)
158+
top5 = paddle.metric.accuracy(input=scores, label=labels, k=5)
159159
_, world_size = get_dist_info()
160160
#NOTE(shipping): deal with multi cards validate
161161
if world_size > 1 and valid_mode: #reduce sum when valid

paddlevideo/tasks/test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,4 @@ def test_model(cfg, weights, parallel=True):
8484
else:
8585
outputs = model(data, mode='test')
8686
Metric.update(batch_id, data, outputs)
87-
print(outputs)
8887
Metric.accumulate()

0 commit comments

Comments
 (0)