Skip to content

Commit dbdcf01

Browse files
authored
[taskflow]Update SentenceeExtraction-m3e (PaddlePaddle#6294)
* Add m3e-base model * Add m3e-base model * Add m3e-base model * Update taskflow.md * Update taskflow.md again * Update piplines-semantic-search * Update piplines-semantic-search
1 parent b34c74f commit dbdcf01

File tree

6 files changed

+375
-10
lines changed

6 files changed

+375
-10
lines changed

docs/model_zoo/taskflow.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1746,6 +1746,7 @@ Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
17461746
| `rocketqa-zh-dureader-para-encoder` | 12 | 768 | 中文|
17471747
| `rocketqa-zh-base-query-encoder` | 12 | 768 | 中文|
17481748
| `rocketqa-zh-base-para-encoder` | 12 | 768 | 中文|
1749+
| `moka-ai/m3e-base` | 12 | 768 | 中文|
17491750
| `rocketqa-zh-medium-query-encoder` | 6 | 768 | 中文|
17501751
| `rocketqa-zh-medium-para-encoder` | 6 | 768 | 中文|
17511752
| `rocketqa-zh-mini-query-encoder` | 6 | 384 | 中文|
@@ -1763,7 +1764,7 @@ Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
17631764
* `max_seq_len`:文本序列的最大长度,默认为128
17641765
* `return_tensors`: 返回的类型,有pd和np,默认为pd。
17651766
* `model`:选择任务使用的模型,默认为`PaddlePaddle/ernie_vil-2.0-base-zh`
1766-
1767+
* `pooling_mode`:选择句向量获取方式,有'max_tokens','mean_tokens','mean_sqrt_len_tokens','cls_token',默认为'cls_token'`moka-ai/m3e-base`)。
17671768

17681769
</div></details>
17691770

paddlenlp/taskflow/taskflow.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
from .text2text_generation import ChatGLMTask
3737
from .text_classification import TextClassificationTask
3838
from .text_correction import CSCTask
39-
from .text_feature_extraction import TextFeatureExtractionTask
39+
from .text_feature_extraction import (
40+
SentenceFeatureExtractionTask,
41+
TextFeatureExtractionTask,
42+
)
4043
from .text_similarity import TextSimilarityTask
4144
from .text_summarization import TextSummarizationTask
4245
from .word_segmentation import SegJiebaTask, SegLACTask, SegWordTagTask
@@ -665,6 +668,16 @@
665668
"task_flag": "feature_extraction-tiny-random-ernievil2",
666669
"task_priority_path": "__internal_testing__/tiny-random-ernievil2",
667670
},
671+
"moka-ai/m3e-base": {
672+
"task_class": SentenceFeatureExtractionTask,
673+
"task_flag": "feature_extraction-moka-ai/m3e-base",
674+
"task_priority_path": "moka-ai/m3e-base",
675+
},
676+
"__internal_testing__/tiny-random-m3e": {
677+
"task_class": SentenceFeatureExtractionTask,
678+
"task_flag": "__internal_testing__/tiny-random-m3e",
679+
"task_priority_path": "__internal_testing__/tiny-random-m3e",
680+
},
668681
},
669682
"default": {"model": "PaddlePaddle/ernie_vil-2.0-base-zh"},
670683
},

paddlenlp/taskflow/text_feature_extraction.py

Lines changed: 267 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import paddle
1919

2020
from paddlenlp.data import DataCollatorWithPadding
21-
from paddlenlp.transformers import AutoTokenizer, ErnieDualEncoder
21+
from paddlenlp.transformers import AutoModel, AutoTokenizer, ErnieDualEncoder
2222

2323
from ..utils.log import logger
2424
from .task import Task
@@ -315,3 +315,269 @@ def _convert_dygraph_to_static(self):
315315
static_model = paddle.jit.to_static(self._model.get_pooled_embedding, input_spec=self._input_spec)
316316
paddle.jit.save(static_model, self.inference_model_path)
317317
logger.info("The inference model save in the path:{}".format(self.inference_model_path))
318+
319+
320+
def text_length(text):
321+
# {key: value} case
322+
if isinstance(text, dict):
323+
return len(next(iter(text.values())))
324+
# Object has no len() method
325+
elif not hasattr(text, "__len__"):
326+
return 1
327+
# Empty string or list of ints
328+
elif len(text) == 0 or isinstance(text[0], int):
329+
return len(text)
330+
# Sum of length of individual strings
331+
else:
332+
return sum([len(t) for t in text])
333+
334+
335+
class SentenceFeatureExtractionTask(Task):
336+
337+
resource_files_names = {
338+
"model_state": "model_state.pdparams",
339+
"config": "config.json",
340+
"vocab_file": "vocab.txt",
341+
"special_tokens_map": "special_tokens_map.json",
342+
"tokenizer_config": "tokenizer_config.json",
343+
}
344+
345+
def __init__(
346+
self,
347+
task: str = None,
348+
model: str = None,
349+
batch_size: int = 1,
350+
max_seq_len: int = 512,
351+
_static_mode: bool = True,
352+
return_tensors: str = "pd",
353+
pooling_mode: str = "cls_token",
354+
**kwargs
355+
):
356+
super().__init__(
357+
task=task,
358+
model=model,
359+
pooling_mode=pooling_mode,
360+
**kwargs,
361+
)
362+
self._seed = None
363+
self.export_type = "text"
364+
self._batch_size = batch_size
365+
self.max_seq_len = max_seq_len
366+
self.model = model
367+
self._static_mode = _static_mode
368+
self.return_tensors = return_tensors
369+
self.pooling_mode = pooling_mode
370+
self._check_predictor_type()
371+
self._construct_tokenizer()
372+
if self._static_mode:
373+
self._get_inference_model()
374+
else:
375+
self._construct_model(model)
376+
377+
def _construct_model(self, model):
378+
"""
379+
Construct the inference model for the predictor.
380+
"""
381+
self._model = AutoModel.from_pretrained(self.model)
382+
self._model.eval()
383+
384+
def _construct_tokenizer(self):
385+
"""
386+
Construct the tokenizer for the predictor.
387+
"""
388+
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
389+
self.pad_token_id = self._tokenizer.convert_tokens_to_ids(self._tokenizer.pad_token)
390+
# Fix windows dtype bug
391+
if self._static_mode:
392+
self._collator = DataCollatorWithPadding(self._tokenizer, return_tensors="np")
393+
else:
394+
self._collator = DataCollatorWithPadding(self._tokenizer, return_tensors="pd")
395+
396+
def _construct_input_spec(self):
397+
"""
398+
Construct the input spec for the convert dygraph model to static model.
399+
"""
400+
self._input_spec = [
401+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
402+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"),
403+
]
404+
405+
def _batchify(self, data, batch_size):
406+
"""
407+
Generate input batches.
408+
"""
409+
410+
def _parse_batch(batch_examples, max_seq_len=None):
411+
if isinstance(batch_examples[0], str):
412+
to_tokenize = [batch_examples]
413+
else:
414+
batch1, batch2 = [], []
415+
for text_tuple in batch_examples:
416+
batch1.append(text_tuple[0])
417+
batch2.append(text_tuple[1])
418+
to_tokenize = [batch1, batch2]
419+
to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
420+
if max_seq_len is None:
421+
max_seq_len = self.max_seq_len
422+
tokenized_inputs = self._tokenizer(
423+
to_tokenize[0],
424+
padding=True,
425+
truncation="longest_first",
426+
max_seq_len=max_seq_len,
427+
)
428+
return tokenized_inputs
429+
430+
# Seperates data into some batches.
431+
one_batch = []
432+
self.length_sorted_idx = np.argsort([-text_length(sen) for sen in data])
433+
sentences_sorted = [data[idx] for idx in self.length_sorted_idx]
434+
for example in range(len(sentences_sorted)):
435+
one_batch.append(sentences_sorted[example])
436+
if len(one_batch) == batch_size:
437+
yield _parse_batch(one_batch)
438+
one_batch = []
439+
if one_batch:
440+
yield _parse_batch(one_batch)
441+
442+
def _preprocess(self, inputs):
443+
"""
444+
Transform the raw inputs to the model inputs, two steps involved:
445+
1) Transform the raw text/image to token ids/pixel_values.
446+
2) Generate the other model inputs from the raw text/image and token ids/pixel_values.
447+
"""
448+
inputs = self._check_input_text(inputs)
449+
batches = self._batchify(inputs, self._batch_size)
450+
outputs = {"batches": batches, "inputs": inputs}
451+
return outputs
452+
453+
def _run_model(self, inputs):
454+
"""
455+
Run the task model from the outputs of the `_preprocess` function.
456+
"""
457+
all_feats = []
458+
if self._static_mode:
459+
with static_mode_guard():
460+
for batch_inputs in inputs["batches"]:
461+
batch_inputs = self._collator(batch_inputs)
462+
if self._predictor_type == "paddle-inference":
463+
if "input_ids" in batch_inputs:
464+
self.input_handles[0].copy_from_cpu(batch_inputs["input_ids"])
465+
self.input_handles[1].copy_from_cpu(batch_inputs["token_type_ids"])
466+
self.predictor.run()
467+
token_embeddings = self.output_handle[0].copy_to_cpu()
468+
if self.pooling_mode == "max_tokens":
469+
attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype(
470+
token_embeddings.dtype
471+
)
472+
input_mask_expanded = np.expand_dims(attention_mask, -1).repeat(
473+
token_embeddings.shape[-1], axis=-1
474+
)
475+
token_embeddings[input_mask_expanded == 0] = -1e9
476+
max_over_time = np.max(token_embeddings, 1)
477+
all_feats.append(max_over_time)
478+
elif self.pooling_mode == "mean_tokens" or self.pooling_mode == "mean_sqrt_len_tokens":
479+
attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype(
480+
token_embeddings.dtype
481+
)
482+
input_mask_expanded = np.expand_dims(attention_mask, -1).repeat(
483+
token_embeddings.shape[-1], axis=-1
484+
)
485+
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1)
486+
sum_mask = input_mask_expanded.sum(1)
487+
sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=np.max(sum_mask))
488+
if self.pooling_mode == "mean_tokens":
489+
all_feats.append(sum_embeddings / sum_mask)
490+
elif self.pooling_mode == "mean_sqrt_len_tokens":
491+
all_feats.append(sum_embeddings / np.sqrt(sum_mask))
492+
else:
493+
cls_token = token_embeddings[:, 0]
494+
all_feats.append(cls_token)
495+
else:
496+
# onnx mode
497+
if "input_ids" in batch_inputs:
498+
input_dict = {}
499+
input_dict["input_ids"] = batch_inputs["input_ids"]
500+
input_dict["token_type_ids"] = batch_inputs["token_type_ids"]
501+
token_embeddings = self.predictor.run(None, input_dict)[0]
502+
if self.pooling_mode == "max_tokens":
503+
attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype(
504+
token_embeddings.dtype
505+
)
506+
input_mask_expanded = np.expand_dims(attention_mask, -1).repeat(
507+
token_embeddings.shape[-1], axis=-1
508+
)
509+
token_embeddings[input_mask_expanded == 0] = -1e9
510+
max_over_time = np.max(token_embeddings, 1)
511+
all_feats.append(max_over_time)
512+
elif self.pooling_mode == "mean_tokens" or self.pooling_mode == "mean_sqrt_len_tokens":
513+
attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype(
514+
token_embeddings.dtype
515+
)
516+
input_mask_expanded = np.expand_dims(attention_mask, -1).repeat(
517+
token_embeddings.shape[-1], axis=-1
518+
)
519+
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1)
520+
sum_mask = input_mask_expanded.sum(1)
521+
sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=np.max(sum_mask))
522+
if self.pooling_mode == "mean_tokens":
523+
all_feats.append(sum_embeddings / sum_mask)
524+
elif self.pooling_mode == "mean_sqrt_len_tokens":
525+
all_feats.append(sum_embeddings / np.sqrt(sum_mask))
526+
else:
527+
cls_token = token_embeddings[:, 0]
528+
all_feats.append(cls_token)
529+
else:
530+
with dygraph_mode_guard():
531+
for batch_inputs in inputs["batches"]:
532+
batch_inputs = self._collator(batch_inputs)
533+
token_embeddings = self._model(input_ids=batch_inputs["input_ids"])[0]
534+
if self.pooling_mode == "max_tokens":
535+
attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype(
536+
self._model.pooler.dense.weight.dtype
537+
)
538+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.shape)
539+
token_embeddings[input_mask_expanded == 0] = -1e9
540+
max_over_time = paddle.max(token_embeddings, 1)
541+
all_feats.append(max_over_time)
542+
543+
elif self.pooling_mode == "mean_tokens" or self.pooling_mode == "mean_sqrt_len_tokens":
544+
attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype(
545+
self._model.pooler.dense.weight.dtype
546+
)
547+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.shape)
548+
sum_embeddings = paddle.sum(token_embeddings * input_mask_expanded, 1)
549+
sum_mask = input_mask_expanded.sum(1)
550+
sum_mask = paddle.clip(sum_mask, min=1e-9)
551+
if self.pooling_mode == "mean_tokens":
552+
all_feats.append(sum_embeddings / sum_mask)
553+
elif self.pooling_mode == "mean_sqrt_len_tokens":
554+
all_feats.append(sum_embeddings / paddle.sqrt(sum_mask))
555+
else:
556+
cls_token = token_embeddings[:, 0]
557+
all_feats.append(cls_token)
558+
inputs.update({"features": all_feats})
559+
return inputs
560+
561+
def _postprocess(self, inputs):
562+
inputs["features"] = np.concatenate(inputs["features"], axis=0)
563+
inputs["features"] = [inputs["features"][idx] for idx in np.argsort(self.length_sorted_idx)]
564+
565+
if self.return_tensors == "pd":
566+
inputs["features"] = paddle.to_tensor(inputs["features"])
567+
return inputs
568+
569+
def _convert_dygraph_to_static(self):
570+
"""
571+
Convert the dygraph model to static model.
572+
"""
573+
assert (
574+
self._model is not None
575+
), "The dygraph model must be created before converting the dygraph model to static model."
576+
assert (
577+
self._input_spec is not None
578+
), "The input spec must be created before converting the dygraph model to static model."
579+
logger.info("Converting to the inference model cost a little time.")
580+
581+
static_model = paddle.jit.to_static(self._model, input_spec=self._input_spec)
582+
paddle.jit.save(static_model, self.inference_model_path)
583+
logger.info("The inference model save in the path:{}".format(self.inference_model_path))

pipelines/examples/semantic-search/semantic_search_example.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
parser.add_argument('--port', type=str, default="8530", help='port of ANN search engine')
4040
parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding")
4141
parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types")
42+
parser.add_argument('--pooling_mode', choices=['max_tokens', 'mean_tokens', 'mean_sqrt_len_tokens', 'cls_token'], default='cls_token', help='the type of sentence embedding')
4243
args = parser.parse_args()
4344
# yapf: enable
4445

@@ -59,6 +60,8 @@ def get_faiss_retriever(use_gpu):
5960
batch_size=args.retriever_batch_size,
6061
use_gpu=use_gpu,
6162
embed_title=args.embed_title,
63+
pooling_mode=args.pooling_mode,
64+
precision="fp16",
6265
)
6366
else:
6467
doc_dir = "data/dureader_dev"
@@ -86,6 +89,8 @@ def get_faiss_retriever(use_gpu):
8689
batch_size=args.retriever_batch_size,
8790
use_gpu=use_gpu,
8891
embed_title=args.embed_title,
92+
pooling_mode=args.pooling_mode,
93+
precision="fp16",
8994
)
9095

9196
# update Embedding
@@ -120,6 +125,8 @@ def get_milvus_retriever(use_gpu):
120125
batch_size=args.retriever_batch_size,
121126
use_gpu=use_gpu,
122127
embed_title=args.embed_title,
128+
pooling_mode=args.pooling_mode,
129+
precision="fp16",
123130
)
124131
else:
125132
doc_dir = "data/dureader_dev"
@@ -146,6 +153,8 @@ def get_milvus_retriever(use_gpu):
146153
batch_size=args.retriever_batch_size,
147154
use_gpu=use_gpu,
148155
embed_title=args.embed_title,
156+
pooling_mode=args.pooling_mode,
157+
precision="fp16",
149158
)
150159

151160
document_store.write_documents(dicts)
@@ -164,15 +173,17 @@ def semantic_search_tutorial():
164173
else:
165174
retriever = get_faiss_retriever(use_gpu)
166175

167-
# Ranker
168-
ranker = ErnieRanker(model_name_or_path="rocketqa-zh-dureader-cross-encoder", use_gpu=use_gpu)
169-
170176
# Pipeline
171177
from pipelines import SemanticSearchPipeline
172178

173-
pipe = SemanticSearchPipeline(retriever, ranker)
174-
175-
prediction = pipe.run(query="亚马逊河流的介绍", params={"Retriever": {"top_k": 50}, "Ranker": {"top_k": 5}})
179+
if args.query_embedding_model == "moka-ai/m3e-base" or args.passage_embedding_model == "moka-ai/m3e-base":
180+
pipe = SemanticSearchPipeline(retriever)
181+
prediction = pipe.run(query="亚马逊河流的介绍", params={"Retriever": {"top_k": 50}})
182+
else:
183+
# Ranker
184+
ranker = ErnieRanker(model_name_or_path="rocketqa-zh-dureader-cross-encoder", use_gpu=use_gpu)
185+
pipe = SemanticSearchPipeline(retriever, ranker)
186+
prediction = pipe.run(query="亚马逊河流的介绍", params={"Retriever": {"top_k": 50}, "Ranker": {"top_k": 5}})
176187

177188
print_documents(prediction)
178189

0 commit comments

Comments
 (0)