Skip to content

Commit 15b3dad

Browse files
tianxinZeyuChen
tianxin
andauthored
Add SimCSE (PaddlePaddle#854)
* add FewCLUE 9 datasets * fix a bug for tnews * Add CI for Ernie text matching * Add CI for Ernie text matching * Add CI for Ernie text matching * fix encoding problem for windows * update ernie_text_matching ci * standard GPU id for CI * standard GPU id for CI * add simcse * update train.py * support multi-card training * add README.md Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
1 parent 9d95922 commit 15b3dad

File tree

5 files changed

+705
-0
lines changed

5 files changed

+705
-0
lines changed
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# 无监督语义匹配模型 [SimCSE](https://arxiv.org/abs/2104.08821)
2+
3+
我们实现了 SimCSE 模型,并在 4 个常用中文语义匹配数据集上对 SimCSE 模型的无监督匹配效果进行了评测。SimCSE 模型适合缺乏监督数据,但是又有大量无监督数据的匹配和检索场景。
4+
5+
## 效果评估
6+
本项目分别使用 LCQMC、BQ_Corpus、STS-B、ATEC 这 4 个语义匹配数据集的训练集作为无监督训练集(仅使用文本信息,不适用 Label),并且在各自数据集上的验证集上进行效果评估,评估指标采用 SimCSE 论文中采用的 Spearman 相关系数,Spearman 相关系数越高,表示模型效果越好。
7+
8+
| 模型 | Infer_with_fc| LCQMC | BQ_Corpus|STS-B|ATEC|
9+
| ------- |-------|-------|-----|------|-----|
10+
| ERNIE-1.0|| 52.33 | 43.75 | 66.66 | 29.78 |
11+
| ERNIE-1.0|| 57.01 | 51.72 | 74.76 | 33.56 |
12+
13+
**Note**: Infer_with_fc 表示在预测阶段计算文本 embedding 表示的时候网络前向是否会过训练阶段最后一层的 fc, 由表格可知: 预测阶段不使用最后一层 fc 可以显著提升无监督语义匹配的效果.
14+
15+
## 快速开始
16+
17+
### 代码结构说明
18+
19+
以下是本项目主要代码结构及说明:
20+
21+
```
22+
simcse/
23+
├── model.py # SimCSE 模型组网代码
24+
├── data.py # 无监督语义匹配训练数据、测试数据的读取逻辑
25+
├── predict.py # 基于训练好的无监督语义匹配模型计算文本 Pair 相似度
26+
└── train.py # SimCSE 模型训练、评估逻辑
27+
```
28+
29+
### 模型训练
30+
我们以中文文本匹配公开数据集 LCQMC 为示例数据集, 仅使用 LCQMC 的文本数据构造生成了无监督的训练数据。可以运行如下命令,开始模型训练并且在 LCQMC 的验证集上进行 Spearman 相关系数评估.
31+
32+
```shell
33+
$ unset CUDA_VISIBLE_DEVICES
34+
python -u -m paddle.distributed.launch --gpus '0' \
35+
train.py \
36+
--device gpu \
37+
--save_dir ./checkpoints/ \
38+
--batch_size 64 \
39+
--learning_rate 5E-5 \
40+
--epochs 1 \
41+
--save_steps 100 \
42+
--eval_steps 100 \
43+
--max_seq_length 64 \
44+
--infer_with_fc_pooler \
45+
--dropout 0.3 \
46+
--train_set_file "./senteval_cn/LCQMC/train.txt" \
47+
--test_set_file "./senteval_cn/LCQMC/dev.tsv"
48+
```
49+
50+
可支持配置的参数:
51+
52+
* `infer_with_fc_pooler`:可选,在预测阶段计算文本 embedding 表示的时候网络前向是否会过训练阶段最后一层的 fc; 建议打开模型效果最好。
53+
* `scale`:可选,在计算 cross_entropy loss 之前对 cosine 相似度进行缩放的因子;默认为 20。
54+
* `dropout`:可选,SimCSE 网络前向使用的 dropout 取值;默认 0.1。
55+
* `save_dir`:可选,保存训练模型的目录;默认保存在当前目录checkpoints文件夹下。
56+
* `max_seq_length`:可选,ERNIE-Gram 模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。
57+
* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
58+
* `learning_rate`:可选,Fine-tune的最大学习率;默认为5e-5。
59+
* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.0。
60+
* `epochs`: 训练轮次,默认为3。
61+
* `warmup_proption`:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.0。
62+
* `init_from_ckpt`:可选,模型参数路径,热启动模型训练;默认为None。
63+
* `seed`:可选,随机种子,默认为1000.
64+
* `device`: 选用什么设备进行训练,可选cpu或gpu。如使用gpu训练则参数gpus指定GPU卡号。
65+
66+
程序运行时将会自动进行训练,评估。同时训练过程中会自动保存模型在指定的`save_dir`中。
67+
如:
68+
```text
69+
checkpoints/
70+
├── model_100
71+
│   ├── model_state.pdparams
72+
│   ├── tokenizer_config.json
73+
│   └── vocab.txt
74+
└── ...
75+
```
76+
77+
**NOTE:**
78+
* 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`
79+
80+
### 基于动态图模型预测
81+
82+
我们用 LCQMC 的测试集作为预测数据, 测试数据示例如下,:
83+
```text
84+
谁有狂三这张高清的 这张高清图,谁有
85+
英雄联盟什么英雄最好 英雄联盟最好英雄是什么
86+
这是什么意思,被蹭网吗 我也是醉了,这是什么意思
87+
现在有什么动画片好看呢? 现在有什么好看的动画片吗?
88+
请问晶达电子厂现在的工资待遇怎么样要求有哪些 三星电子厂工资待遇怎么样啊
89+
```
90+
91+
执行如下命令开始预测:
92+
```shell
93+
python -u -m paddle.distributed.launch --gpus "0" \
94+
predict.py \
95+
--device gpu \
96+
--params_path "./checkpoints/model_4400/model_state.pdparams"\
97+
--batch_size 64 \
98+
--max_seq_length 64 \
99+
--input_file 'test.tsv'
100+
```
101+
102+
输出预测结果如下:
103+
```text
104+
0.7201147675514221
105+
0.9010907411575317
106+
0.5393891334533691
107+
0.9698929786682129
108+
0.6056119203567505
109+
```
110+
111+
## Reference
112+
[1] Gao, Tianyu, Xingcheng Yao, and Danqi Chen. “SimCSE: Simple Contrastive Learning of Sentence Embeddings.” ArXiv:2104.08821 [Cs], April 18, 2021. http://arxiv.org/abs/2104.08821.

examples/text_matching/simcse/data.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import paddle
18+
19+
from paddlenlp.utils.log import logger
20+
21+
22+
def create_dataloader(dataset,
23+
mode='train',
24+
batch_size=1,
25+
batchify_fn=None,
26+
trans_fn=None):
27+
if trans_fn:
28+
dataset = dataset.map(trans_fn)
29+
30+
shuffle = True if mode == 'train' else False
31+
if mode == 'train':
32+
batch_sampler = paddle.io.DistributedBatchSampler(
33+
dataset, batch_size=batch_size, shuffle=shuffle)
34+
else:
35+
batch_sampler = paddle.io.BatchSampler(
36+
dataset, batch_size=batch_size, shuffle=shuffle)
37+
38+
return paddle.io.DataLoader(
39+
dataset=dataset,
40+
batch_sampler=batch_sampler,
41+
collate_fn=batchify_fn,
42+
return_list=True)
43+
44+
45+
def convert_example(example, tokenizer, max_seq_length=512, do_evalute=False):
46+
"""
47+
Builds model inputs from a sequence.
48+
49+
A BERT sequence has the following format:
50+
51+
- single sequence: ``[CLS] X [SEP]``
52+
53+
Args:
54+
example(obj:`list(str)`): The list of text to be converted to ids.
55+
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
56+
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
57+
max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
58+
Sequences longer than this will be truncated, sequences shorter will be padded.
59+
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
60+
61+
Returns:
62+
input_ids(obj:`list[int]`): The list of query token ids.
63+
token_type_ids(obj: `list[int]`): List of query sequence pair mask.
64+
"""
65+
66+
result = []
67+
68+
for key, text in example.items():
69+
if 'label' in key:
70+
# do_evaluate
71+
result += [example['label']]
72+
else:
73+
# do_train
74+
encoded_inputs = tokenizer(text=text, max_seq_len=max_seq_length)
75+
input_ids = encoded_inputs["input_ids"]
76+
token_type_ids = encoded_inputs["token_type_ids"]
77+
result += [input_ids, token_type_ids]
78+
79+
return result
80+
81+
82+
def read_simcse_text(data_path):
83+
"""Reads data."""
84+
with open(data_path, 'r', encoding='utf-8') as f:
85+
for line in f:
86+
data = line.rstrip()
87+
yield {'text_a': data, 'text_b': data}
88+
89+
90+
def read_text_pair(data_path, is_test=False):
91+
"""Reads data."""
92+
with open(data_path, 'r', encoding='utf-8') as f:
93+
for line in f:
94+
data = line.rstrip().split("\t")
95+
if is_test == False:
96+
if len(data) != 3:
97+
continue
98+
yield {'text_a': data[0], 'text_b': data[1], 'label': data[2]}
99+
else:
100+
if len(data) != 2:
101+
continue
102+
yield {'text_a': data[0], 'text_b': data[1]}
+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
import sys
17+
18+
import numpy as np
19+
20+
import paddle
21+
import paddle.nn as nn
22+
import paddle.nn.functional as F
23+
24+
25+
class SimCSE(nn.Layer):
26+
def __init__(self,
27+
pretrained_model,
28+
dropout=None,
29+
margin=0.0,
30+
scale=20,
31+
output_emb_size=None):
32+
33+
super().__init__()
34+
35+
self.ptm = pretrained_model
36+
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
37+
38+
# if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,
39+
# we recommend set output_emb_size = 256 considering the trade-off beteween
40+
# recall performance and efficiency
41+
self.output_emb_size = output_emb_size
42+
if output_emb_size > 0:
43+
weight_attr = paddle.ParamAttr(
44+
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
45+
self.emb_reduce_linear = paddle.nn.Linear(
46+
768, output_emb_size, weight_attr=weight_attr)
47+
48+
self.margin = margin
49+
# Used scaling cosine similarity to ease converge
50+
self.sacle = scale
51+
52+
def get_pooled_embedding(self,
53+
input_ids,
54+
token_type_ids=None,
55+
position_ids=None,
56+
attention_mask=None,
57+
with_pooler=True):
58+
59+
# Note: cls_embedding is poolerd embedding with act tanh
60+
sequence_output, cls_embedding = self.ptm(input_ids, token_type_ids,
61+
position_ids, attention_mask)
62+
63+
if with_pooler == False:
64+
cls_embedding = sequence_output[:, 0, :]
65+
66+
if self.output_emb_size > 0:
67+
cls_embedding = self.emb_reduce_linear(cls_embedding)
68+
69+
cls_embedding = self.dropout(cls_embedding)
70+
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)
71+
72+
return cls_embedding
73+
74+
def cosine_sim(self,
75+
query_input_ids,
76+
title_input_ids,
77+
query_token_type_ids=None,
78+
query_position_ids=None,
79+
query_attention_mask=None,
80+
title_token_type_ids=None,
81+
title_position_ids=None,
82+
title_attention_mask=None,
83+
with_pooler=True):
84+
85+
query_cls_embedding = self.get_pooled_embedding(
86+
query_input_ids,
87+
query_token_type_ids,
88+
query_position_ids,
89+
query_attention_mask,
90+
with_pooler=with_pooler)
91+
92+
title_cls_embedding = self.get_pooled_embedding(
93+
title_input_ids,
94+
title_token_type_ids,
95+
title_position_ids,
96+
title_attention_mask,
97+
with_pooler=with_pooler)
98+
99+
cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
100+
axis=-1)
101+
return cosine_sim
102+
103+
def forward(self,
104+
query_input_ids,
105+
title_input_ids,
106+
query_token_type_ids=None,
107+
query_position_ids=None,
108+
query_attention_mask=None,
109+
title_token_type_ids=None,
110+
title_position_ids=None,
111+
title_attention_mask=None):
112+
113+
query_cls_embedding = self.get_pooled_embedding(
114+
query_input_ids, query_token_type_ids, query_position_ids,
115+
query_attention_mask)
116+
117+
title_cls_embedding = self.get_pooled_embedding(
118+
title_input_ids, title_token_type_ids, title_position_ids,
119+
title_attention_mask)
120+
121+
cosine_sim = paddle.matmul(
122+
query_cls_embedding, title_cls_embedding, transpose_y=True)
123+
124+
# substract margin from all positive samples cosine_sim()
125+
margin_diag = paddle.full(
126+
shape=[query_cls_embedding.shape[0]],
127+
fill_value=self.margin,
128+
dtype=paddle.get_default_dtype())
129+
130+
cosine_sim = cosine_sim - paddle.diag(margin_diag)
131+
132+
# scale cosine to ease training converge
133+
cosine_sim *= self.sacle
134+
135+
labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')
136+
labels = paddle.reshape(labels, shape=[-1, 1])
137+
138+
loss = F.cross_entropy(input=cosine_sim, label=labels)
139+
140+
return loss

0 commit comments

Comments
 (0)