Skip to content

Commit fd65164

Browse files
authored
add wordtag linking (PaddlePaddle#300)
* add wordtag linking * optimize wordtag open * add termtree_type_csv.csv * optimize code
1 parent 1553b1a commit fd65164

File tree

5 files changed

+225
-66
lines changed

5 files changed

+225
-66
lines changed

examples/information_extraction/wordtag/data.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def read(data_path):
2222
def load_dict(dict_path):
2323
vocab = {}
2424
i = 0
25-
for line in open(dict_path, 'r', encoding='utf-8'):
26-
vocab[line.strip()] = i
27-
i += 1
25+
with open(dict_path, 'r', encoding='utf-8') as fin:
26+
for line in fin:
27+
vocab[line.strip()] = i
28+
i += 1
2829
return vocab
2930

3031

examples/information_extraction/wordtag/download.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
from paddle.utils.download import get_path_from_url
2121

22-
URL = "https://paddlenlp.bj.bcebos.com/paddlenlp/datasets/wordtag_dataset.tar.gz"
22+
URLS = [
23+
"https://paddlenlp.bj.bcebos.com/paddlenlp/datasets/wordtag_dataset.tar.gz",
24+
"https://paddlenlp.bj.bcebos.com/paddlenlp/resource/termtree.rawbase",
25+
"https://paddlenlp.bj.bcebos.com/paddlenlp/resource/termtree_type.csv"
26+
]
2327

2428

2529
def main(arguments):
@@ -31,7 +35,8 @@ def main(arguments):
3135
type=str,
3236
default='./')
3337
args = parser.parse_args(arguments)
34-
get_path_from_url(URL, args.data_dir)
38+
for url in URLS:
39+
get_path_from_url(url, args.data_dir)
3540

3641

3742
if __name__ == '__main__':

examples/information_extraction/wordtag/predict.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import io
16-
import sys
15+
import os
1716
import argparse
1817

1918
import paddle
@@ -38,8 +37,16 @@ def parse_args():
3837

3938
def do_predict(args):
4039
paddle.set_device(args.device)
41-
predictor = WordtagPredictor(args.init_ckpt_dir, "./data/tags.txt")
42-
txts = ['《孤女》是2010年九州出版社出版的小说,作者是余兼羽。', '4分40秒至10分钟只有歌声。']
40+
predictor = WordtagPredictor(
41+
model_dir=args.init_ckpt_dir,
42+
tag_path=os.path.join(args.data_dir, "tags.txt"),
43+
term_schema_path="termtree_type.csv",
44+
term_data_path="termtree.rawbase")
45+
txts = [
46+
"美人鱼是周星驰导演的电影", "小米别熬粥了,加1个苹果,瞬间变小米蛋糕,太香了",
47+
"618不要只知道小米、苹果,这三款产品一样是超级爆款", "天鸿美和院地处黄公望国家森林公园山麓", "你好百度"
48+
]
49+
4350
res = predictor.run(txts)
4451
print(res)
4552

examples/information_extraction/wordtag/predictor.py

+202-56
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,115 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import glob
116
import json
17+
import math
18+
import os
219

320
import paddle
421
import paddle.nn as nn
22+
import pandas as pd
523
from paddlenlp.datasets import MapDataset
624
from paddlenlp.data import Stack, Pad, Tuple
725
from paddlenlp.transformers import ErnieCtmWordtagModel, ErnieCtmTokenizer
826

27+
LABEL_TO_SCHEMA = {
28+
"人物类_实体": ["人物|E", "虚拟角色|E", "演艺团体|E"],
29+
"人物类_概念": ["人物|C", "虚拟角色|C"],
30+
"作品类_实体": ["作品与出版物|E"],
31+
"作品类_概念": ["作品与出版物|C", "文化类"],
32+
"组织机构类": ["组织机构"],
33+
"组织机构类_企事业单位": ["企事业单位", "品牌", "组织机构"],
34+
"组织机构类_医疗卫生机构": ["医疗卫生机构", "组织机构"],
35+
"组织机构类_国家机关": ["国家机关", "组织机构"],
36+
"组织机构类_体育组织机构": ["体育组织机构", "组织机构"],
37+
"组织机构类_教育组织机构": ["教育组织机构", "组织机构"],
38+
"组织机构类_军事组织机构": ["军事组织机构", "组织机构"],
39+
"物体类": ["物体与物品", "品牌", "虚拟物品", "虚拟物品"],
40+
"物体类_兵器": ["兵器"],
41+
"物体类_化学物质": ["物体与物品", "化学术语"],
42+
"其他角色类": ["角色"],
43+
"文化类": ["文化", "作品与出版物|C", "体育运动项目", "语言文字"],
44+
"文化类_语言文字": ["语言学术语"],
45+
"文化类_奖项赛事活动": ["奖项赛事活动", "特殊日", "事件"],
46+
"文化类_制度政策协议": ["制度政策协议", "法律法规"],
47+
"文化类_姓氏与人名": ["姓氏与人名"],
48+
"生物类": ["生物"],
49+
"生物类_植物": ["植物", "生物"],
50+
"生物类_动物": ["动物", "生物"],
51+
"品牌名": ["品牌", "企事业单位"],
52+
"场所类": ["区域场所", "居民服务机构", "医疗卫生机构"],
53+
"场所类_交通场所": ["交通场所", "设施"],
54+
"位置方位": ["位置方位"],
55+
"世界地区类": ["世界地区", "区域场所", "政权朝代"],
56+
"饮食类": ["饮食", "生物类", "药物"],
57+
"饮食类_菜品": ["饮食"],
58+
"饮食类_饮品": ["饮食"],
59+
"药物类": ["药物", "生物类"],
60+
"药物类_中药": ["药物", "生物类"],
61+
"医学术语类": ["医药学术语"],
62+
"术语类_生物体": ["生物学术语"],
63+
"疾病损伤类": ["疾病损伤", "动物疾病", "医药学术语"],
64+
"疾病损伤类_植物病虫害": ["植物病虫害", "医药学术语"],
65+
"宇宙类": ["天文学术语"],
66+
"事件类": ["事件", "奖项赛事活动"],
67+
"时间类": ["时间阶段", "政权朝代"],
68+
"术语类": ["术语"],
69+
"术语类_符号指标类": ["编码符号指标", "术语"],
70+
"信息资料": ["生活用语"],
71+
"链接地址": ["生活用语"],
72+
"个性特征": ["个性特点", "生活用语"],
73+
"感官特征": ["生活用语"],
74+
"场景事件": ["场景事件", "情绪", "态度", "个性特点"],
75+
"介词": ["介词"],
76+
"介词_方位介词": ["介词"],
77+
"助词": ["助词"],
78+
"代词": ["代词"],
79+
"连词": ["连词"],
80+
"副词": ["副词"],
81+
"疑问词": ["疑问词"],
82+
"肯定词": ["肯定否定词"],
83+
"否定词": ["肯定否定词"],
84+
"数量词": ["数量词", "量词"],
85+
"叹词": ["叹词"],
86+
"拟声词": ["拟声词"],
87+
"修饰词": ["修饰词", "生活用语"],
88+
"外语单词": ["日文假名", "词汇用语"],
89+
"汉语拼音": ["汉语拼音"],
90+
}
91+
992

1093
class WordtagPredictor(object):
1194
"""Predictor of wordtag model.
1295
"""
1396

14-
def __init__(self, model_dir, tag_path, linking_path=None):
97+
def __init__(self,
98+
model_dir,
99+
tag_path,
100+
term_schema_path=None,
101+
term_data_path=None):
15102
"""Initialize method of the predictor.
16103
17104
Args:
18-
model_dir: The pre-trained model checkpoint dir.
19-
tag_path: The tag vocab path.
20-
linking_path:if you want to use linking mode, you should load link feature using.
105+
model_dir (`str`):
106+
The pre-trained model checkpoint dir.
107+
tag_path (`str`):
108+
The tag vocab path.
109+
term_schema_path (`str`, optional):
110+
if you want to use linking mode, you should load term schema. Defaults to ``None``.
111+
term_data_path (`str`, optional):
112+
if you want to use linking mode, you should load term data. Defaults to ``None``.
21113
"""
22114
self._tags_to_index, self._index_to_tags = self._load_labels(tag_path)
23115

@@ -30,28 +122,27 @@ def __init__(self, model_dir, tag_path, linking_path=None):
30122

31123
self._tokenizer = ErnieCtmTokenizer.from_pretrained(model_dir)
32124
self._summary_num = self._model.ernie_ctm.content_summary_index + 1
33-
self.linking = False
34-
if linking_path is not None:
35-
self.linking_dict = {}
36-
with open(linking_path, encoding="utf-8") as fp:
37-
for line in fp:
38-
data = json.loads(line)
39-
if data["label"] not in self.linking_dict:
40-
self.linking_dict[data["label"]] = []
41-
self.linking_dict[data["label"]].append({
42-
"sid": data["sid"],
43-
"cls": paddle.to_tensor(data["cls1"]).unsqueeze(0),
44-
"term": paddle.to_tensor(data["term"]).unsqueeze(0)
45-
})
46-
self.linking = True
47-
self.sim_fct = nn.CosineSimilarity(dim=1)
125+
if term_schema_path is not None:
126+
self._term_schema = self._load_schema(term_schema_path)
127+
if term_data_path is not None:
128+
self._term_dict = self._load_term_tree_data(term_data_path)
129+
if term_data_path is not None and term_schema_path is not None:
130+
self._linking = True
131+
else:
132+
self._linking = False
48133

49134
@property
50135
def summary_num(self):
51136
"""Number of model summary token
52137
"""
53138
return self._summary_num
54139

140+
@property
141+
def linking(self):
142+
"""Whether to do term linking.
143+
"""
144+
return self._linking
145+
55146
@staticmethod
56147
def _load_labels(tag_path):
57148
tags_to_idx = {}
@@ -64,9 +155,52 @@ def _load_labels(tag_path):
64155
idx_to_tags = dict(zip(*(tags_to_idx.values(), tags_to_idx.keys())))
65156
return tags_to_idx, idx_to_tags
66157

158+
@staticmethod
159+
def _load_schema(schema_path):
160+
schema_df = pd.read_csv(schema_path, sep="\t", encoding="gb2312")
161+
schema = {}
162+
for idx in range(schema_df.shape[0]):
163+
if not isinstance(schema_df["type-1"][idx], float):
164+
schema[schema_df["type-1"][idx]] = "root"
165+
if not isinstance(schema_df["type-2"][idx], float):
166+
schema[schema_df["type-2"][idx]] = schema_df["type-1"][idx]
167+
if not isinstance(schema_df["type-3"][idx], float):
168+
schema[schema_df["type-3"][idx]] = schema_df["type-2"][idx]
169+
return schema
170+
171+
@staticmethod
172+
def _load_term_tree_data(term_tree_name_or_path):
173+
if os.path.isdir(term_tree_name_or_path):
174+
fn_list = glob.glob(f"{term_tree_name_or_path}/*", recursive=True)
175+
else:
176+
fn_list = [term_tree_name_or_path]
177+
term_dict = {}
178+
for fn in fn_list:
179+
with open(fn, encoding="utf-8") as fp:
180+
for line in fp:
181+
data = json.loads(line)
182+
if data["term"] not in term_dict:
183+
term_dict[data["term"]] = {}
184+
if data["termtype"] not in term_dict[data["term"]]:
185+
term_dict[data["term"]][data["termtype"]] = []
186+
term_dict[data["term"]][data["termtype"]].append(data[
187+
"termid"])
188+
for alia in data["alias"]:
189+
if alia not in term_dict:
190+
term_dict[alia] = {}
191+
if data["termtype"] not in term_dict[alia]:
192+
term_dict[alia][data["termtype"]] = []
193+
term_dict[alia][data["termtype"]].append(data["termid"])
194+
for alia in data["alias_ext"]:
195+
if alia not in term_dict:
196+
term_dict[alia] = {}
197+
if data["termtype"] not in term_dict[alia]:
198+
term_dict[alia][data["termtype"]] = []
199+
term_dict[alia][data["termtype"]].append(data["termid"])
200+
return term_dict
201+
67202
def _pre_process_text(self, input_texts, max_seq_len=128, batch_size=1):
68203
infer_data = []
69-
max_length = 0
70204
for text in input_texts:
71205
tokens = ["[CLS%i]" % i
72206
for i in range(1, self.summary_num)] + list(text)
@@ -170,45 +304,57 @@ def run(self,
170304
all_pred_tags += pred_tags.numpy().tolist()
171305

172306
results = self._decode(input_texts, all_pred_tags)
307+
if self.linking is True:
308+
for res in results:
309+
self._term_linking(res)
173310
outputs = results
174311
if return_hidden_states is True:
175312
outputs = (results, ) + (seq_logits, cls_logits)
176313
return outputs
177314

178-
def _post_linking(self, pred_res, hidden_states):
179-
for pred in pred_res:
180-
for item in pred["items"]:
181-
if item["item"] in self.linking_dict:
182-
item_vectors = self.linking_dict[item["item"]]
183-
item_pred_vector = hidden_states[1]
184-
185-
res = []
186-
for item_vector in item_vectors:
187-
vec = item_vector["cls"]
188-
similarity = self.sim_fct(vec, item_pred_vector)
189-
res.append({
190-
"sid": item_vector["sid"],
191-
"cosine": similarity.item()
192-
})
193-
res.sort(key=lambda d: -d["cosine"])
194-
item["link"] = res
195-
196-
def run_with_link(self, input_text):
197-
"""Predict wordtag results with term linking.
315+
def _term_linking(self, wordtag_res):
316+
for item in wordtag_res["items"]:
317+
if item["wordtag_label"] not in LABEL_TO_SCHEMA:
318+
continue
319+
if item["item"] not in self._term_dict:
320+
continue
321+
target_type = LABEL_TO_SCHEMA[item["wordtag_label"]]
322+
matched_type = list(self._term_dict[item["item"]].keys())
323+
matched = False
324+
term_id = None
325+
target_idx = math.inf
326+
for mt in matched_type:
327+
tmp_type = mt
328+
while tmp_type != "root":
329+
if tmp_type not in self._term_schema:
330+
break
331+
for i, target in enumerate(target_type):
332+
if target.startswith(tmp_type):
333+
target_src = target.split("|")
334+
for can_term_id in self._term_dict[item["item"]][
335+
mt]:
336+
tmp_term_id = can_term_id
337+
if len(target_src) == 1:
338+
matched = True
339+
if i < target_idx:
340+
target_idx = i
341+
term_id = tmp_term_id
342+
else:
343+
if target_src[
344+
1] == "C" and "_cb_" in tmp_term_id:
345+
matched = True
346+
if i < target_idx:
347+
target_idx = i
348+
term_id = tmp_term_id
349+
if target_src[
350+
1] == "E" and "_eb_" in tmp_term_id:
351+
matched = True
352+
if i < target_idx:
353+
target_idx = i
354+
term_id = tmp_term_id
355+
tmp_type = self._term_schema[tmp_type]
356+
if matched is True:
357+
break
198358

199-
Args:
200-
input_text: input text
201-
202-
Raises:
203-
ValueError: raise ValueError if is not linking mode.
204-
205-
Returns:
206-
pred_res: result with linking.
207-
"""
208-
if self.linking is False:
209-
raise ValueError(
210-
"Not linking mode, you should initialize object by ``WordtagPredictor(model_dir, linking_path)``."
211-
)
212-
pred_res = self.run(input_text, return_hidden_states=True)
213-
self._post_linking(pred_res[0], pred_res[1:])
214-
return pred_res[0]
359+
if matched is True:
360+
item["termid"] = term_id

paddlenlp/transformers/ernie_ctm/modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class ErnieCtmPretrainedModel(PretrainedModel):
110110
pretrained_resource_files_map = {
111111
"model_state": {
112112
"ernie-ctm":
113-
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_ctm_base.pdparams"
113+
"https://paddlenlp.bj.bcebos.com/paddlenlp/models/transformers/ernie-ctm-base.pdparams"
114114
}
115115
}
116116
base_model_prefix = "ernie_ctm"

0 commit comments

Comments
 (0)