Skip to content

Commit 853e1d0

Browse files
authored
Cherry pick uniform sparsity (#956)
* [cherry-pick][Unstructured_prune] add local_sparsity (#916) * [cherry-pick][unstructured_pruner] add local_sparsity args in demo (#920) * [Unstructured_prune] add local_sparsity demo
1 parent 7c96fd4 commit 853e1d0

File tree

8 files changed

+81
-26
lines changed

8 files changed

+81
-26
lines changed

demo/dygraph/unstructured_pruning/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
5151
add_arg('pruning_strategy', str, 'base', "Which training strategy to use in pruning, we only support base and gmp for now. Default: base")
5252
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
53+
add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False")
5354
# yapf: enable
5455

5556

@@ -96,12 +97,14 @@ def create_unstructured_pruner(model, args, configs=None):
9697
mode=args.pruning_mode,
9798
ratio=args.ratio,
9899
threshold=args.threshold,
99-
prune_params_type=args.prune_params_type)
100+
prune_params_type=args.prune_params_type,
101+
local_sparsity=args.local_sparsity)
100102
else:
101103
return GMPUnstructuredPruner(
102104
model,
103105
ratio=args.ratio,
104106
prune_params_type=args.prune_params_type,
107+
local_sparsity=args.local_sparsity,
105108
configs=configs)
106109

107110

@@ -270,7 +273,6 @@ def train(epoch):
270273
train_reader_cost = 0.0
271274
train_run_cost = 0.0
272275
total_samples = 0
273-
274276
reader_start = time.time()
275277

276278
for i in range(args.last_epoch + 1, args.num_epochs):

demo/unstructured_prune/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
add_arg('pruning_steps', int, 120, "How many times you want to increase your ratio during training. Default: 120")
5050
add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
5151
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
52+
add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False")
5253
# yapf: enable
5354

5455
model_list = models.__all__
@@ -96,13 +97,15 @@ def create_unstructured_pruner(train_program, args, place, configs):
9697
ratio=args.ratio,
9798
threshold=args.threshold,
9899
prune_params_type=args.prune_params_type,
99-
place=place)
100+
place=place,
101+
local_sparsity=args.local_sparsity)
100102
else:
101103
return GMPUnstructuredPruner(
102104
train_program,
103105
ratio=args.ratio,
104106
prune_params_type=args.prune_params_type,
105107
place=place,
108+
local_sparsity=args.local_sparsity,
106109
configs=configs)
107110

108111

docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
UnstructuredPruner
55
----------
66

7-
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None)
7+
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False)
88
99
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
1010

@@ -44,6 +44,8 @@ UnstructuredPruner
4444
4545
..
4646
47+
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
48+
4749
**返回:** 一个UnstructuredPruner类的实例。
4850

4951
**示例代码:**
@@ -203,7 +205,7 @@ GMPUnstructuredPruner
203205

204206
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
205207

206-
.. py:class:: paddleslim.GMPUnstructuredPruner(model, ratio=0.55, prune_params_type=None, skip_params_func=None, configs=None)
208+
.. py:class:: paddleslim.GMPUnstructuredPruner(model, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False, configs=None)
207209
208210
该类是UnstructuredPruner的一个子类,通过覆盖step()方法,优化了训练策略,使稀疏化训练更易恢复到稠密模型精度。其他方法均继承自父类。
209211

@@ -213,6 +215,7 @@ GMPUnstructuredPruner
213215
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
214216
- **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。
215217
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
218+
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
216219
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。各参数介绍如下:
217220

218221
.. code-block:: python

docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
UnstrucuturedPruner
55
----------
66

7-
.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None)
7+
.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None, local_sparsity=False)
88
99
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
1010

@@ -43,6 +43,7 @@ UnstrucuturedPruner
4343
4444
..
4545
46+
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
4647

4748
**返回:** 一个UnstructuredPruner类的实例
4849

@@ -280,7 +281,7 @@ UnstrucuturedPruner
280281
GMPUnstrucuturedPruner
281282
----------
282283

283-
.. py:class:: paddleslim.prune.GMPUnstructuredPruner(program, ratio=0.55, scope=None, place=None, prune_params_type=None, skip_params_func=None, configs=None)
284+
.. py:class:: paddleslim.prune.GMPUnstructuredPruner(program, ratio=0.55, scope=None, place=None, prune_params_type=None, skip_params_func=None, local_sparsity=False, configs=None)
284285
285286
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
286287

@@ -294,6 +295,7 @@ GMPUnstrucuturedPruner
294295
- **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。
295296
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。
296297
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
298+
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
297299
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。具体描述如下:
298300

299301
.. code-block:: python

paddleslim/dygraph/prune/unstructured_pruner.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class UnstructuredPruner():
2424
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55
2525
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
2626
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
27+
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
2728
"""
2829

2930
def __init__(self,
@@ -32,14 +33,19 @@ def __init__(self,
3233
threshold=0.01,
3334
ratio=0.55,
3435
prune_params_type=None,
35-
skip_params_func=None):
36+
skip_params_func=None,
37+
local_sparsity=False):
3638
assert mode in ('ratio', 'threshold'
3739
), "mode must be selected from 'ratio' and 'threshold'"
3840
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
41+
if local_sparsity:
42+
assert mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly."
3943
self.model = model
4044
self.mode = mode
4145
self.threshold = threshold
4246
self.ratio = ratio
47+
self.local_sparsity = local_sparsity
48+
self.thresholds = {}
4349

4450
# Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params
4551
if skip_params_func is not None:
@@ -91,11 +97,19 @@ def update_threshold(self):
9197
continue
9298
t_param = param.value().get_tensor()
9399
v_param = np.array(t_param)
94-
params_flatten.append(v_param.flatten())
95-
params_flatten = np.concatenate(params_flatten, axis=0)
96-
total_length = params_flatten.size
97-
self.threshold = np.sort(np.abs(params_flatten))[max(
98-
0, round(self.ratio * total_length) - 1)].item()
100+
if self.local_sparsity:
101+
flatten_v_param = v_param.flatten()
102+
cur_length = flatten_v_param.size
103+
cur_threshold = np.sort(np.abs(flatten_v_param))[max(
104+
0, round(self.ratio * cur_length) - 1)].item()
105+
self.thresholds[param.name] = cur_threshold
106+
else:
107+
params_flatten.append(v_param.flatten())
108+
if not self.local_sparsity:
109+
params_flatten = np.concatenate(params_flatten, axis=0)
110+
total_length = params_flatten.size
111+
self.threshold = np.sort(np.abs(params_flatten))[max(
112+
0, round(self.ratio * total_length) - 1)].item()
99113

100114
def _update_masks(self):
101115
for name, sub_layer in self.model.named_sublayers():
@@ -105,7 +119,11 @@ def _update_masks(self):
105119
if param.name in self.skip_params:
106120
continue
107121
mask = self.masks.get(param.name)
108-
bool_tmp = (paddle.abs(param) >= self.threshold)
122+
if self.local_sparsity:
123+
bool_tmp = (
124+
paddle.abs(param) >= self.thresholds[param.name])
125+
else:
126+
bool_tmp = (paddle.abs(param) >= self.threshold)
109127
paddle.assign(bool_tmp, output=mask)
110128

111129
def summarize_weights(self, model, ratio=0.1):
@@ -248,6 +266,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
248266
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55
249267
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
250268
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
269+
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
251270
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None
252271
253272
.. code-block:: python
@@ -268,11 +287,13 @@ def __init__(self,
268287
ratio=0.55,
269288
prune_params_type=None,
270289
skip_params_func=None,
290+
local_sparsity=False,
271291
configs=None):
272292

273293
assert configs is not None, "Configs must be passed in for GMP pruner."
274294
super(GMPUnstructuredPruner, self).__init__(
275-
model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func)
295+
model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func,
296+
local_sparsity)
276297
self.stable_iterations = configs.get('stable_iterations')
277298
self.pruning_iterations = configs.get('pruning_iterations')
278299
self.tunning_iterations = configs.get('tunning_iterations')

paddleslim/prune/unstructured_pruner.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class UnstructuredPruner():
2020
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
2121
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
2222
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None
23+
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
2324
"""
2425

2526
def __init__(self,
@@ -30,15 +31,19 @@ def __init__(self,
3031
scope=None,
3132
place=None,
3233
prune_params_type=None,
33-
skip_params_func=None):
34+
skip_params_func=None,
35+
local_sparsity=False):
3436
self.mode = mode
3537
self.ratio = ratio
3638
self.threshold = threshold
39+
self.local_sparsity = local_sparsity
40+
self.thresholds = {}
3741
assert self.mode in [
3842
'ratio', 'threshold'
3943
], "mode must be selected from 'ratio' and 'threshold'"
4044
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
41-
45+
if self.local_sparsity:
46+
assert self.mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly."
4247
self.scope = paddle.static.global_scope() if scope == None else scope
4348
self.place = paddle.static.cpu_places()[0] if place is None else place
4449

@@ -156,14 +161,20 @@ def update_threshold(self):
156161
continue
157162
t_param = self.scope.find_var(param).get_tensor()
158163
v_param = np.array(t_param)
159-
params_flatten.append(v_param.flatten())
160-
params_flatten = np.concatenate(params_flatten, axis=0)
161-
self.threshold = self._partition_sort(params_flatten)
164+
if self.local_sparsity:
165+
cur_threshold = self._partition_sort(v_param.flatten())
166+
self.thresholds[param] = cur_threshold
167+
else:
168+
params_flatten.append(v_param.flatten())
169+
if not self.local_sparsity:
170+
params_flatten = np.concatenate(params_flatten, axis=0)
171+
self.threshold = self._partition_sort(params_flatten)
162172

163173
def _partition_sort(self, params):
164174
total_len = len(params)
165175
params_zeros = params[params == 0]
166176
params_nonzeros = params[params != 0]
177+
if len(params_nonzeros) == 0: return 0
167178
new_ratio = max((self.ratio * total_len - len(params_zeros)),
168179
0) / len(params_nonzeros)
169180
return np.sort(np.abs(params_nonzeros))[max(
@@ -177,7 +188,10 @@ def _update_masks(self):
177188
t_param = self.scope.find_var(param).get_tensor()
178189
t_mask = self.scope.find_var(mask_name).get_tensor()
179190
v_param = np.array(t_param)
180-
v_param[np.abs(v_param) < self.threshold] = 0
191+
if self.local_sparsity:
192+
v_param[np.abs(v_param) < self.thresholds[param]] = 0
193+
else:
194+
v_param[np.abs(v_param) < self.threshold] = 0
181195
v_mask = (v_param != 0).astype(v_param.dtype)
182196
t_mask.set(v_mask, self.place)
183197

@@ -240,6 +254,7 @@ def _get_skip_params(self, program):
240254
if 'norm' in op.type() and 'grad' not in op.type():
241255
for input in op.all_inputs():
242256
skip_params.add(input.name())
257+
print(skip_params)
243258
return skip_params
244259

245260
def _get_skip_params_conv1x1(self, program):
@@ -296,6 +311,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
296311
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
297312
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
298313
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None
314+
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
299315
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below:
300316
301317
.. code-block:: python
@@ -317,12 +333,13 @@ def __init__(self,
317333
place=None,
318334
prune_params_type=None,
319335
skip_params_func=None,
336+
local_sparsity=False,
320337
configs=None):
321338
assert configs is not None, "Please pass in a valid config dictionary."
322339

323340
super(GMPUnstructuredPruner, self).__init__(
324341
program, 'ratio', ratio, 0.0, scope, place, prune_params_type,
325-
skip_params_func)
342+
skip_params_func, local_sparsity)
326343
self.stable_iterations = configs.get('stable_iterations')
327344
self.pruning_iterations = configs.get('pruning_iterations')
328345
self.tunning_iterations = configs.get('tunning_iterations')

tests/dygraph/test_unstructured_prune.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ def __init__(self, *args, **kwargs):
1616
def _gen_model(self):
1717
self.net = mobilenet_v1(num_classes=10, pretrained=False)
1818
self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False)
19-
self.pruner = UnstructuredPruner(self.net, mode='ratio', ratio=0.55)
19+
self.pruner = UnstructuredPruner(
20+
self.net, mode='ratio', ratio=0.55, local_sparsity=True)
2021
self.pruner_conv1x1 = UnstructuredPruner(
2122
self.net_conv1x1,
2223
mode='ratio',
2324
ratio=0.55,
24-
prune_params_type='conv1x1_only')
25+
prune_params_type='conv1x1_only',
26+
local_sparsity=False)
2527

2628
def test_prune(self):
2729
ori_sparsity = UnstructuredPruner.total_sparse(self.net)

tests/test_unstructured_pruner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,18 @@ def _gen_model(self):
4242
exe.run(self.startup_program, scope=self.scope)
4343

4444
self.pruner = UnstructuredPruner(
45-
self.main_program, 'ratio', scope=self.scope, place=place)
45+
self.main_program,
46+
'ratio',
47+
scope=self.scope,
48+
place=place,
49+
local_sparsity=True)
4650
self.pruner_conv1x1 = UnstructuredPruner(
4751
self.main_program,
4852
'ratio',
4953
scope=self.scope,
5054
place=place,
51-
prune_params_type='conv1x1_only')
55+
prune_params_type='conv1x1_only',
56+
local_sparsity=False)
5257

5358
def test_unstructured_prune(self):
5459
for param in self.main_program.global_block().all_parameters():

0 commit comments

Comments
 (0)