|
18 | 18 | import paddle
|
19 | 19 |
|
20 | 20 | from paddlenlp.data import DataCollatorWithPadding
|
21 |
| -from paddlenlp.transformers import AutoTokenizer, ErnieDualEncoder |
| 21 | +from paddlenlp.transformers import AutoModel, AutoTokenizer, ErnieDualEncoder |
22 | 22 |
|
23 | 23 | from ..utils.log import logger
|
24 | 24 | from .task import Task
|
@@ -315,3 +315,269 @@ def _convert_dygraph_to_static(self):
|
315 | 315 | static_model = paddle.jit.to_static(self._model.get_pooled_embedding, input_spec=self._input_spec)
|
316 | 316 | paddle.jit.save(static_model, self.inference_model_path)
|
317 | 317 | logger.info("The inference model save in the path:{}".format(self.inference_model_path))
|
| 318 | + |
| 319 | + |
| 320 | +def text_length(text): |
| 321 | + # {key: value} case |
| 322 | + if isinstance(text, dict): |
| 323 | + return len(next(iter(text.values()))) |
| 324 | + # Object has no len() method |
| 325 | + elif not hasattr(text, "__len__"): |
| 326 | + return 1 |
| 327 | + # Empty string or list of ints |
| 328 | + elif len(text) == 0 or isinstance(text[0], int): |
| 329 | + return len(text) |
| 330 | + # Sum of length of individual strings |
| 331 | + else: |
| 332 | + return sum([len(t) for t in text]) |
| 333 | + |
| 334 | + |
| 335 | +class SentenceFeatureExtractionTask(Task): |
| 336 | + |
| 337 | + resource_files_names = { |
| 338 | + "model_state": "model_state.pdparams", |
| 339 | + "config": "config.json", |
| 340 | + "vocab_file": "vocab.txt", |
| 341 | + "special_tokens_map": "special_tokens_map.json", |
| 342 | + "tokenizer_config": "tokenizer_config.json", |
| 343 | + } |
| 344 | + |
| 345 | + def __init__( |
| 346 | + self, |
| 347 | + task: str = None, |
| 348 | + model: str = None, |
| 349 | + batch_size: int = 1, |
| 350 | + max_seq_len: int = 512, |
| 351 | + _static_mode: bool = True, |
| 352 | + return_tensors: str = "pd", |
| 353 | + pooling_mode: str = "cls_token", |
| 354 | + **kwargs |
| 355 | + ): |
| 356 | + super().__init__( |
| 357 | + task=task, |
| 358 | + model=model, |
| 359 | + pooling_mode=pooling_mode, |
| 360 | + **kwargs, |
| 361 | + ) |
| 362 | + self._seed = None |
| 363 | + self.export_type = "text" |
| 364 | + self._batch_size = batch_size |
| 365 | + self.max_seq_len = max_seq_len |
| 366 | + self.model = model |
| 367 | + self._static_mode = _static_mode |
| 368 | + self.return_tensors = return_tensors |
| 369 | + self.pooling_mode = pooling_mode |
| 370 | + self._check_predictor_type() |
| 371 | + self._construct_tokenizer() |
| 372 | + if self._static_mode: |
| 373 | + self._get_inference_model() |
| 374 | + else: |
| 375 | + self._construct_model(model) |
| 376 | + |
| 377 | + def _construct_model(self, model): |
| 378 | + """ |
| 379 | + Construct the inference model for the predictor. |
| 380 | + """ |
| 381 | + self._model = AutoModel.from_pretrained(self.model) |
| 382 | + self._model.eval() |
| 383 | + |
| 384 | + def _construct_tokenizer(self): |
| 385 | + """ |
| 386 | + Construct the tokenizer for the predictor. |
| 387 | + """ |
| 388 | + self._tokenizer = AutoTokenizer.from_pretrained(self.model) |
| 389 | + self.pad_token_id = self._tokenizer.convert_tokens_to_ids(self._tokenizer.pad_token) |
| 390 | + # Fix windows dtype bug |
| 391 | + if self._static_mode: |
| 392 | + self._collator = DataCollatorWithPadding(self._tokenizer, return_tensors="np") |
| 393 | + else: |
| 394 | + self._collator = DataCollatorWithPadding(self._tokenizer, return_tensors="pd") |
| 395 | + |
| 396 | + def _construct_input_spec(self): |
| 397 | + """ |
| 398 | + Construct the input spec for the convert dygraph model to static model. |
| 399 | + """ |
| 400 | + self._input_spec = [ |
| 401 | + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), |
| 402 | + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"), |
| 403 | + ] |
| 404 | + |
| 405 | + def _batchify(self, data, batch_size): |
| 406 | + """ |
| 407 | + Generate input batches. |
| 408 | + """ |
| 409 | + |
| 410 | + def _parse_batch(batch_examples, max_seq_len=None): |
| 411 | + if isinstance(batch_examples[0], str): |
| 412 | + to_tokenize = [batch_examples] |
| 413 | + else: |
| 414 | + batch1, batch2 = [], [] |
| 415 | + for text_tuple in batch_examples: |
| 416 | + batch1.append(text_tuple[0]) |
| 417 | + batch2.append(text_tuple[1]) |
| 418 | + to_tokenize = [batch1, batch2] |
| 419 | + to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] |
| 420 | + if max_seq_len is None: |
| 421 | + max_seq_len = self.max_seq_len |
| 422 | + tokenized_inputs = self._tokenizer( |
| 423 | + to_tokenize[0], |
| 424 | + padding=True, |
| 425 | + truncation="longest_first", |
| 426 | + max_seq_len=max_seq_len, |
| 427 | + ) |
| 428 | + return tokenized_inputs |
| 429 | + |
| 430 | + # Seperates data into some batches. |
| 431 | + one_batch = [] |
| 432 | + self.length_sorted_idx = np.argsort([-text_length(sen) for sen in data]) |
| 433 | + sentences_sorted = [data[idx] for idx in self.length_sorted_idx] |
| 434 | + for example in range(len(sentences_sorted)): |
| 435 | + one_batch.append(sentences_sorted[example]) |
| 436 | + if len(one_batch) == batch_size: |
| 437 | + yield _parse_batch(one_batch) |
| 438 | + one_batch = [] |
| 439 | + if one_batch: |
| 440 | + yield _parse_batch(one_batch) |
| 441 | + |
| 442 | + def _preprocess(self, inputs): |
| 443 | + """ |
| 444 | + Transform the raw inputs to the model inputs, two steps involved: |
| 445 | + 1) Transform the raw text/image to token ids/pixel_values. |
| 446 | + 2) Generate the other model inputs from the raw text/image and token ids/pixel_values. |
| 447 | + """ |
| 448 | + inputs = self._check_input_text(inputs) |
| 449 | + batches = self._batchify(inputs, self._batch_size) |
| 450 | + outputs = {"batches": batches, "inputs": inputs} |
| 451 | + return outputs |
| 452 | + |
| 453 | + def _run_model(self, inputs): |
| 454 | + """ |
| 455 | + Run the task model from the outputs of the `_preprocess` function. |
| 456 | + """ |
| 457 | + all_feats = [] |
| 458 | + if self._static_mode: |
| 459 | + with static_mode_guard(): |
| 460 | + for batch_inputs in inputs["batches"]: |
| 461 | + batch_inputs = self._collator(batch_inputs) |
| 462 | + if self._predictor_type == "paddle-inference": |
| 463 | + if "input_ids" in batch_inputs: |
| 464 | + self.input_handles[0].copy_from_cpu(batch_inputs["input_ids"]) |
| 465 | + self.input_handles[1].copy_from_cpu(batch_inputs["token_type_ids"]) |
| 466 | + self.predictor.run() |
| 467 | + token_embeddings = self.output_handle[0].copy_to_cpu() |
| 468 | + if self.pooling_mode == "max_tokens": |
| 469 | + attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype( |
| 470 | + token_embeddings.dtype |
| 471 | + ) |
| 472 | + input_mask_expanded = np.expand_dims(attention_mask, -1).repeat( |
| 473 | + token_embeddings.shape[-1], axis=-1 |
| 474 | + ) |
| 475 | + token_embeddings[input_mask_expanded == 0] = -1e9 |
| 476 | + max_over_time = np.max(token_embeddings, 1) |
| 477 | + all_feats.append(max_over_time) |
| 478 | + elif self.pooling_mode == "mean_tokens" or self.pooling_mode == "mean_sqrt_len_tokens": |
| 479 | + attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype( |
| 480 | + token_embeddings.dtype |
| 481 | + ) |
| 482 | + input_mask_expanded = np.expand_dims(attention_mask, -1).repeat( |
| 483 | + token_embeddings.shape[-1], axis=-1 |
| 484 | + ) |
| 485 | + sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1) |
| 486 | + sum_mask = input_mask_expanded.sum(1) |
| 487 | + sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=np.max(sum_mask)) |
| 488 | + if self.pooling_mode == "mean_tokens": |
| 489 | + all_feats.append(sum_embeddings / sum_mask) |
| 490 | + elif self.pooling_mode == "mean_sqrt_len_tokens": |
| 491 | + all_feats.append(sum_embeddings / np.sqrt(sum_mask)) |
| 492 | + else: |
| 493 | + cls_token = token_embeddings[:, 0] |
| 494 | + all_feats.append(cls_token) |
| 495 | + else: |
| 496 | + # onnx mode |
| 497 | + if "input_ids" in batch_inputs: |
| 498 | + input_dict = {} |
| 499 | + input_dict["input_ids"] = batch_inputs["input_ids"] |
| 500 | + input_dict["token_type_ids"] = batch_inputs["token_type_ids"] |
| 501 | + token_embeddings = self.predictor.run(None, input_dict)[0] |
| 502 | + if self.pooling_mode == "max_tokens": |
| 503 | + attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype( |
| 504 | + token_embeddings.dtype |
| 505 | + ) |
| 506 | + input_mask_expanded = np.expand_dims(attention_mask, -1).repeat( |
| 507 | + token_embeddings.shape[-1], axis=-1 |
| 508 | + ) |
| 509 | + token_embeddings[input_mask_expanded == 0] = -1e9 |
| 510 | + max_over_time = np.max(token_embeddings, 1) |
| 511 | + all_feats.append(max_over_time) |
| 512 | + elif self.pooling_mode == "mean_tokens" or self.pooling_mode == "mean_sqrt_len_tokens": |
| 513 | + attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype( |
| 514 | + token_embeddings.dtype |
| 515 | + ) |
| 516 | + input_mask_expanded = np.expand_dims(attention_mask, -1).repeat( |
| 517 | + token_embeddings.shape[-1], axis=-1 |
| 518 | + ) |
| 519 | + sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1) |
| 520 | + sum_mask = input_mask_expanded.sum(1) |
| 521 | + sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=np.max(sum_mask)) |
| 522 | + if self.pooling_mode == "mean_tokens": |
| 523 | + all_feats.append(sum_embeddings / sum_mask) |
| 524 | + elif self.pooling_mode == "mean_sqrt_len_tokens": |
| 525 | + all_feats.append(sum_embeddings / np.sqrt(sum_mask)) |
| 526 | + else: |
| 527 | + cls_token = token_embeddings[:, 0] |
| 528 | + all_feats.append(cls_token) |
| 529 | + else: |
| 530 | + with dygraph_mode_guard(): |
| 531 | + for batch_inputs in inputs["batches"]: |
| 532 | + batch_inputs = self._collator(batch_inputs) |
| 533 | + token_embeddings = self._model(input_ids=batch_inputs["input_ids"])[0] |
| 534 | + if self.pooling_mode == "max_tokens": |
| 535 | + attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype( |
| 536 | + self._model.pooler.dense.weight.dtype |
| 537 | + ) |
| 538 | + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.shape) |
| 539 | + token_embeddings[input_mask_expanded == 0] = -1e9 |
| 540 | + max_over_time = paddle.max(token_embeddings, 1) |
| 541 | + all_feats.append(max_over_time) |
| 542 | + |
| 543 | + elif self.pooling_mode == "mean_tokens" or self.pooling_mode == "mean_sqrt_len_tokens": |
| 544 | + attention_mask = (batch_inputs["input_ids"] != self.pad_token_id).astype( |
| 545 | + self._model.pooler.dense.weight.dtype |
| 546 | + ) |
| 547 | + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.shape) |
| 548 | + sum_embeddings = paddle.sum(token_embeddings * input_mask_expanded, 1) |
| 549 | + sum_mask = input_mask_expanded.sum(1) |
| 550 | + sum_mask = paddle.clip(sum_mask, min=1e-9) |
| 551 | + if self.pooling_mode == "mean_tokens": |
| 552 | + all_feats.append(sum_embeddings / sum_mask) |
| 553 | + elif self.pooling_mode == "mean_sqrt_len_tokens": |
| 554 | + all_feats.append(sum_embeddings / paddle.sqrt(sum_mask)) |
| 555 | + else: |
| 556 | + cls_token = token_embeddings[:, 0] |
| 557 | + all_feats.append(cls_token) |
| 558 | + inputs.update({"features": all_feats}) |
| 559 | + return inputs |
| 560 | + |
| 561 | + def _postprocess(self, inputs): |
| 562 | + inputs["features"] = np.concatenate(inputs["features"], axis=0) |
| 563 | + inputs["features"] = [inputs["features"][idx] for idx in np.argsort(self.length_sorted_idx)] |
| 564 | + |
| 565 | + if self.return_tensors == "pd": |
| 566 | + inputs["features"] = paddle.to_tensor(inputs["features"]) |
| 567 | + return inputs |
| 568 | + |
| 569 | + def _convert_dygraph_to_static(self): |
| 570 | + """ |
| 571 | + Convert the dygraph model to static model. |
| 572 | + """ |
| 573 | + assert ( |
| 574 | + self._model is not None |
| 575 | + ), "The dygraph model must be created before converting the dygraph model to static model." |
| 576 | + assert ( |
| 577 | + self._input_spec is not None |
| 578 | + ), "The input spec must be created before converting the dygraph model to static model." |
| 579 | + logger.info("Converting to the inference model cost a little time.") |
| 580 | + |
| 581 | + static_model = paddle.jit.to_static(self._model, input_spec=self._input_spec) |
| 582 | + paddle.jit.save(static_model, self.inference_model_path) |
| 583 | + logger.info("The inference model save in the path:{}".format(self.inference_model_path)) |
0 commit comments