Skip to content

Commit e2c6bad

Browse files
Bai Yifanwanghaoshuang
Bai Yifan
authored andcommitted
Support dispensable student_loss in PaddleSlim distillation (#19824)
* support_dispensable_student_loss, test=develop * add distillation test, test=develop * fix distillation test non convergence problem, test=develop * fix test_distillation fail problem, test=develop
1 parent 3f87464 commit e2c6bad

File tree

7 files changed

+32
-17
lines changed

7 files changed

+32
-17
lines changed

python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def _create_distillation_graph(self, context):
6464
var.stop_gradient = True
6565
graph = context.train_graph.clone()
6666
graph.merge(teacher)
67-
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
67+
if 'loss' in graph.out_nodes:
68+
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
6869

6970
# step 2
7071
for distiller in self.distillers:

python/paddle/fluid/contrib/slim/distillation/distiller.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,15 @@ def apply(self, graph):
8888
layers.square(student_feature_map - teacher_feature_map))
8989

9090
distillation_loss = l2loss * self.distillation_loss_weight
91-
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
91+
student_loss = 0
92+
if 'loss' in ret_graph.out_nodes:
93+
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
9294
loss = distillation_loss + student_loss
9395

96+
ret_graph.out_nodes['loss'] = loss.name
9497
ret_graph.out_nodes[
9598
'l2loss_' + self.student_feature_map + "_" +
9699
self.teacher_feature_map] = distillation_loss.name
97-
ret_graph.out_nodes['loss'] = loss.name
98100
return ret_graph
99101

100102

@@ -176,12 +178,14 @@ def apply(self, graph):
176178
losses.append(l2_loss)
177179
distillation_loss = layers.sum(
178180
losses) * self.distillation_loss_weight
179-
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
181+
student_loss = 0
182+
if 'loss' in ret_graph.out_nodes:
183+
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
180184
loss = distillation_loss + student_loss
181185

186+
ret_graph.out_nodes['loss'] = loss.name
182187
ret_graph.out_nodes[
183188
'fsp_distillation_loss'] = distillation_loss.name
184-
ret_graph.out_nodes['loss'] = loss.name
185189
return ret_graph
186190

187191
def _fsp_matrix(self, fea_map_0, fea_map_1):
@@ -261,16 +265,18 @@ def apply(self, graph):
261265
student_feature_map = ret_graph.var(self.student_feature_map)._var
262266
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
263267
s_fea = student_feature_map / self.student_temperature
264-
t_fea = teacher_feature_map / self.distillation_loss_weight
268+
t_fea = teacher_feature_map / self.teacher_temperature
265269
t_fea.stop_gradient = True
266270
ce_loss = layers.softmax_with_cross_entropy(
267271
s_fea, t_fea, soft_label=True)
268272
distillation_loss = ce_loss * self.distillation_loss_weight
269-
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
273+
student_loss = 0
274+
if 'loss' in ret_graph.out_nodes:
275+
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
270276
loss = distillation_loss + student_loss
271277

278+
ret_graph.out_nodes['loss'] = loss.name
272279
ret_graph.out_nodes[
273280
'soft_label_loss_' + self.student_feature_map + "_" +
274281
self.teacher_feature_map] = distillation_loss.name
275-
ret_graph.out_nodes['loss'] = loss.name
276282
return ret_graph

python/paddle/fluid/contrib/slim/graph/graph_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ def get_optimize_graph(self, optimizer, place, scope, no_grad_var_names=[]):
410410
target_name = graph.out_nodes['loss']
411411
elif 'cost' in graph.out_nodes:
412412
target_name = graph.out_nodes['cost']
413+
else:
414+
return None
413415
target = graph.var(target_name)._var
414416
# The learning rate variable may be created in other program.
415417
# Update information in optimizer to make

python/paddle/fluid/contrib/slim/tests/CMakeLists.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn
3232
--acc_diff_threshold 0.1)
3333
endfunction()
3434

35-
# NOTE: TODOOOOOOOOOOO
36-
# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs
37-
# Need to figure out the root cause and then add it back
38-
list(REMOVE_ITEM TEST_OPS test_distillation_strategy)
39-
4035
if(WIN32)
4136
list(REMOVE_ITEM TEST_OPS test_light_nas)
4237
endif()

python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ distillers:
3030
distillation_loss_weight: 1
3131
l2_distiller:
3232
class: 'L2Distiller'
33-
teacher_feature_map: 'teacher.tmp_2'
34-
student_feature_map: 'student.tmp_2'
33+
teacher_feature_map: 'teacher.tmp_1'
34+
student_feature_map: 'student.tmp_1'
3535
distillation_loss_weight: 1
3636
soft_label_distiller:
3737
class: 'SoftLabelDistiller'
3838
student_temperature: 1.0
3939
teacher_temperature: 1.0
40-
teacher_feature_map: 'teacher.tmp_1'
41-
student_feature_map: 'student.tmp_1'
40+
teacher_feature_map: 'teacher.tmp_2'
41+
student_feature_map: 'student.tmp_2'
4242
distillation_loss_weight: 0.001
4343
strategies:
4444
distillation_strategy:

python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,17 @@ def test_get_optimize_graph(self):
139139
feed={'image': image,
140140
'label': label})
141141

142+
def test_get_optimize_graph_without_loss(self):
143+
self.build_program()
144+
self.eval_graph.out_nodes = {}
145+
place = fluid.CPUPlace()
146+
if fluid.core.is_compiled_with_cuda():
147+
place = fluid.CUDAPlace(0)
148+
opt = fluid.optimizer.SGD(learning_rate=0.001)
149+
train_graph = self.eval_graph.get_optimize_graph(
150+
opt, place, self.scope, no_grad_var_names=['image'])
151+
self.assertEquals(train_graph, None)
152+
142153
def test_flops(self):
143154
self.build_program()
144155
self.assertEquals(self.train_graph.flops(), 354624)

0 commit comments

Comments
 (0)