@@ -112,7 +112,7 @@ def rep_func(layer: nn.Layer, pattern: str):
112
112
if not isinstance (layer_name_pattern , list ):
113
113
layer_name_pattern = [layer_name_pattern ]
114
114
115
- handle_res_dict = {}
115
+ hit_layer_pattern_list = []
116
116
for pattern in layer_name_pattern :
117
117
# parse pattern to find target layer and its parent
118
118
layer_list = parse_pattern_str (pattern = pattern , parent_layer = self )
@@ -133,8 +133,8 @@ def rep_func(layer: nn.Layer, pattern: str):
133
133
else :
134
134
setattr (sub_layer_parent , sub_layer_name , new_sub_layer )
135
135
136
- handle_res_dict [ pattern ] = new_sub_layer
137
- return handle_res_dict
136
+ hit_layer_pattern_list . append ( pattern )
137
+ return hit_layer_pattern_list
138
138
139
139
def stop_after (self , stop_layer_name : str ) -> bool :
140
140
"""stop forward and backward after 'stop_layer_name'.
@@ -192,15 +192,15 @@ def __call__(self, layer, pattern):
192
192
193
193
handle_func = Handler (self .res_dict )
194
194
195
- res_dict = self .upgrade_sublayer (
195
+ hit_layer_pattern_list = self .upgrade_sublayer (
196
196
return_patterns , handle_func = handle_func )
197
197
198
198
if hasattr (self , "hook_remove_helper" ):
199
199
self .hook_remove_helper .remove ()
200
200
self .hook_remove_helper = self .register_forward_post_hook (
201
201
self ._return_dict_hook )
202
202
203
- return res_dict
203
+ return hit_layer_pattern_list
204
204
205
205
206
206
def set_identity (parent_layer : nn .Layer ,
0 commit comments