|
27 | 27 | import paddle.nn.functional as F
|
28 | 28 |
|
29 | 29 | from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer
|
30 |
| -from paddlenlp.transformers import TinyBertModel, TinyBertForSequenceClassification, TinyBertTokenizer |
31 |
| -from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer |
32 |
| -from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer |
33 | 30 | from paddlenlp.utils.log import logger
|
34 | 31 | from paddleslim.nas.ofa import OFA, utils
|
35 | 32 | from paddleslim.nas.ofa.convert_super import Convert, supernet
|
36 | 33 | from paddleslim.nas.ofa.layers import BaseBlock
|
37 | 34 |
|
38 |
| -MODEL_CLASSES = { |
39 |
| - "bert": (BertForSequenceClassification, BertTokenizer), |
40 |
| - "roberta": (RobertaForSequenceClassification, RobertaTokenizer), |
41 |
| - "tinybert": (TinyBertForSequenceClassification, TinyBertTokenizer), |
42 |
| -} |
| 35 | +MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), } |
43 | 36 |
|
44 | 37 |
|
45 |
| -def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None): |
| 38 | +def bert_forward(self, |
| 39 | + input_ids, |
| 40 | + token_type_ids=None, |
| 41 | + position_ids=None, |
| 42 | + attention_mask=None, |
| 43 | + output_hidden_states=False): |
46 | 44 | wtype = self.pooler.dense.fn.weight.dtype if hasattr(
|
47 | 45 | self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
|
48 | 46 | if attention_mask is None:
|
49 | 47 | attention_mask = paddle.unsqueeze(
|
50 | 48 | (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
|
51 |
| - embedding_output = self.embeddings(input_ids, token_type_ids) |
52 |
| - encoded_layer = self.encoder(embedding_output, attention_mask) |
53 |
| - pooled_output = self.pooler(encoded_layer) |
54 |
| - |
55 |
| - return encoded_layer, pooled_output |
| 49 | + else: |
| 50 | + if attention_mask.ndim == 2: |
| 51 | + # attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length] |
| 52 | + attention_mask = attention_mask.unsqueeze(axis=[1, 2]) |
| 53 | + |
| 54 | + embedding_output = self.embeddings( |
| 55 | + input_ids=input_ids, |
| 56 | + position_ids=position_ids, |
| 57 | + token_type_ids=token_type_ids) |
| 58 | + if output_hidden_states: |
| 59 | + output = embedding_output |
| 60 | + encoder_outputs = [] |
| 61 | + for mod in self.encoder.layers: |
| 62 | + output = mod(output, src_mask=attention_mask) |
| 63 | + encoder_outputs.append(output) |
| 64 | + if self.encoder.norm is not None: |
| 65 | + encoder_outputs[-1] = self.encoder.norm(encoder_outputs[-1]) |
| 66 | + pooled_output = self.pooler(encoder_outputs[-1]) |
| 67 | + else: |
| 68 | + sequence_output = self.encoder(embedding_output, attention_mask) |
| 69 | + pooled_output = self.pooler(sequence_output) |
| 70 | + if output_hidden_states: |
| 71 | + return encoder_outputs, pooled_output |
| 72 | + else: |
| 73 | + return sequence_output, pooled_output |
56 | 74 |
|
57 | 75 |
|
58 |
| -TinyBertModel.forward = tinybert_forward |
| 76 | +BertModel.forward = bert_forward |
59 | 77 |
|
60 | 78 |
|
61 | 79 | def parse_args():
|
|
0 commit comments