Skip to content

Commit 25ee526

Browse files
authored
Fix keyerror (#794)
* fix key_error in pruner * add unit test for get_ratios_by_loss
1 parent 6947a1e commit 25ee526

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

paddleslim/prune/sensitive.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,5 @@ def get_ratios_by_loss(sensitivities, loss):
209209
_logger.info(losses, ratio, (r1 - r0) / (l1 - l0), i)
210210

211211
break
212+
if i == 0: ratios[param] = 0.0
212213
return ratios

tests/test_sensitivity.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import paddle
1919
import paddle.fluid as fluid
2020
from static_case import StaticCase
21-
from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities
21+
from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities, get_ratios_by_loss
2222
from layers import conv_bn_layer
2323

2424

@@ -107,9 +107,16 @@ def eval_func_for_args(args):
107107
sensitivities_file="./sensitivities_file_2",
108108
pruned_ratios=[0.1, 0.2, 0.3, 0.4])
109109
self.assertTrue(params_sens == origin_sens)
110-
111110
self.assertTrue(sens == origin_sens)
112111

112+
loss = 0.0
113+
ratios = get_ratios_by_loss(sens, loss)
114+
self.assertTrue(len(ratios) == len(sens))
115+
116+
loss = min(list(sens.get('conv4_weights').values())) - 0.01
117+
ratios = get_ratios_by_loss(sens, loss)
118+
self.assertTrue(len(ratios) == len(sens))
119+
113120

114121
if __name__ == '__main__':
115122
unittest.main()

0 commit comments

Comments
 (0)