11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import re
15
14
from typing import Dict
16
15
from typing import List
17
16
@@ -30,7 +29,6 @@ def __init__(self,
30
29
self .zh_frontend = Frontend (
31
30
phone_vocab_path = phone_vocab_path , tone_vocab_path = tone_vocab_path )
32
31
self .en_frontend = English (phone_vocab_path = phone_vocab_path )
33
- self .SENTENCE_SPLITOR = re .compile (r'([:、,;。?!,;?!][”’]?)' )
34
32
self .sp_id = self .zh_frontend .vocab_phones ["sp" ]
35
33
self .sp_id_tensor = paddle .to_tensor ([self .sp_id ])
36
34
@@ -47,188 +45,56 @@ def is_alphabet(self, char):
47
45
else :
48
46
return False
49
47
50
- def is_number (self , char ):
51
- if char >= '\u0030 ' and char <= '\u0039 ' :
52
- return True
53
- else :
54
- return False
55
-
56
48
def is_other (self , char ):
57
- if not (self .is_chinese (char ) or self .is_number (char ) or
58
- self .is_alphabet (char )):
49
+ if not (self .is_chinese (char ) or self .is_alphabet (char )):
59
50
return True
60
51
else :
61
52
return False
62
53
63
- def is_end (self , before_char , after_char ) -> bool :
64
- flag = 0
65
- for char in (before_char , after_char ):
66
- if self .is_alphabet (char ) or char == " " :
67
- flag += 1
68
- if flag == 2 :
69
- return True
70
- else :
71
- return False
72
-
73
- def _replace (self , text : str ) -> str :
74
- new_text = ""
75
-
76
- # get "." indexs
77
- point = "."
78
- point_indexs = []
79
- index = - 1
80
- for i in range (text .count (point )):
81
- index = text .find ("." , index + 1 , len (text ))
82
- point_indexs .append (index )
83
-
84
- # replace "." -> "。" when English sentence ending
85
- if len (point_indexs ) == 0 :
86
- new_text = text
87
-
88
- elif len (point_indexs ) == 1 :
89
- point_index = point_indexs [0 ]
90
- if point_index == 0 or point_index == len (text ) - 1 :
91
- new_text = text
92
- else :
93
- if not self .is_end (text [point_index - 1 ], text [point_index +
94
- 1 ]):
95
- new_text = text
96
- else :
97
- new_text = text [:point_index ] + "。" + text [point_index + 1 :]
98
-
99
- elif len (point_indexs ) == 2 :
100
- first_index = point_indexs [0 ]
101
- end_index = point_indexs [1 ]
102
-
103
- # first
104
- if first_index != 0 :
105
- if not self .is_end (text [first_index - 1 ], text [first_index +
106
- 1 ]):
107
- new_text += (text [:first_index ] + "." )
108
- else :
109
- new_text += (text [:first_index ] + "。" )
110
- else :
111
- new_text += "."
112
- # last
113
- if end_index != len (text ) - 1 :
114
- if not self .is_end (text [end_index - 1 ], text [end_index + 1 ]):
115
- new_text += text [point_indexs [- 2 ] + 1 :]
116
- else :
117
- new_text += (text [point_indexs [- 2 ] + 1 :end_index ] + "。" +
118
- text [end_index + 1 :])
119
- else :
120
- new_text += "."
121
-
122
- else :
123
- first_index = point_indexs [0 ]
124
- end_index = point_indexs [- 1 ]
125
- # first
126
- if first_index != 0 :
127
- if not self .is_end (text [first_index - 1 ], text [first_index +
128
- 1 ]):
129
- new_text += (text [:first_index ] + "." )
130
- else :
131
- new_text += (text [:first_index ] + "。" )
132
- else :
133
- new_text += "."
134
- # middle
135
- for j in range (1 , len (point_indexs ) - 1 ):
136
- point_index = point_indexs [j ]
137
- if not self .is_end (text [point_index - 1 ], text [point_index +
138
- 1 ]):
139
- new_text += (
140
- text [point_indexs [j - 1 ] + 1 :point_index ] + "." )
141
- else :
142
- new_text += (
143
- text [point_indexs [j - 1 ] + 1 :point_index ] + "。" )
144
- # last
145
- if end_index != len (text ) - 1 :
146
- if not self .is_end (text [end_index - 1 ], text [end_index + 1 ]):
147
- new_text += text [point_indexs [- 2 ] + 1 :]
148
- else :
149
- new_text += (text [point_indexs [- 2 ] + 1 :end_index ] + "。" +
150
- text [end_index + 1 :])
151
- else :
152
- new_text += "."
153
-
154
- return new_text
155
-
156
- def _split (self , text : str ) -> List [str ]:
157
- text = re .sub (r'[《》【】<=>{}()()#&@“”^_|…\\]' , '' , text )
158
- # 替换英文句子的句号 "." --> "。" 用于后续分句
159
- text = self ._replace (text )
160
- text = self .SENTENCE_SPLITOR .sub (r'\1\n' , text )
161
- text = text .strip ()
162
- sentences = [sentence .strip () for sentence in re .split (r'\n+' , text )]
163
- return sentences
164
-
165
- def _distinguish (self , text : str ) -> List [str ]:
54
+ def get_segment (self , text : str ) -> List [str ]:
166
55
# sentence --> [ch_part, en_part, ch_part, ...]
167
-
168
56
segments = []
169
57
types = []
170
-
171
58
flag = 0
172
59
temp_seg = ""
173
60
temp_lang = ""
174
61
175
62
# Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
176
63
for ch in text :
177
- if ch == "." :
178
- types .append ("point" )
179
- elif self .is_chinese (ch ):
64
+ if self .is_chinese (ch ):
180
65
types .append ("zh" )
181
66
elif self .is_alphabet (ch ):
182
67
types .append ("en" )
183
- elif ch == " " :
184
- types .append ("blank" )
185
- elif self .is_number (ch ):
186
- types .append ("num" )
187
68
else :
188
- types .append ("unk " )
69
+ types .append ("other " )
189
70
190
71
assert len (types ) == len (text )
191
72
192
73
for i in range (len (types )):
193
-
194
74
# find the first char of the seg
195
75
if flag == 0 :
196
- # 首个字符是中文,英文或者数字
197
- if types [i ] == "zh" or types [i ] == "en" or types [i ] == "num" :
198
- temp_seg += text [i ]
199
- temp_lang = types [i ]
200
- flag = 1
76
+ temp_seg += text [i ]
77
+ temp_lang = types [i ]
78
+ flag = 1
201
79
202
80
else :
203
- # 数字和小数点均与前面的字符合并,类型属于前面一个字符的类型
204
- if types [i ] == temp_lang or types [i ] == "num" or types [
205
- i ] == "point" :
206
- temp_seg += text [i ]
207
-
208
- # 数字与后面的任意字符都拼接
209
- elif temp_lang == "num" :
210
- temp_seg += text [i ]
211
- if types [i ] == "zh" or types [i ] == "en" :
81
+ if temp_lang == "other" :
82
+ if types [i ] == temp_lang :
83
+ temp_seg += text [i ]
84
+ else :
85
+ temp_seg += text [i ]
212
86
temp_lang = types [i ]
213
87
214
- # 如果是空格则与前面字符拼接
215
- elif types [i ] == "blank" :
216
- temp_seg += text [i ]
217
-
218
- elif types [i ] == "unk" :
219
- pass
220
-
221
88
else :
222
- segments .append ((temp_seg , temp_lang ))
223
-
224
- if types [i ] == "zh" or types [i ] == "en" :
89
+ if types [i ] == temp_lang :
90
+ temp_seg += text [i ]
91
+ elif types [i ] == "other" :
92
+ temp_seg += text [i ]
93
+ else :
94
+ segments .append ((temp_seg , temp_lang ))
225
95
temp_seg = text [i ]
226
96
temp_lang = types [i ]
227
97
flag = 1
228
- else :
229
- flag = 0
230
- temp_seg = ""
231
- temp_lang = ""
232
98
233
99
segments .append ((temp_seg , temp_lang ))
234
100
@@ -241,34 +107,30 @@ def get_input_ids(self,
241
107
add_sp : bool = True ,
242
108
to_tensor : bool = True ) -> Dict [str , List [paddle .Tensor ]]:
243
109
244
- sentences = self ._split (sentence )
110
+ segments = self .get_segment (sentence )
111
+
245
112
phones_list = []
246
113
result = {}
247
- for text in sentences :
248
- phones_seg = []
249
- segments = self ._distinguish (text )
250
- for seg in segments :
251
- content = seg [0 ]
252
- lang = seg [1 ]
253
- if content != '' :
254
- if lang == "en" :
255
- input_ids = self .en_frontend .get_input_ids (
256
- content , merge_sentences = True , to_tensor = to_tensor )
257
- else :
258
- input_ids = self .zh_frontend .get_input_ids (
259
- content ,
260
- merge_sentences = True ,
261
- get_tone_ids = get_tone_ids ,
262
- to_tensor = to_tensor )
263
114
264
- phones_seg .append (input_ids ["phone_ids" ][0 ])
265
- if add_sp :
266
- phones_seg .append (self .sp_id_tensor )
267
-
268
- if phones_seg == []:
269
- phones_seg .append (self .sp_id_tensor )
270
- phones = paddle .concat (phones_seg )
271
- phones_list .append (phones )
115
+ for seg in segments :
116
+ content = seg [0 ]
117
+ lang = seg [1 ]
118
+ if content != '' :
119
+ if lang == "en" :
120
+ input_ids = self .en_frontend .get_input_ids (
121
+ content , merge_sentences = False , to_tensor = to_tensor )
122
+ else :
123
+ input_ids = self .zh_frontend .get_input_ids (
124
+ content ,
125
+ merge_sentences = False ,
126
+ get_tone_ids = get_tone_ids ,
127
+ to_tensor = to_tensor )
128
+ if add_sp :
129
+ input_ids ["phone_ids" ][- 1 ] = paddle .concat (
130
+ [input_ids ["phone_ids" ][- 1 ], self .sp_id_tensor ])
131
+
132
+ for phones in input_ids ["phone_ids" ]:
133
+ phones_list .append (phones )
272
134
273
135
if merge_sentences :
274
136
merge_list = paddle .concat (phones_list )
0 commit comments