27
27
import paddle .nn .functional as F
28
28
29
29
from paddlenlp .transformers import BertModel , BertForSequenceClassification , BertTokenizer
30
+ from paddlenlp .transformers import TinyBertModel , TinyBertForSequenceClassification , TinyBertTokenizer
31
+ from paddlenlp .transformers import TinyBertForSequenceClassification , TinyBertTokenizer
32
+ from paddlenlp .transformers import RobertaForSequenceClassification , RobertaTokenizer
30
33
from paddlenlp .utils .log import logger
31
34
from paddleslim .nas .ofa import OFA , utils
32
35
from paddleslim .nas .ofa .convert_super import Convert , supernet
33
36
from paddleslim .nas .ofa .layers import BaseBlock
34
37
35
- MODEL_CLASSES = {"bert" : (BertForSequenceClassification , BertTokenizer ), }
38
+ MODEL_CLASSES = {
39
+ "bert" : (BertForSequenceClassification , BertTokenizer ),
40
+ "roberta" : (RobertaForSequenceClassification , RobertaTokenizer ),
41
+ "tinybert" : (TinyBertForSequenceClassification , TinyBertTokenizer ),
42
+ }
43
+
44
+
45
+ def tinybert_forward (self , input_ids , token_type_ids = None , attention_mask = None ):
46
+ wtype = self .pooler .dense .fn .weight .dtype if hasattr (
47
+ self .pooler .dense , 'fn' ) else self .pooler .dense .weight .dtype
48
+ if attention_mask is None :
49
+ attention_mask = paddle .unsqueeze (
50
+ (input_ids == self .pad_token_id ).astype (wtype ) * - 1e9 , axis = [1 , 2 ])
51
+ embedding_output = self .embeddings (input_ids , token_type_ids )
52
+ encoded_layer = self .encoder (embedding_output , attention_mask )
53
+ pooled_output = self .pooler (encoded_layer )
54
+
55
+ return encoded_layer , pooled_output
56
+
57
+
58
+ TinyBertModel .forward = tinybert_forward
36
59
37
60
38
61
def parse_args ():
@@ -113,14 +136,15 @@ def do_train(args):
113
136
config_path = os .path .join (args .model_name_or_path , 'model_config.json' )
114
137
cfg_dict = dict (json .loads (open (config_path ).read ()))
115
138
139
+ kept_layers_index = {}
116
140
if args .depth_mult < 1.0 :
117
- depth = round (cfg_dict ["init_args" ][0 ]['num_hidden_layers' ] * args . depth_mult )
118
- cfg_dict [ "init_args" ][ 0 ][ 'num_hidden_layers' ] = depth
119
- kept_layers_index = {}
120
- for idx , i in enumerate (range (1 , depth + 1 )):
141
+ depth = round (cfg_dict ["init_args" ][0 ]['num_hidden_layers' ] *
142
+ args . depth_mult )
143
+ cfg_dict [ "init_args" ][ 0 ][ 'num_hidden_layers' ] = depth
144
+ for idx , i in enumerate (range (1 , depth + 1 )):
121
145
kept_layers_index [idx ] = math .floor (i / args .depth_mult ) - 1
122
146
123
- os .rename (config_path , config_path + '_bak' )
147
+ os .rename (config_path , config_path + '_bak' )
124
148
with open (config_path , "w" , encoding = "utf-8" ) as f :
125
149
f .write (json .dumps (cfg_dict , ensure_ascii = False ))
126
150
@@ -132,7 +156,7 @@ def do_train(args):
132
156
origin_model = model_class .from_pretrained (
133
157
args .model_name_or_path , num_classes = num_labels )
134
158
135
- os .rename (config_path + '_bak' , config_path )
159
+ os .rename (config_path + '_bak' , config_path )
136
160
137
161
sp_config = supernet (expand_ratio = [1.0 , args .width_mult ])
138
162
model = Convert (sp_config ).convert (model )
@@ -142,15 +166,24 @@ def do_train(args):
142
166
sd = paddle .load (
143
167
os .path .join (args .model_name_or_path , 'model_state.pdparams' ))
144
168
145
- for name , params in ofa_model .model .named_parameters ():
146
- if 'encoder' not in name :
147
- params .set_value (sd [name ])
148
- else :
149
- idx = int (name .strip ().split ('.' )[3 ])
150
- mapping_name = name .replace ('.' + str (idx )+ '.' , '.' + str (kept_layers_index [idx ])+ '.' )
151
- params .set_value (sd [mapping_name ])
169
+ if len (kept_layers_index ) == 0 :
170
+ ofa_model .model .set_state_dict (sd )
171
+ else :
172
+ for name , params in ofa_model .model .named_parameters ():
173
+ if 'encoder' not in name :
174
+ params .set_value (sd [name ])
175
+ else :
176
+ idx = int (name .strip ().split ('.' )[3 ])
177
+ mapping_name = name .replace (
178
+ '.' + str (idx ) + '.' ,
179
+ '.' + str (kept_layers_index [idx ]) + '.' )
180
+ params .set_value (sd [mapping_name ])
152
181
153
182
best_config = utils .dynabert_config (ofa_model , args .width_mult )
183
+ for name , sublayer in ofa_model .model .named_sublayers ():
184
+ if isinstance (sublayer , paddle .nn .MultiHeadAttention ):
185
+ sublayer .num_heads = int (args .width_mult * sublayer .num_heads )
186
+
154
187
ofa_model .export (
155
188
best_config ,
156
189
input_shapes = [[1 , args .max_seq_length ], [1 , args .max_seq_length ]],
0 commit comments