Skip to content

Commit 8b0005b

Browse files
Fix all the unittest of pruning. (#346)
1 parent 39ee8eb commit 8b0005b

7 files changed

+29
-114
lines changed

paddleslim/prune/group_param.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def collect_convs(params, graph, visited={}):
5858
walker = conv2d_walker(
5959
conv_op, pruned_params=pruned_params, visited=visited)
6060
walker.prune(param, pruned_axis=0, pruned_idx=[0])
61-
groups.append(pruned_params)
61+
if len(pruned_params) > 0:
62+
groups.append(pruned_params)
6263
visited = set()
6364
uniq_groups = []
6465
for group in groups:

tests/test_auto_prune.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

tests/test_fpgm_prune.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def test_prune(self):
6363
param_shape_backup=None)
6464

6565
shapes = {
66-
"conv1_weights": (4L, 3L, 3L, 3L),
67-
"conv2_weights": (4L, 4L, 3L, 3L),
68-
"conv3_weights": (8L, 4L, 3L, 3L),
69-
"conv4_weights": (4L, 8L, 3L, 3L),
70-
"conv5_weights": (8L, 4L, 3L, 3L),
71-
"conv6_weights": (8L, 8L, 3L, 3L)
66+
"conv1_weights": (4, 3, 3, 3),
67+
"conv2_weights": (4, 4, 3, 3),
68+
"conv3_weights": (8, 4, 3, 3),
69+
"conv4_weights": (4, 8, 3, 3),
70+
"conv5_weights": (8, 4, 3, 3),
71+
"conv6_weights": (8, 8, 3, 3)
7272
}
7373

7474
for param in main_program.global_block().all_parameters():

tests/test_optimal_threshold.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def test_prune(self):
6464
param_shape_backup=None)
6565

6666
shapes = {
67-
"conv1_weights": (4L, 3L, 3L, 3L),
68-
"conv2_weights": (4L, 4L, 3L, 3L),
69-
"conv3_weights": (8L, 4L, 3L, 3L),
70-
"conv4_weights": (4L, 8L, 3L, 3L),
71-
"conv5_weights": (8L, 4L, 3L, 3L),
72-
"conv6_weights": (8L, 8L, 3L, 3L)
67+
"conv1_weights": (4, 3, 3, 3),
68+
"conv2_weights": (4, 4, 3, 3),
69+
"conv3_weights": (8, 4, 3, 3),
70+
"conv4_weights": (4, 8, 3, 3),
71+
"conv5_weights": (8, 4, 3, 3),
72+
"conv6_weights": (8, 8, 3, 3)
7373
}
7474

7575
for param in main_program.global_block().all_parameters():

tests/test_prune.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def test_prune(self):
6262
param_shape_backup=None)
6363

6464
shapes = {
65-
"conv1_weights": (4L, 3L, 3L, 3L),
66-
"conv2_weights": (4L, 4L, 3L, 3L),
67-
"conv3_weights": (8L, 4L, 3L, 3L),
68-
"conv4_weights": (4L, 8L, 3L, 3L),
69-
"conv5_weights": (8L, 4L, 3L, 3L),
70-
"conv6_weights": (8L, 8L, 3L, 3L)
65+
"conv1_weights": (4, 3, 3, 3),
66+
"conv2_weights": (4, 4, 3, 3),
67+
"conv3_weights": (8, 4, 3, 3),
68+
"conv4_weights": (4, 8, 3, 3),
69+
"conv5_weights": (8, 4, 3, 3),
70+
"conv6_weights": (8, 8, 3, 3)
7171
}
7272

7373
for param in main_program.global_block().all_parameters():

tests/test_sensitivity.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy
1818
import paddle
1919
import paddle.fluid as fluid
20-
from paddleslim.analysis import sensitivity
20+
from paddleslim.prune import sensitivity
2121
from layers import conv_bn_layer
2222

2323

@@ -47,22 +47,20 @@ def test_sensitivity(self):
4747
val_reader = paddle.fluid.io.batch(
4848
paddle.dataset.mnist.test(), batch_size=128)
4949

50-
def eval_func(program, scope):
50+
def eval_func(program):
5151
feeder = fluid.DataFeeder(
5252
feed_list=['image', 'label'], place=place, program=program)
5353
acc_set = []
5454
for data in val_reader():
5555
acc_np = exe.run(program=program,
56-
scope=scope,
5756
feed=feeder.feed(data),
5857
fetch_list=[acc_top1])
5958
acc_set.append(float(acc_np[0]))
6059
acc_val_mean = numpy.array(acc_set).mean()
6160
print("acc_val_mean: {}".format(acc_val_mean))
6261
return acc_val_mean
6362

64-
sensitivity(eval_program,
65-
fluid.global_scope(), place, ["conv4_weights"], eval_func,
63+
sensitivity(eval_program, place, ["conv4_weights"], eval_func,
6664
"./sensitivities_file")
6765

6866

tests/test_slim_prune.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def test_prune(self):
6363
param_shape_backup=None)
6464

6565
shapes = {
66-
"conv1_weights": (4L, 3L, 3L, 3L),
67-
"conv2_weights": (4L, 4L, 3L, 3L),
68-
"conv3_weights": (8L, 4L, 3L, 3L),
69-
"conv4_weights": (4L, 8L, 3L, 3L),
70-
"conv5_weights": (8L, 4L, 3L, 3L),
71-
"conv6_weights": (8L, 8L, 3L, 3L)
66+
"conv1_weights": (4, 3, 3, 3),
67+
"conv2_weights": (4, 4, 3, 3),
68+
"conv3_weights": (8, 4, 3, 3),
69+
"conv4_weights": (4, 8, 3, 3),
70+
"conv5_weights": (8, 4, 3, 3),
71+
"conv6_weights": (8, 8, 3, 3)
7272
}
7373

7474
for param in main_program.global_block().all_parameters():

0 commit comments

Comments
 (0)