Skip to content

Commit 6ccc0be

Browse files
authored
Add template postprocess module for faster_tokenizer (#2516)
* Add template postprocessing * add template json * move template.cc to faster_tokenizers * Add template processing pybind Conflicts: faster_tokenizer/faster_tokenizer/src/core/CMakeLists.txt * tmp template pybind * Fix template processor pybind __init__ * Fix template processing overflowing bug * fix from_json of template piece
1 parent 1f446ff commit 6ccc0be

File tree

11 files changed

+853
-7
lines changed

11 files changed

+853
-7
lines changed

faster_tokenizer/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ if(WITH_PYTHON)
5959

6060
add_subdirectory(python)
6161
add_custom_target(build_tokenizers_bdist_wheel ALL
62-
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel
62+
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel --plat-name=manylinux1_x86_64
6363
DEPENDS copy_python_tokenizers)
6464

6565
else(WITH_PYTHON)

faster_tokenizer/faster_tokenizer/include/core/tokenizer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ class Tokenizer {
213213
AddedVocabulary added_vocabulary_;
214214
bool use_truncation_;
215215
bool use_padding_;
216-
// TODO(zhoushunjie): Implement Decoder later.
216+
217217
friend void to_json(nlohmann::json& j, const Tokenizer& tokenizer);
218218
friend void from_json(const nlohmann::json& j, Tokenizer& tokenizer);
219219
};

faster_tokenizer/faster_tokenizer/include/postprocessors/postprocessors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ limitations under the License. */
1616

1717
#include "postprocessors/bert.h"
1818
#include "postprocessors/postprocessor.h"
19+
#include "postprocessors/template.h"
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <unordered_map>
19+
#include <vector>
20+
21+
#include "boost/variant.hpp"
22+
#include "glog/logging.h"
23+
#include "nlohmann/json.hpp"
24+
#include "postprocessors/postprocessor.h"
25+
26+
namespace tokenizers {
27+
namespace postprocessors {
28+
29+
enum SequenceType { SEQ_A, SEQ_B };
30+
NLOHMANN_JSON_SERIALIZE_ENUM(SequenceType,
31+
{
32+
{SEQ_A, "A"}, {SEQ_B, "B"},
33+
});
34+
// The template indicate `${Id} : ${TypeId}`
35+
using TemplateSequence = std::pair<SequenceType, uint>;
36+
using TemplateSpecialToken = std::pair<std::string, uint>;
37+
38+
using TemplatePiece = boost::variant<TemplateSequence, TemplateSpecialToken>;
39+
void to_json(nlohmann::json& j, const TemplatePiece& template_piece);
40+
void from_json(const nlohmann::json& j, TemplatePiece& template_piece);
41+
42+
void ParseIdFromString(const std::string& template_id_string,
43+
TemplatePiece* template_piece);
44+
void SetTypeId(uint type_id, TemplatePiece* template_piece);
45+
void GetTemplatePieceFromString(const std::string& template_string,
46+
TemplatePiece* template_piece);
47+
48+
struct SpecialToken {
49+
std::string id_;
50+
std::vector<uint> ids_;
51+
std::vector<std::string> tokens_;
52+
SpecialToken() = default;
53+
SpecialToken(const std::string& id,
54+
const std::vector<uint>& ids,
55+
const std::vector<std::string>& tokens)
56+
: id_(id), ids_(ids), tokens_(tokens) {}
57+
SpecialToken(const std::string& token, uint id) {
58+
id_ = token;
59+
ids_.push_back(id);
60+
tokens_.push_back(token);
61+
}
62+
friend void to_json(nlohmann::json& j, const SpecialToken& special_token);
63+
friend void from_json(const nlohmann::json& j, SpecialToken& special_token);
64+
};
65+
66+
struct Template {
67+
std::vector<TemplatePiece> pieces_;
68+
Template() = default;
69+
explicit Template(const std::string& template_str) {
70+
std::vector<std::string> pieces;
71+
72+
// Parse the pieces
73+
size_t start = template_str.find_first_not_of(" ");
74+
size_t pos;
75+
while ((pos = template_str.find_first_of(" ", start)) !=
76+
std::string::npos) {
77+
pieces.push_back(template_str.substr(start, pos - start));
78+
start = template_str.find_first_not_of(" ", pos);
79+
}
80+
if (start != std::string::npos) {
81+
pieces.push_back(template_str.substr(start));
82+
}
83+
AddStringPiece(pieces);
84+
}
85+
86+
explicit Template(const std::vector<TemplatePiece>& pieces)
87+
: pieces_(pieces) {}
88+
explicit Template(const std::vector<std::string>& pieces) {
89+
AddStringPiece(pieces);
90+
}
91+
92+
void GetPiecesFromVec(const std::vector<std::string>& pieces) {
93+
AddStringPiece(pieces);
94+
}
95+
96+
void GetPiecesFromStr(const std::string& template_str) {
97+
std::vector<std::string> pieces;
98+
99+
// Parse the pieces
100+
size_t start = template_str.find_first_not_of(" ");
101+
size_t pos;
102+
while ((pos = template_str.find_first_of(" ", start)) !=
103+
std::string::npos) {
104+
pieces.push_back(template_str.substr(start, pos - start));
105+
start = template_str.find_first_not_of(" ", pos);
106+
}
107+
if (start != std::string::npos) {
108+
pieces.push_back(template_str.substr(start));
109+
}
110+
AddStringPiece(pieces);
111+
}
112+
113+
void Clean() { pieces_.clear(); }
114+
115+
private:
116+
void AddStringPiece(const std::vector<std::string>& pieces) {
117+
for (auto&& piece : pieces) {
118+
TemplatePiece template_piece;
119+
GetTemplatePieceFromString(piece, &template_piece);
120+
if (boost::get<TemplateSequence>(&template_piece)) {
121+
pieces_.push_back(boost::get<TemplateSequence>(template_piece));
122+
} else {
123+
pieces_.push_back(boost::get<TemplateSpecialToken>(template_piece));
124+
}
125+
}
126+
}
127+
128+
friend void to_json(nlohmann::json& j, const Template& template_);
129+
friend void from_json(const nlohmann::json& j, Template& template_);
130+
};
131+
132+
struct SpecialTokensMap {
133+
std::unordered_map<std::string, SpecialToken> tokens_map_;
134+
SpecialTokensMap() = default;
135+
explicit SpecialTokensMap(const std::vector<SpecialToken>& special_tokens) {
136+
SetTokensMap(special_tokens);
137+
}
138+
void SetTokensMap(const std::vector<SpecialToken>& special_tokens) {
139+
tokens_map_.clear();
140+
for (const auto& special_token : special_tokens) {
141+
tokens_map_.insert({special_token.id_, special_token});
142+
}
143+
}
144+
friend void to_json(nlohmann::json& j, const SpecialTokensMap& tokens_map);
145+
friend void from_json(const nlohmann::json& j, SpecialTokensMap& tokens_map);
146+
};
147+
148+
struct TemplatePostProcessor : public PostProcessor {
149+
TemplatePostProcessor();
150+
TemplatePostProcessor(const Template&,
151+
const Template&,
152+
const std::vector<SpecialToken>&);
153+
154+
virtual void operator()(core::Encoding* encoding,
155+
core::Encoding* pair_encoding,
156+
bool add_special_tokens,
157+
core::Encoding* result_encoding) const override;
158+
virtual size_t AddedTokensNum(bool is_pair) const override;
159+
160+
void UpdateSinglePieces(const std::string& template_str);
161+
void UpdateSinglePieces(const std::vector<std::string>& pieces);
162+
void UpdatePairPieces(const std::string& template_str);
163+
void UpdatePairPieces(const std::vector<std::string>& pieces);
164+
void UpdateAddedTokensNum();
165+
void SetTokensMap(const std::vector<SpecialToken>& special_tokens);
166+
size_t CountAdded(Template* template_,
167+
const SpecialTokensMap& special_tokens_map);
168+
size_t DefaultAdded(bool is_single = true);
169+
void ApplyTemplate(const Template& pieces,
170+
core::Encoding* encoding,
171+
core::Encoding* pair_encoding,
172+
bool add_special_tokens,
173+
core::Encoding* result_encoding) const;
174+
175+
friend void to_json(nlohmann::json& j,
176+
const TemplatePostProcessor& template_postprocessor);
177+
friend void from_json(const nlohmann::json& j,
178+
TemplatePostProcessor& template_postprocessor);
179+
180+
Template single_;
181+
Template pair_;
182+
size_t added_single_;
183+
size_t added_pair_;
184+
SpecialTokensMap special_tokens_map_;
185+
};
186+
187+
} // postprocessors
188+
} // tokenizers
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cc_library(added_vocabulary SRCS added_vocabulary.cc DEPS normalizers pretokenizers json)
2-
cc_library(tokenizer SRCS tokenizer.cc DEPS added_vocabulary json decoders trie models)
2+
cc_library(tokenizer SRCS tokenizer.cc DEPS added_vocabulary json decoders trie models postprocessors)
33
cc_library(core SRCS encoding.cc DEPS json)
44
add_dependencies(tokenizer extern_boost)

faster_tokenizer/faster_tokenizer/src/core/tokenizer.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,11 @@ void to_json(nlohmann::json& j, const Tokenizer& tokenizer) {
517517
typeid(postprocessors::BertPostProcessor)) {
518518
j["postprocessor"] = *dynamic_cast<postprocessors::BertPostProcessor*>(
519519
tokenizer.post_processor_.get());
520+
} else if (typeid(*tokenizer.post_processor_.get()) ==
521+
typeid(postprocessors::TemplatePostProcessor)) {
522+
j["postprocessor"] =
523+
*dynamic_cast<postprocessors::TemplatePostProcessor*>(
524+
tokenizer.post_processor_.get());
520525
}
521526
}
522527

@@ -611,6 +616,10 @@ void from_json(const nlohmann::json& j, Tokenizer& tokenizer) {
611616
postprocessors::BertPostProcessor bert_postprocessor;
612617
post_processor.get_to(bert_postprocessor);
613618
tokenizer.SetPostProcessor(bert_postprocessor);
619+
} else if (post_processor.at("type") == "TemplateProcessing") {
620+
postprocessors::TemplatePostProcessor template_postprocessor;
621+
post_processor.get_to(template_postprocessor);
622+
tokenizer.SetPostProcessor(template_postprocessor);
614623
}
615624
}
616625

@@ -686,7 +695,10 @@ template void Tokenizer::SetModel(const models::FasterWordPiece&);
686695
// Instantiate processors
687696
template void Tokenizer::SetPostProcessor(
688697
const postprocessors::BertPostProcessor&);
698+
template void Tokenizer::SetPostProcessor(
699+
const postprocessors::TemplatePostProcessor&);
689700

701+
// Instantiate Decoder
690702
template void Tokenizer::SetDecoder(const decoders::WordPiece& decoder);
691703
} // core
692704
} // tokenizers
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
cc_library(models SRCS wordpiece.cc faster_wordpiece.cc DEPS core json trie failure)
1+
cc_library(models SRCS wordpiece.cc faster_wordpiece.cc DEPS core json trie failure icu)
22
add_dependencies(models extern_boost)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cc_library(postprocessors SRCS bert.cc postprocessor.cc DEPS core json)
1+
cc_library(postprocessors SRCS bert.cc postprocessor.cc template.cc DEPS core json)

0 commit comments

Comments
 (0)