Skip to content

Commit 0cfccd7

Browse files
committed
add weighted soft labels loss function
add weighted soft labels loss function
1 parent 104428d commit 0cfccd7

File tree

5 files changed

+317
-1
lines changed

5 files changed

+317
-1
lines changed

docs/zh_CN/advanced_tutorials/knowledge_distillation.md

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- [1.2.5 DKD](#1.2.5)
1717
- [1.2.6 DIST](#1.2.6)
1818
- [1.2.7 MGD](#1.2.7)
19+
- [1.2.8 WSL](#1.2.8)
1920
- [2. 使用方法](#2)
2021
- [2.1 环境配置](#2.1)
2122
- [2.2 数据准备](#2.2)
@@ -399,7 +400,7 @@ DKD将蒸馏中常用的 KD Loss 进行了解耦成为Target Class Knowledge Dis
399400
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
400401
| --- | --- | --- | --- | --- |
401402
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
402-
| AFD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
403+
| DKD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
403404

404405

405406
##### 1.2.5.2 DKD 配置
@@ -583,6 +584,73 @@ Loss:
583584
weight: 1.0
584585
```
585586

587+
<a name='1.2.8'></a>
588+
589+
#### 1.2.8 WSL
590+
591+
##### 1.2.8.1 WSL 算法介绍
592+
593+
论文信息:
594+
595+
596+
> [Rethinking Soft Labels For Knowledge Distillation: A Bias-variance Tradeoff Perspective](https://arxiv.org/abs/2102.0650)
597+
>
598+
> Helong Zhou, Liangchen Song, Jiajie Chen, Ye Zhou, Guoli Wang, Junsong Yuan, Qian Zhang
599+
>
600+
> ICLR, 2021
601+
602+
WSL (Weighted Soft Labels) 损失函数根据教师模型与学生模型关于真值标签的 CE Loss 比值,对每个样本的 KD Loss 分别赋予权重。若学生模型相对教师模型在某个样本上预测结果更好,则对该样本赋予较小的权重。该方法简单、有效,使各个样本的权重可自适应调节,提升了蒸馏精度。
603+
604+
在ImageNet1k公开数据集上,效果如下所示。
605+
606+
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
607+
| --- | --- | --- | --- | --- |
608+
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
609+
| WSL | ResNet18 | [resnet34_distill_resnet18_wsl.yaml](../../../ppcls/configs/ImageNet/Distillation/esnet34_distill_resnet18_wsl.yaml) | 72.23%(**+1.43%**) | - |
610+
611+
612+
##### 1.2.8.2 WSL 配置
613+
614+
WSL 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationGTCELoss`(学生与真值标签之间的CE loss)以及`DistillationWSLLoss`(学生与教师之间的WSL loss),作为训练的损失函数。
615+
616+
617+
```yaml
618+
# model architecture
619+
Arch:
620+
name: "DistillationModel"
621+
# if not null, its lengths should be same as models
622+
pretrained_list:
623+
# if not null, its lengths should be same as models
624+
freeze_params_list:
625+
- True
626+
- False
627+
models:
628+
- Teacher:
629+
name: ResNet34
630+
pretrained: True
631+
632+
- Student:
633+
name: ResNet18
634+
pretrained: False
635+
636+
infer_model_name: "Student"
637+
638+
639+
# loss function config for traing/eval process
640+
Loss:
641+
Train:
642+
- DistillationGTCELoss:
643+
weight: 1.0
644+
model_names: ["Student"]
645+
- DistillationWSLLoss:
646+
weight: 2.5
647+
model_name_pairs: [["Student", "Teacher"]]
648+
temperature: 2
649+
Eval:
650+
- CELoss:
651+
weight: 1.0
652+
```
653+
586654
<a name="2"></a>
587655

588656
## 2. 模型训练、评估和预测
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# global configs
2+
Global:
3+
checkpoints: null
4+
pretrained_model: null
5+
output_dir: ./output/r34_r18_wsl
6+
device: "gpu"
7+
save_interval: 1
8+
eval_during_train: True
9+
eval_interval: 1
10+
epochs: 100
11+
print_batch_step: 10
12+
use_visualdl: False
13+
# used for static mode and model export
14+
image_shape: [3, 224, 224]
15+
save_inference_dir: "./inference"
16+
17+
# model architecture
18+
Arch:
19+
name: "DistillationModel"
20+
# if not null, its lengths should be same as models
21+
pretrained_list:
22+
# if not null, its lengths should be same as models
23+
freeze_params_list:
24+
- True
25+
- False
26+
models:
27+
- Teacher:
28+
name: ResNet34
29+
pretrained: True
30+
31+
- Student:
32+
name: ResNet18
33+
pretrained: False
34+
35+
infer_model_name: "Student"
36+
37+
38+
# loss function config for traing/eval process
39+
Loss:
40+
Train:
41+
- DistillationGTCELoss:
42+
weight: 1.0
43+
model_names: ["Student"]
44+
- DistillationWSLLoss:
45+
weight: 2.5
46+
model_name_pairs: [["Student", "Teacher"]]
47+
temperature: 2
48+
Eval:
49+
- CELoss:
50+
weight: 1.0
51+
52+
53+
Optimizer:
54+
name: Momentum
55+
momentum: 0.9
56+
weight_decay: 1e-4
57+
lr:
58+
name: MultiStepDecay
59+
learning_rate: 0.1
60+
milestones: [30, 60, 90]
61+
step_each_epoch: 1
62+
gamma: 0.1
63+
64+
65+
# data loader for train and eval
66+
DataLoader:
67+
Train:
68+
dataset:
69+
name: ImageNetDataset
70+
image_root: "./dataset/ILSVRC2012/"
71+
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
72+
transform_ops:
73+
- DecodeImage:
74+
to_rgb: True
75+
channel_first: False
76+
- RandCropImage:
77+
size: 224
78+
- RandFlipImage:
79+
flip_code: 1
80+
- NormalizeImage:
81+
scale: 0.00392157
82+
mean: [0.485, 0.456, 0.406]
83+
std: [0.229, 0.224, 0.225]
84+
order: ''
85+
86+
sampler:
87+
name: DistributedBatchSampler
88+
batch_size: 64
89+
drop_last: False
90+
shuffle: True
91+
loader:
92+
num_workers: 8
93+
use_shared_memory: True
94+
95+
Eval:
96+
dataset:
97+
name: ImageNetDataset
98+
image_root: "./dataset/ILSVRC2012/"
99+
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
100+
transform_ops:
101+
- DecodeImage:
102+
to_rgb: True
103+
channel_first: False
104+
- ResizeImage:
105+
resize_short: 256
106+
- CropImage:
107+
size: 224
108+
- NormalizeImage:
109+
scale: 0.00392157
110+
mean: [0.485, 0.456, 0.406]
111+
std: [0.229, 0.224, 0.225]
112+
order: ''
113+
sampler:
114+
name: DistributedBatchSampler
115+
batch_size: 64
116+
drop_last: False
117+
shuffle: False
118+
loader:
119+
num_workers: 4
120+
use_shared_memory: True
121+
122+
Infer:
123+
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
124+
batch_size: 10
125+
transforms:
126+
- DecodeImage:
127+
to_rgb: True
128+
channel_first: False
129+
- ResizeImage:
130+
resize_short: 256
131+
- CropImage:
132+
size: 224
133+
- NormalizeImage:
134+
scale: 1.0/255.0
135+
mean: [0.485, 0.456, 0.406]
136+
std: [0.229, 0.224, 0.225]
137+
order: ''
138+
- ToCHWImage:
139+
PostProcess:
140+
name: Topk
141+
topk: 5
142+
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
143+
144+
Metric:
145+
Train:
146+
- DistillationTopkAcc:
147+
model_key: "Student"
148+
topk: [1, 5]
149+
Eval:
150+
- DistillationTopkAcc:
151+
model_key: "Student"
152+
topk: [1, 5]

ppcls/loss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .distillationloss import DistillationRKDLoss
2626
from .distillationloss import DistillationKLDivLoss
2727
from .distillationloss import DistillationDKDLoss
28+
from .distillationloss import DistillationWSLLoss
2829
from .distillationloss import DistillationMultiLabelLoss
2930
from .distillationloss import DistillationDISTLoss
3031
from .distillationloss import DistillationPairLoss

ppcls/loss/distillationloss.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .rkdloss import RKdAngle, RkdDistance
2323
from .kldivloss import KLDivLoss
2424
from .dkdloss import DKDLoss
25+
from .wslloss import WSLLoss
2526
from .dist_loss import DISTLoss
2627
from .multilabelloss import MultiLabelLoss
2728
from .mgd_loss import MGDLoss
@@ -262,6 +263,34 @@ def forward(self, predicts, batch):
262263
return loss_dict
263264

264265

266+
class DistillationWSLLoss(WSLLoss):
267+
"""
268+
DistillationWSLLoss
269+
"""
270+
271+
def __init__(self,
272+
model_name_pairs=[],
273+
key=None,
274+
temperature=2.0,
275+
name="wsl_loss"):
276+
super().__init__(temperature)
277+
self.model_name_pairs = model_name_pairs
278+
self.key = key
279+
self.name = name
280+
281+
def forward(self, predicts, batch):
282+
loss_dict = dict()
283+
for idx, pair in enumerate(self.model_name_pairs):
284+
out1 = predicts[pair[0]]
285+
out2 = predicts[pair[1]]
286+
if self.key is not None:
287+
out1 = out1[self.key]
288+
out2 = out2[self.key]
289+
loss = super().forward(out1, out2, batch)
290+
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
291+
return loss_dict
292+
293+
265294
class DistillationMultiLabelLoss(MultiLabelLoss):
266295
"""
267296
DistillationMultiLabelLoss

ppcls/loss/wslloss.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
20+
class WSLLoss(nn.Layer):
21+
"""
22+
Weighted Soft Labels Loss
23+
paper: https://arxiv.org/pdf/2102.00650.pdf
24+
code reference: https://github.com/bellymonster/Weighted-Soft-Label-Distillation
25+
"""
26+
27+
def __init__(self, temperature=2.0, use_target_as_gt=False):
28+
super().__init__()
29+
self.temperature = temperature
30+
self.use_target_as_gt = use_target_as_gt
31+
32+
def forward(self, logits_student, logits_teacher, target=None):
33+
"""Compute weighted soft labels loss.
34+
Args:
35+
logits_student: student's logits with shape (batch_size, num_classes)
36+
logits_teacher: teacher's logits with shape (batch_size, num_classes)
37+
target: ground truth labels with shape (batch_size)
38+
"""
39+
if target is None or self.use_target_as_gt:
40+
target = logits_teacher.argmax(axis=-1)
41+
42+
target = F.one_hot(
43+
target.reshape([-1]), num_classes=logits_student[0].shape[0])
44+
45+
s_input_for_softmax = logits_student / self.temperature
46+
t_input_for_softmax = logits_teacher / self.temperature
47+
48+
ce_loss_s = -paddle.sum(target *
49+
F.log_softmax(logits_student.detach()),
50+
axis=1)
51+
ce_loss_t = -paddle.sum(target *
52+
F.log_softmax(logits_teacher.detach()),
53+
axis=1)
54+
55+
ratio = ce_loss_s / (ce_loss_t + 1e-7)
56+
ratio = paddle.maximum(ratio, paddle.zeros_like(ratio))
57+
58+
kd_loss = -paddle.sum(F.softmax(t_input_for_softmax) *
59+
F.log_softmax(s_input_for_softmax),
60+
axis=1)
61+
weight = 1 - paddle.exp(-ratio)
62+
63+
weighted_kd_loss = (self.temperature**2) * paddle.mean(kd_loss *
64+
weight)
65+
66+
return weighted_kd_loss

0 commit comments

Comments
 (0)