1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
1
16
import json
17
+ import math
18
+ import os
2
19
3
20
import paddle
4
21
import paddle .nn as nn
22
+ import pandas as pd
5
23
from paddlenlp .datasets import MapDataset
6
24
from paddlenlp .data import Stack , Pad , Tuple
7
25
from paddlenlp .transformers import ErnieCtmWordtagModel , ErnieCtmTokenizer
8
26
27
+ LABEL_TO_SCHEMA = {
28
+ "人物类_实体" : ["人物|E" , "虚拟角色|E" , "演艺团体|E" ],
29
+ "人物类_概念" : ["人物|C" , "虚拟角色|C" ],
30
+ "作品类_实体" : ["作品与出版物|E" ],
31
+ "作品类_概念" : ["作品与出版物|C" , "文化类" ],
32
+ "组织机构类" : ["组织机构" ],
33
+ "组织机构类_企事业单位" : ["企事业单位" , "品牌" , "组织机构" ],
34
+ "组织机构类_医疗卫生机构" : ["医疗卫生机构" , "组织机构" ],
35
+ "组织机构类_国家机关" : ["国家机关" , "组织机构" ],
36
+ "组织机构类_体育组织机构" : ["体育组织机构" , "组织机构" ],
37
+ "组织机构类_教育组织机构" : ["教育组织机构" , "组织机构" ],
38
+ "组织机构类_军事组织机构" : ["军事组织机构" , "组织机构" ],
39
+ "物体类" : ["物体与物品" , "品牌" , "虚拟物品" , "虚拟物品" ],
40
+ "物体类_兵器" : ["兵器" ],
41
+ "物体类_化学物质" : ["物体与物品" , "化学术语" ],
42
+ "其他角色类" : ["角色" ],
43
+ "文化类" : ["文化" , "作品与出版物|C" , "体育运动项目" , "语言文字" ],
44
+ "文化类_语言文字" : ["语言学术语" ],
45
+ "文化类_奖项赛事活动" : ["奖项赛事活动" , "特殊日" , "事件" ],
46
+ "文化类_制度政策协议" : ["制度政策协议" , "法律法规" ],
47
+ "文化类_姓氏与人名" : ["姓氏与人名" ],
48
+ "生物类" : ["生物" ],
49
+ "生物类_植物" : ["植物" , "生物" ],
50
+ "生物类_动物" : ["动物" , "生物" ],
51
+ "品牌名" : ["品牌" , "企事业单位" ],
52
+ "场所类" : ["区域场所" , "居民服务机构" , "医疗卫生机构" ],
53
+ "场所类_交通场所" : ["交通场所" , "设施" ],
54
+ "位置方位" : ["位置方位" ],
55
+ "世界地区类" : ["世界地区" , "区域场所" , "政权朝代" ],
56
+ "饮食类" : ["饮食" , "生物类" , "药物" ],
57
+ "饮食类_菜品" : ["饮食" ],
58
+ "饮食类_饮品" : ["饮食" ],
59
+ "药物类" : ["药物" , "生物类" ],
60
+ "药物类_中药" : ["药物" , "生物类" ],
61
+ "医学术语类" : ["医药学术语" ],
62
+ "术语类_生物体" : ["生物学术语" ],
63
+ "疾病损伤类" : ["疾病损伤" , "动物疾病" , "医药学术语" ],
64
+ "疾病损伤类_植物病虫害" : ["植物病虫害" , "医药学术语" ],
65
+ "宇宙类" : ["天文学术语" ],
66
+ "事件类" : ["事件" , "奖项赛事活动" ],
67
+ "时间类" : ["时间阶段" , "政权朝代" ],
68
+ "术语类" : ["术语" ],
69
+ "术语类_符号指标类" : ["编码符号指标" , "术语" ],
70
+ "信息资料" : ["生活用语" ],
71
+ "链接地址" : ["生活用语" ],
72
+ "个性特征" : ["个性特点" , "生活用语" ],
73
+ "感官特征" : ["生活用语" ],
74
+ "场景事件" : ["场景事件" , "情绪" , "态度" , "个性特点" ],
75
+ "介词" : ["介词" ],
76
+ "介词_方位介词" : ["介词" ],
77
+ "助词" : ["助词" ],
78
+ "代词" : ["代词" ],
79
+ "连词" : ["连词" ],
80
+ "副词" : ["副词" ],
81
+ "疑问词" : ["疑问词" ],
82
+ "肯定词" : ["肯定否定词" ],
83
+ "否定词" : ["肯定否定词" ],
84
+ "数量词" : ["数量词" , "量词" ],
85
+ "叹词" : ["叹词" ],
86
+ "拟声词" : ["拟声词" ],
87
+ "修饰词" : ["修饰词" , "生活用语" ],
88
+ "外语单词" : ["日文假名" , "词汇用语" ],
89
+ "汉语拼音" : ["汉语拼音" ],
90
+ }
91
+
9
92
10
93
class WordtagPredictor (object ):
11
94
"""Predictor of wordtag model.
12
95
"""
13
96
14
- def __init__ (self , model_dir , tag_path , linking_path = None ):
97
+ def __init__ (self ,
98
+ model_dir ,
99
+ tag_path ,
100
+ term_schema_path = None ,
101
+ term_data_path = None ):
15
102
"""Initialize method of the predictor.
16
103
17
104
Args:
18
- model_dir: The pre-trained model checkpoint dir.
19
- tag_path: The tag vocab path.
20
- linking_path:if you want to use linking mode, you should load link feature using.
105
+ model_dir (`str`):
106
+ The pre-trained model checkpoint dir.
107
+ tag_path (`str`):
108
+ The tag vocab path.
109
+ term_schema_path (`str`, optional):
110
+ if you want to use linking mode, you should load term schema. Defaults to ``None``.
111
+ term_data_path (`str`, optional):
112
+ if you want to use linking mode, you should load term data. Defaults to ``None``.
21
113
"""
22
114
self ._tags_to_index , self ._index_to_tags = self ._load_labels (tag_path )
23
115
@@ -30,28 +122,27 @@ def __init__(self, model_dir, tag_path, linking_path=None):
30
122
31
123
self ._tokenizer = ErnieCtmTokenizer .from_pretrained (model_dir )
32
124
self ._summary_num = self ._model .ernie_ctm .content_summary_index + 1
33
- self .linking = False
34
- if linking_path is not None :
35
- self .linking_dict = {}
36
- with open (linking_path , encoding = "utf-8" ) as fp :
37
- for line in fp :
38
- data = json .loads (line )
39
- if data ["label" ] not in self .linking_dict :
40
- self .linking_dict [data ["label" ]] = []
41
- self .linking_dict [data ["label" ]].append ({
42
- "sid" : data ["sid" ],
43
- "cls" : paddle .to_tensor (data ["cls1" ]).unsqueeze (0 ),
44
- "term" : paddle .to_tensor (data ["term" ]).unsqueeze (0 )
45
- })
46
- self .linking = True
47
- self .sim_fct = nn .CosineSimilarity (dim = 1 )
125
+ if term_schema_path is not None :
126
+ self ._term_schema = self ._load_schema (term_schema_path )
127
+ if term_data_path is not None :
128
+ self ._term_dict = self ._load_term_tree_data (term_data_path )
129
+ if term_data_path is not None and term_schema_path is not None :
130
+ self ._linking = True
131
+ else :
132
+ self ._linking = False
48
133
49
134
@property
50
135
def summary_num (self ):
51
136
"""Number of model summary token
52
137
"""
53
138
return self ._summary_num
54
139
140
+ @property
141
+ def linking (self ):
142
+ """Whether to do term linking.
143
+ """
144
+ return self ._linking
145
+
55
146
@staticmethod
56
147
def _load_labels (tag_path ):
57
148
tags_to_idx = {}
@@ -64,9 +155,52 @@ def _load_labels(tag_path):
64
155
idx_to_tags = dict (zip (* (tags_to_idx .values (), tags_to_idx .keys ())))
65
156
return tags_to_idx , idx_to_tags
66
157
158
+ @staticmethod
159
+ def _load_schema (schema_path ):
160
+ schema_df = pd .read_csv (schema_path , sep = "\t " , encoding = "gb2312" )
161
+ schema = {}
162
+ for idx in range (schema_df .shape [0 ]):
163
+ if not isinstance (schema_df ["type-1" ][idx ], float ):
164
+ schema [schema_df ["type-1" ][idx ]] = "root"
165
+ if not isinstance (schema_df ["type-2" ][idx ], float ):
166
+ schema [schema_df ["type-2" ][idx ]] = schema_df ["type-1" ][idx ]
167
+ if not isinstance (schema_df ["type-3" ][idx ], float ):
168
+ schema [schema_df ["type-3" ][idx ]] = schema_df ["type-2" ][idx ]
169
+ return schema
170
+
171
+ @staticmethod
172
+ def _load_term_tree_data (term_tree_name_or_path ):
173
+ if os .path .isdir (term_tree_name_or_path ):
174
+ fn_list = glob .glob (f"{ term_tree_name_or_path } /*" , recursive = True )
175
+ else :
176
+ fn_list = [term_tree_name_or_path ]
177
+ term_dict = {}
178
+ for fn in fn_list :
179
+ with open (fn , encoding = "utf-8" ) as fp :
180
+ for line in fp :
181
+ data = json .loads (line )
182
+ if data ["term" ] not in term_dict :
183
+ term_dict [data ["term" ]] = {}
184
+ if data ["termtype" ] not in term_dict [data ["term" ]]:
185
+ term_dict [data ["term" ]][data ["termtype" ]] = []
186
+ term_dict [data ["term" ]][data ["termtype" ]].append (data [
187
+ "termid" ])
188
+ for alia in data ["alias" ]:
189
+ if alia not in term_dict :
190
+ term_dict [alia ] = {}
191
+ if data ["termtype" ] not in term_dict [alia ]:
192
+ term_dict [alia ][data ["termtype" ]] = []
193
+ term_dict [alia ][data ["termtype" ]].append (data ["termid" ])
194
+ for alia in data ["alias_ext" ]:
195
+ if alia not in term_dict :
196
+ term_dict [alia ] = {}
197
+ if data ["termtype" ] not in term_dict [alia ]:
198
+ term_dict [alia ][data ["termtype" ]] = []
199
+ term_dict [alia ][data ["termtype" ]].append (data ["termid" ])
200
+ return term_dict
201
+
67
202
def _pre_process_text (self , input_texts , max_seq_len = 128 , batch_size = 1 ):
68
203
infer_data = []
69
- max_length = 0
70
204
for text in input_texts :
71
205
tokens = ["[CLS%i]" % i
72
206
for i in range (1 , self .summary_num )] + list (text )
@@ -170,45 +304,57 @@ def run(self,
170
304
all_pred_tags += pred_tags .numpy ().tolist ()
171
305
172
306
results = self ._decode (input_texts , all_pred_tags )
307
+ if self .linking is True :
308
+ for res in results :
309
+ self ._term_linking (res )
173
310
outputs = results
174
311
if return_hidden_states is True :
175
312
outputs = (results , ) + (seq_logits , cls_logits )
176
313
return outputs
177
314
178
- def _post_linking (self , pred_res , hidden_states ):
179
- for pred in pred_res :
180
- for item in pred ["items" ]:
181
- if item ["item" ] in self .linking_dict :
182
- item_vectors = self .linking_dict [item ["item" ]]
183
- item_pred_vector = hidden_states [1 ]
184
-
185
- res = []
186
- for item_vector in item_vectors :
187
- vec = item_vector ["cls" ]
188
- similarity = self .sim_fct (vec , item_pred_vector )
189
- res .append ({
190
- "sid" : item_vector ["sid" ],
191
- "cosine" : similarity .item ()
192
- })
193
- res .sort (key = lambda d : - d ["cosine" ])
194
- item ["link" ] = res
195
-
196
- def run_with_link (self , input_text ):
197
- """Predict wordtag results with term linking.
315
+ def _term_linking (self , wordtag_res ):
316
+ for item in wordtag_res ["items" ]:
317
+ if item ["wordtag_label" ] not in LABEL_TO_SCHEMA :
318
+ continue
319
+ if item ["item" ] not in self ._term_dict :
320
+ continue
321
+ target_type = LABEL_TO_SCHEMA [item ["wordtag_label" ]]
322
+ matched_type = list (self ._term_dict [item ["item" ]].keys ())
323
+ matched = False
324
+ term_id = None
325
+ target_idx = math .inf
326
+ for mt in matched_type :
327
+ tmp_type = mt
328
+ while tmp_type != "root" :
329
+ if tmp_type not in self ._term_schema :
330
+ break
331
+ for i , target in enumerate (target_type ):
332
+ if target .startswith (tmp_type ):
333
+ target_src = target .split ("|" )
334
+ for can_term_id in self ._term_dict [item ["item" ]][
335
+ mt ]:
336
+ tmp_term_id = can_term_id
337
+ if len (target_src ) == 1 :
338
+ matched = True
339
+ if i < target_idx :
340
+ target_idx = i
341
+ term_id = tmp_term_id
342
+ else :
343
+ if target_src [
344
+ 1 ] == "C" and "_cb_" in tmp_term_id :
345
+ matched = True
346
+ if i < target_idx :
347
+ target_idx = i
348
+ term_id = tmp_term_id
349
+ if target_src [
350
+ 1 ] == "E" and "_eb_" in tmp_term_id :
351
+ matched = True
352
+ if i < target_idx :
353
+ target_idx = i
354
+ term_id = tmp_term_id
355
+ tmp_type = self ._term_schema [tmp_type ]
356
+ if matched is True :
357
+ break
198
358
199
- Args:
200
- input_text: input text
201
-
202
- Raises:
203
- ValueError: raise ValueError if is not linking mode.
204
-
205
- Returns:
206
- pred_res: result with linking.
207
- """
208
- if self .linking is False :
209
- raise ValueError (
210
- "Not linking mode, you should initialize object by ``WordtagPredictor(model_dir, linking_path)``."
211
- )
212
- pred_res = self .run (input_text , return_hidden_states = True )
213
- self ._post_linking (pred_res [0 ], pred_res [1 :])
214
- return pred_res [0 ]
359
+ if matched is True :
360
+ item ["termid" ] = term_id
0 commit comments