Skip to content

Commit f4c1f48

Browse files
authored
[cherry-pick][unstructured_prune]Resume training (#958) (#960)
1 parent 853e1d0 commit f4c1f48

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

paddleslim/dygraph/prune/unstructured_pruner.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def __init__(self,
304304
self.cur_iteration = configs.get('resume_iteration')
305305

306306
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
307+
self._need_prune_once = False
307308
self._prepare_training_hyper_parameters()
308309

309310
def _prepare_training_hyper_parameters(self):
@@ -330,6 +331,7 @@ def _prepare_training_hyper_parameters(self):
330331

331332
# pop out used ratios to resume training
332333
for i in range(self.cur_iteration):
334+
self._need_prune_once = True
333335
if len(self.
334336
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
335337
self.ratio = self.ratios_stack.pop()
@@ -344,7 +346,8 @@ def step(self):
344346

345347
# Update the threshold and masks only when a new ratio has been set.
346348
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
347-
if ori_ratio != self.ratio:
349+
if ori_ratio != self.ratio or self._need_prune_once:
348350
self.update_threshold()
349351
self._update_masks()
352+
self._need_prune_once = False
350353
self.cur_iteration += 1

paddleslim/prune/unstructured_pruner.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(self,
350350
self.cur_iteration = configs.get('resume_iteration')
351351

352352
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
353+
self._need_prune_once = False
353354
self._prepare_training_hyper_parameters()
354355

355356
def _prepare_training_hyper_parameters(self):
@@ -376,6 +377,7 @@ def _prepare_training_hyper_parameters(self):
376377

377378
# pop out used ratios to resume training
378379
for i in range(self.cur_iteration):
380+
self._need_prune_once = True
379381
if len(self.
380382
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
381383
self.ratio = self.ratios_stack.pop()
@@ -393,7 +395,8 @@ def step(self):
393395

394396
# Update the threshold and masks only when a new ratio has been set.
395397
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
396-
if ori_ratio != self.ratio:
398+
if ori_ratio != self.ratio or self._need_prune_once:
397399
self.update_threshold()
398400
self._update_masks()
401+
self._need_prune_once = False
399402
self.cur_iteration += 1

0 commit comments

Comments
 (0)