Skip to content

Commit 168f058

Browse files
authored
fix ofa export bug (PaddlePaddle#1322)
1 parent 492b040 commit 168f058

File tree

1 file changed

+47
-14
lines changed

1 file changed

+47
-14
lines changed

examples/model_compression/ofa/export_model.py

+47-14
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,35 @@
2727
import paddle.nn.functional as F
2828

2929
from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer
30+
from paddlenlp.transformers import TinyBertModel, TinyBertForSequenceClassification, TinyBertTokenizer
31+
from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer
32+
from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer
3033
from paddlenlp.utils.log import logger
3134
from paddleslim.nas.ofa import OFA, utils
3235
from paddleslim.nas.ofa.convert_super import Convert, supernet
3336
from paddleslim.nas.ofa.layers import BaseBlock
3437

35-
MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), }
38+
MODEL_CLASSES = {
39+
"bert": (BertForSequenceClassification, BertTokenizer),
40+
"roberta": (RobertaForSequenceClassification, RobertaTokenizer),
41+
"tinybert": (TinyBertForSequenceClassification, TinyBertTokenizer),
42+
}
43+
44+
45+
def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
46+
wtype = self.pooler.dense.fn.weight.dtype if hasattr(
47+
self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
48+
if attention_mask is None:
49+
attention_mask = paddle.unsqueeze(
50+
(input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
51+
embedding_output = self.embeddings(input_ids, token_type_ids)
52+
encoded_layer = self.encoder(embedding_output, attention_mask)
53+
pooled_output = self.pooler(encoded_layer)
54+
55+
return encoded_layer, pooled_output
56+
57+
58+
TinyBertModel.forward = tinybert_forward
3659

3760

3861
def parse_args():
@@ -113,14 +136,15 @@ def do_train(args):
113136
config_path = os.path.join(args.model_name_or_path, 'model_config.json')
114137
cfg_dict = dict(json.loads(open(config_path).read()))
115138

139+
kept_layers_index = {}
116140
if args.depth_mult < 1.0:
117-
depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] * args.depth_mult)
118-
cfg_dict["init_args"][0]['num_hidden_layers'] = depth
119-
kept_layers_index = {}
120-
for idx, i in enumerate(range(1, depth+1)):
141+
depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] *
142+
args.depth_mult)
143+
cfg_dict["init_args"][0]['num_hidden_layers'] = depth
144+
for idx, i in enumerate(range(1, depth + 1)):
121145
kept_layers_index[idx] = math.floor(i / args.depth_mult) - 1
122146

123-
os.rename(config_path, config_path+'_bak')
147+
os.rename(config_path, config_path + '_bak')
124148
with open(config_path, "w", encoding="utf-8") as f:
125149
f.write(json.dumps(cfg_dict, ensure_ascii=False))
126150

@@ -132,7 +156,7 @@ def do_train(args):
132156
origin_model = model_class.from_pretrained(
133157
args.model_name_or_path, num_classes=num_labels)
134158

135-
os.rename(config_path+'_bak', config_path)
159+
os.rename(config_path + '_bak', config_path)
136160

137161
sp_config = supernet(expand_ratio=[1.0, args.width_mult])
138162
model = Convert(sp_config).convert(model)
@@ -142,15 +166,24 @@ def do_train(args):
142166
sd = paddle.load(
143167
os.path.join(args.model_name_or_path, 'model_state.pdparams'))
144168

145-
for name, params in ofa_model.model.named_parameters():
146-
if 'encoder' not in name:
147-
params.set_value(sd[name])
148-
else:
149-
idx = int(name.strip().split('.')[3])
150-
mapping_name = name.replace('.'+str(idx)+'.', '.'+str(kept_layers_index[idx])+'.')
151-
params.set_value(sd[mapping_name])
169+
if len(kept_layers_index) == 0:
170+
ofa_model.model.set_state_dict(sd)
171+
else:
172+
for name, params in ofa_model.model.named_parameters():
173+
if 'encoder' not in name:
174+
params.set_value(sd[name])
175+
else:
176+
idx = int(name.strip().split('.')[3])
177+
mapping_name = name.replace(
178+
'.' + str(idx) + '.',
179+
'.' + str(kept_layers_index[idx]) + '.')
180+
params.set_value(sd[mapping_name])
152181

153182
best_config = utils.dynabert_config(ofa_model, args.width_mult)
183+
for name, sublayer in ofa_model.model.named_sublayers():
184+
if isinstance(sublayer, paddle.nn.MultiHeadAttention):
185+
sublayer.num_heads = int(args.width_mult * sublayer.num_heads)
186+
154187
ofa_model.export(
155188
best_config,
156189
input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]],

0 commit comments

Comments
 (0)