Skip to content

Commit e3587f5

Browse files
authored
fix condition of layer_forward in ofa (#777)
1 parent acbac87 commit e3587f5

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

paddleslim/nas/ofa/get_sub_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def check_search_space(graph):
225225
depthwise_conv.append(inp._var.name)
226226

227227
if len(same_search_space) == 0:
228-
return None, None
228+
return None, []
229229

230230
same_search_space = sorted([sorted(x) for x in same_search_space])
231231
final_search_space = []

paddleslim/nas/ofa/ofa.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def layers_forward(self, block, *inputs, **kwargs):
108108
if getattr(self, 'current_config', None) != None:
109109
### if block is fixed, donnot join key into candidate
110110
### concrete config as parameter in kwargs
111-
if block.fixed == False and (
112-
self._skip_layers != None and
113-
self._key2name[block.key] not in self._skip_layers) and \
111+
if block.fixed == False and (self._skip_layers == None or
112+
(self._skip_layers != None and
113+
self._key2name[block.key] not in self._skip_layers)) and \
114114
(block.fn.weight.name not in self._depthwise_conv):
115115
assert self._key2name[
116116
block.
@@ -180,6 +180,7 @@ def __init__(self,
180180
self._build_ss = False
181181
self._broadcast = False
182182
self._skip_layers = None
183+
self._depthwise_conv = []
183184

184185
### if elastic_order is none, use default order
185186
if self.elastic_order is not None:

0 commit comments

Comments
 (0)