Skip to content

Commit 045e4e2

Browse files
authored
Add embedding finetune demo (#1204)
* Add embedding seq-cls finetune demo and update api * Update docs of pad_sequence and trunc_sequence
1 parent 5832b1a commit 045e4e2

File tree

7 files changed

+484
-52
lines changed

7 files changed

+484
-52
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import List
17+
18+
import paddle
19+
import paddle.nn as nn
20+
import paddle.nn.functional as F
21+
22+
import paddlenlp as nlp
23+
from paddlenlp.embeddings import TokenEmbedding
24+
from paddlenlp.data import JiebaTokenizer
25+
26+
from paddlehub.utils.log import logger
27+
from paddlehub.utils.utils import pad_sequence, trunc_sequence
28+
29+
30+
class BoWModel(nn.Layer):
31+
"""
32+
This class implements the Bag of Words Classification Network model to classify texts.
33+
At a high level, the model starts by embedding the tokens and running them through
34+
a word embedding. Then, we encode these epresentations with a `BoWEncoder`.
35+
Lastly, we take the output of the encoder to create a final representation,
36+
which is passed through some feed-forward layers to output a logits (`output_layer`).
37+
Args:
38+
vocab_size (obj:`int`): The vocabulary size.
39+
emb_dim (obj:`int`, optional, defaults to 300): The embedding dimension.
40+
hidden_size (obj:`int`, optional, defaults to 128): The first full-connected layer hidden size.
41+
fc_hidden_size (obj:`int`, optional, defaults to 96): The second full-connected layer hidden size.
42+
num_classes (obj:`int`): All the labels that the data has.
43+
"""
44+
45+
def __init__(self,
46+
num_classes: int = 2,
47+
embedder: TokenEmbedding = None,
48+
tokenizer: JiebaTokenizer = None,
49+
hidden_size: int = 128,
50+
fc_hidden_size: int = 96,
51+
load_checkpoint: str = None,
52+
label_map: dict = None):
53+
super().__init__()
54+
self.embedder = embedder
55+
self.tokenizer = tokenizer
56+
self.label_map = label_map
57+
58+
emb_dim = self.embedder.embedding_dim
59+
self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim)
60+
self.fc1 = nn.Linear(self.bow_encoder.get_output_dim(), hidden_size)
61+
self.fc2 = nn.Linear(hidden_size, fc_hidden_size)
62+
self.dropout = nn.Dropout(p=0.3, axis=1)
63+
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
64+
self.criterion = nn.loss.CrossEntropyLoss()
65+
self.metric = paddle.metric.Accuracy()
66+
67+
if load_checkpoint is not None and os.path.isfile(load_checkpoint):
68+
state_dict = paddle.load(load_checkpoint)
69+
self.set_state_dict(state_dict)
70+
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
71+
72+
def training_step(self, batch: List[paddle.Tensor], batch_idx: int):
73+
"""
74+
One step for training, which should be called as forward computation.
75+
Args:
76+
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed,
77+
such as input_ids, sent_ids, pos_ids, input_mask and labels.
78+
batch_idx(int): The index of batch.
79+
Returns:
80+
results(:obj: Dict) : The model outputs, such as loss and metrics.
81+
"""
82+
_, avg_loss, metric = self(ids=batch[0], labels=batch[1])
83+
self.metric.reset()
84+
return {'loss': avg_loss, 'metrics': metric}
85+
86+
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int):
87+
"""
88+
One step for validation, which should be called as forward computation.
89+
Args:
90+
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed,
91+
such as input_ids, sent_ids, pos_ids, input_mask and labels.
92+
batch_idx(int): The index of batch.
93+
Returns:
94+
results(:obj: Dict) : The model outputs, such as metrics.
95+
"""
96+
_, _, metric = self(ids=batch[0], labels=batch[1])
97+
self.metric.reset()
98+
return {'metrics': metric}
99+
100+
def forward(self, ids: paddle.Tensor, labels: paddle.Tensor = None):
101+
102+
# Shape: (batch_size, num_tokens, embedding_dim)
103+
embedded_text = self.embedder(ids)
104+
105+
# Shape: (batch_size, embedding_dim)
106+
summed = self.bow_encoder(embedded_text)
107+
summed = self.dropout(summed)
108+
encoded_text = paddle.tanh(summed)
109+
110+
# Shape: (batch_size, hidden_size)
111+
fc1_out = paddle.tanh(self.fc1(encoded_text))
112+
# Shape: (batch_size, fc_hidden_size)
113+
fc2_out = paddle.tanh(self.fc2(fc1_out))
114+
# Shape: (batch_size, num_classes)
115+
logits = self.output_layer(fc2_out)
116+
117+
probs = F.softmax(logits, axis=1)
118+
if labels is not None:
119+
loss = self.criterion(logits, labels)
120+
correct = self.metric.compute(probs, labels)
121+
acc = self.metric.update(correct)
122+
return probs, loss, {'acc': acc}
123+
else:
124+
return probs
125+
126+
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int):
127+
examples = []
128+
for item in data:
129+
ids = self.tokenizer.encode(sentence=item[0])
130+
131+
if len(ids) > max_seq_len:
132+
ids = trunc_sequence(ids, max_seq_len)
133+
else:
134+
pad_token = self.tokenizer.vocab.pad_token
135+
pad_token_id = self.tokenizer.vocab.to_indices(pad_token)
136+
ids = pad_sequence(ids, max_seq_len, pad_token_id)
137+
examples.append(ids)
138+
139+
# Seperates data into some batches.
140+
one_batch = []
141+
for example in examples:
142+
one_batch.append(example)
143+
if len(one_batch) == batch_size:
144+
yield one_batch
145+
one_batch = []
146+
if one_batch:
147+
# The last batch whose size is less than the config batch_size setting.
148+
yield one_batch
149+
150+
def predict(
151+
self,
152+
data: List[List[str]],
153+
max_seq_len: int = 128,
154+
batch_size: int = 1,
155+
use_gpu: bool = False,
156+
return_result: bool = True,
157+
):
158+
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
159+
160+
batches = self._batchify(data, max_seq_len, batch_size)
161+
results = []
162+
self.eval()
163+
for batch in batches:
164+
ids = paddle.to_tensor(batch)
165+
probs = self(ids)
166+
idx = paddle.argmax(probs, axis=1).numpy()
167+
168+
if return_result:
169+
idx = idx.tolist()
170+
labels = [self.label_map[i] for i in idx]
171+
results.extend(labels)
172+
else:
173+
results.extend(probs.numpy())
174+
175+
return results
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddlehub as hub
16+
from paddlenlp.data import JiebaTokenizer
17+
from model import BoWModel
18+
19+
import ast
20+
import argparse
21+
22+
23+
parser = argparse.ArgumentParser(__doc__)
24+
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="")
25+
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
26+
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.")
27+
parser.add_argument("--checkpoint", type=str, default='./checkpoint/best_model/model.pdparams', help="Model checkpoint")
28+
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False")
29+
30+
args = parser.parse_args()
31+
32+
33+
if __name__ == '__main__':
34+
# Data to be prdicted
35+
data = [
36+
["这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"],
37+
["交通方便;环境很好;服务态度很好 房间较小"],
38+
["还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。"],
39+
["前台接待太差,酒店有A B楼之分,本人check-in后,前台未告诉B楼在何处,并且B楼无明显指示;房间太小,根本不像4星级设施,下次不会再选择入住此店啦"],
40+
["19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"],
41+
]
42+
43+
label_map = {0: 'negative', 1: 'positive'}
44+
45+
embedder = hub.Module(name=args.hub_embedding_name)
46+
tokenizer = embedder.get_tokenizer()
47+
model = BoWModel(
48+
embedder=embedder,
49+
tokenizer=tokenizer,
50+
load_checkpoint=args.checkpoint,
51+
label_map=label_map)
52+
53+
results = model.predict(data, max_seq_len=args.max_seq_len, batch_size=args.batch_size, use_gpu=args.use_gpu, return_result=False)
54+
for idx, text in enumerate(data):
55+
print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddlehub as hub
17+
from paddlehub.datasets import ChnSentiCorp
18+
from paddlenlp.data import JiebaTokenizer
19+
from model import BoWModel
20+
21+
import ast
22+
import argparse
23+
24+
25+
parser = argparse.ArgumentParser(__doc__)
26+
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="")
27+
parser.add_argument("--num_epoch", type=int, default=10, help="Number of epoches for fine-tuning.")
28+
parser.add_argument("--learning_rate", type=float, default=5e-4, help="Learning rate used to train with warmup.")
29+
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
30+
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.")
31+
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to model checkpoint")
32+
parser.add_argument("--save_interval", type=int, default=5, help="Save checkpoint every n epoch.")
33+
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False")
34+
35+
args = parser.parse_args()
36+
37+
38+
if __name__ == '__main__':
39+
embedder = hub.Module(name=args.hub_embedding_name)
40+
tokenizer = embedder.get_tokenizer()
41+
42+
train_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='train')
43+
dev_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='dev')
44+
test_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='test')
45+
46+
model = BoWModel(embedder=embedder)
47+
optimizer = paddle.optimizer.AdamW(
48+
learning_rate=args.learning_rate, parameters=model.parameters())
49+
trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu)
50+
trainer.train(
51+
train_dataset,
52+
epochs=args.num_epoch,
53+
batch_size=args.batch_size,
54+
eval_dataset=dev_dataset,
55+
save_interval=args.save_interval,
56+
)
57+
trainer.evaluate(test_dataset, batch_size=args.batch_size)

modules/text/embedding/w2v_baidu_encyclopedia_target_word-word_dim300/module.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import List
1616
from paddlenlp.embeddings import TokenEmbedding
1717
from paddlehub.module.module import moduleinfo, serving
18+
from paddlehub.module.nlp_module import EmbeddingModule
1819

1920

2021
@moduleinfo(
@@ -23,33 +24,13 @@
2324
summary="",
2425
author="paddlepaddle",
2526
author_email="",
26-
type="nlp/semantic_model")
27+
type="nlp/semantic_model",
28+
meta=EmbeddingModule)
2729
class Embedding(TokenEmbedding):
2830
"""
2931
Embedding model
3032
"""
31-
def __init__(self, *args, **kwargs):
32-
super(Embedding, self).__init__(embedding_name="w2v.baidu_encyclopedia.target.word-word.dim300", *args, **kwargs)
33-
34-
@serving
35-
def calc_similarity(self, data: List[List[str]]):
36-
"""
37-
Calculate similarities of giving word pairs.
38-
"""
39-
results = []
40-
for word_pair in data:
41-
if len(word_pair) != 2:
42-
raise RuntimeError(
43-
f'The input must have two words, but got {len(word_pair)}. Please check your inputs.')
44-
if not isinstance(word_pair[0], str) or not isinstance(word_pair[1], str):
45-
raise RuntimeError(
46-
f'The types of text pair must be (str, str), but got'
47-
f' ({type(word_pair[0]).__name__}, {type(word_pair[1]).__name__}). Please check your inputs.')
33+
embedding_name = 'w2v.baidu_encyclopedia.target.word-word.dim300'
4834

49-
for word in word_pair:
50-
if self.get_idx_from_word(word) == \
51-
self.get_idx_from_word(self.vocab.unk_token):
52-
raise RuntimeError(
53-
f'Word "{word}" is not in vocab. Please check your inputs.')
54-
results.append(str(self.cosine_sim(*word_pair)))
55-
return results
35+
def __init__(self, *args, **kwargs):
36+
super(Embedding, self).__init__(embedding_name=self.embedding_name, *args, **kwargs)

0 commit comments

Comments
 (0)