Skip to content

Commit 54bef3d

Browse files
authored
Support ofa export for bert (PaddlePaddle#1326)
* fix ofa export bug * support bert export
1 parent 168f058 commit 54bef3d

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

examples/model_compression/ofa/export_model.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,53 @@
2727
import paddle.nn.functional as F
2828

2929
from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer
30-
from paddlenlp.transformers import TinyBertModel, TinyBertForSequenceClassification, TinyBertTokenizer
31-
from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer
32-
from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer
3330
from paddlenlp.utils.log import logger
3431
from paddleslim.nas.ofa import OFA, utils
3532
from paddleslim.nas.ofa.convert_super import Convert, supernet
3633
from paddleslim.nas.ofa.layers import BaseBlock
3734

38-
MODEL_CLASSES = {
39-
"bert": (BertForSequenceClassification, BertTokenizer),
40-
"roberta": (RobertaForSequenceClassification, RobertaTokenizer),
41-
"tinybert": (TinyBertForSequenceClassification, TinyBertTokenizer),
42-
}
35+
MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), }
4336

4437

45-
def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
38+
def bert_forward(self,
39+
input_ids,
40+
token_type_ids=None,
41+
position_ids=None,
42+
attention_mask=None,
43+
output_hidden_states=False):
4644
wtype = self.pooler.dense.fn.weight.dtype if hasattr(
4745
self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
4846
if attention_mask is None:
4947
attention_mask = paddle.unsqueeze(
5048
(input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
51-
embedding_output = self.embeddings(input_ids, token_type_ids)
52-
encoded_layer = self.encoder(embedding_output, attention_mask)
53-
pooled_output = self.pooler(encoded_layer)
54-
55-
return encoded_layer, pooled_output
49+
else:
50+
if attention_mask.ndim == 2:
51+
# attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length]
52+
attention_mask = attention_mask.unsqueeze(axis=[1, 2])
53+
54+
embedding_output = self.embeddings(
55+
input_ids=input_ids,
56+
position_ids=position_ids,
57+
token_type_ids=token_type_ids)
58+
if output_hidden_states:
59+
output = embedding_output
60+
encoder_outputs = []
61+
for mod in self.encoder.layers:
62+
output = mod(output, src_mask=attention_mask)
63+
encoder_outputs.append(output)
64+
if self.encoder.norm is not None:
65+
encoder_outputs[-1] = self.encoder.norm(encoder_outputs[-1])
66+
pooled_output = self.pooler(encoder_outputs[-1])
67+
else:
68+
sequence_output = self.encoder(embedding_output, attention_mask)
69+
pooled_output = self.pooler(sequence_output)
70+
if output_hidden_states:
71+
return encoder_outputs, pooled_output
72+
else:
73+
return sequence_output, pooled_output
5674

5775

58-
TinyBertModel.forward = tinybert_forward
76+
BertModel.forward = bert_forward
5977

6078

6179
def parse_args():

0 commit comments

Comments
 (0)