Skip to content

Commit cc2fb0d

Browse files
authored
Merge pull request #2 from bigdata-ustc/fix-link-vars
[BUGFIX] invalid link formulas
2 parents 15e457e + b84ffab commit cc2fb0d

File tree

11 files changed

+281
-79
lines changed

11 files changed

+281
-79
lines changed

EduNLP/Formula/Formula.py

Lines changed: 159 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,84 @@
22
# 2021/3/8 @ tongshiwei
33

44
from pprint import pformat
5-
from typing import List
5+
from typing import List, Dict
66
import networkx as nx
77
from copy import deepcopy
88

9-
from .ast import str2ast, get_edges, ast, link_variable
9+
from .ast import str2ast, get_edges, link_variable
1010

1111
CONST_MATHORD = {r"\pi"}
1212

13-
__all__ = ["Formula", "FormulaGroup", "CONST_MATHORD"]
13+
__all__ = ["Formula", "FormulaGroup", "CONST_MATHORD", "link_formulas"]
1414

1515

1616
class Formula(object):
17-
def __init__(self, formula, is_str=True, variable_standardization=False, const_mathord=None):
18-
self._formula = formula
19-
self._ast = str2ast(formula) if is_str else formula
20-
if variable_standardization:
21-
const_mathord = CONST_MATHORD if const_mathord is None else const_mathord
22-
self.variable_standardization(inplace=True, const_mathord=const_mathord)
17+
"""
18+
Examples
19+
--------
20+
>>> f = Formula("x")
21+
>>> f
22+
<Formula: x>
23+
>>> f.ast
24+
[{'val': {'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}, \
25+
'structure': {'bro': [None, None], 'child': None, 'father': None, 'forest': None}}]
26+
>>> f.elements
27+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}]
28+
>>> f.variable_standardization(inplace=True)
29+
<Formula: x>
30+
>>> f.elements
31+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}]
32+
"""
33+
34+
def __init__(self, formula: (str, List[Dict]), variable_standardization=False, const_mathord=None,
35+
*args, **kwargs):
36+
"""
2337
24-
def variable_standardization(self, inplace=False, const_mathord=None):
38+
Parameters
39+
----------
40+
formula: str or List[Dict]
41+
latex formula string or the parsed abstracted syntax tree
42+
variable_standardization
43+
const_mathord
44+
args
45+
kwargs
46+
"""
47+
self._formula = formula
48+
self._ast = None
49+
self.reset_ast(
50+
formula_ensure_str=False,
51+
variable_standardization=variable_standardization,
52+
const_mathord=const_mathord, *args, **kwargs
53+
)
54+
55+
def variable_standardization(self, inplace=False, const_mathord=None, variable_connect_dict=None):
2556
const_mathord = const_mathord if const_mathord is not None else CONST_MATHORD
2657
ast_tree = self._ast if inplace else deepcopy(self._ast)
27-
variables = {}
28-
index = 0
58+
var_code = variable_connect_dict["var_code"] if variable_connect_dict is not None else {}
2959
for node in ast_tree:
3060
if node["val"]["type"] == "mathord":
3161
var = node["val"]["text"]
3262
if var in const_mathord:
3363
continue
3464
else:
35-
if var not in variables:
36-
variables[var], index = index, index + 1
37-
node["val"]["var"] = variables[var]
65+
if var not in var_code:
66+
var_code[var] = len(var_code)
67+
node["val"]["var"] = var_code[var]
3868
if inplace:
3969
return self
4070
else:
4171
return Formula(ast_tree, is_str=False)
4272

4373
@property
44-
def element(self):
74+
def ast(self):
4575
return self._ast
4676

4777
@property
48-
def ast(self) -> (nx.Graph, nx.DiGraph):
78+
def elements(self):
79+
return [self.ast_graph.nodes[node] for node in self.ast_graph.nodes]
80+
81+
@property
82+
def ast_graph(self) -> (nx.Graph, nx.DiGraph):
4983
edges = [(edge[0], edge[1]) for edge in get_edges(self._ast) if edge[2] == 3]
5084
tree = nx.DiGraph()
5185
for node in self._ast:
@@ -60,38 +94,85 @@ def to_str(self):
6094
return pformat(self._ast)
6195

6296
def __repr__(self):
63-
return "<Formula: %s>" % self._formula
97+
if isinstance(self._formula, str):
98+
return "<Formula: %s>" % self._formula
99+
else:
100+
return super(Formula, self).__repr__()
64101

102+
def reset_ast(self, formula_ensure_str: bool = True, variable_standardization=False, const_mathord=None, *args,
103+
**kwargs):
104+
if formula_ensure_str is True and self.resetable is False:
105+
raise TypeError("formula must be str, now is %s" % type(self._formula))
106+
self._ast = str2ast(self._formula, *args, **kwargs) if isinstance(self._formula, str) else self._formula
107+
if variable_standardization:
108+
const_mathord = CONST_MATHORD if const_mathord is None else const_mathord
109+
self.variable_standardization(inplace=True, const_mathord=const_mathord)
110+
return self._ast
65111

66-
class FormulaGroup(object):
67-
def __init__(self, formula_list: List[str], variable_standardization=False, const_mathord=None):
68-
"""
112+
@property
113+
def resetable(self):
114+
return isinstance(self._formula, str)
69115

70-
Parameters
71-
----------
72-
formula_list: List[str]
73-
"""
74-
forest_begin = 0
116+
117+
class FormulaGroup(object):
118+
"""
119+
Examples
120+
---------
121+
>>> fg = FormulaGroup(["x + y", "y + x", "z + x"])
122+
>>> fg
123+
<FormulaGroup: <Formula: x + y>;<Formula: y + x>;<Formula: z + x>>
124+
>>> fg = FormulaGroup(["x + y", Formula("y + x"), "z + x"])
125+
>>> fg
126+
<FormulaGroup: <Formula: x + y>;<Formula: y + x>;<Formula: z + x>>
127+
>>> fg = FormulaGroup(["x", Formula("y"), "x"])
128+
>>> fg.elements
129+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}, {'id': 1, 'type': 'mathord', 'text': 'y', 'role': None},\
130+
{'id': 2, 'type': 'mathord', 'text': 'x', 'role': None}]
131+
>>> fg = FormulaGroup(["x", Formula("y"), "x"], variable_standardization=True)
132+
>>> fg.elements
133+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}, \
134+
{'id': 1, 'type': 'mathord', 'text': 'y', 'role': None, 'var': 1}, \
135+
{'id': 2, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}]
136+
"""
137+
138+
def __init__(self,
139+
formula_list: (list, List[str], List[Formula]),
140+
variable_standardization=False,
141+
const_mathord=None,
142+
detach=True
143+
):
75144
forest = []
76-
formula_sep_index = []
77-
for index in range(0, len(formula_list)):
78-
formula_sep_index.append(forest_begin)
79-
tree = ast(
80-
formula_list[index],
81-
forest_begin=forest_begin,
82-
is_str=True
83-
)
84-
forest_begin += len(tree)
145+
self._formulas = []
146+
for formula in formula_list:
147+
if isinstance(formula, str):
148+
formula = Formula(
149+
formula,
150+
forest_begin=len(forest),
151+
)
152+
self._formulas.append(formula)
153+
tree = formula.ast
154+
elif isinstance(formula, Formula):
155+
if detach:
156+
formula = deepcopy(formula)
157+
tree = formula.reset_ast(
158+
formula_ensure_str=True,
159+
variable_standardization=False,
160+
forest_begin=len(forest),
161+
)
162+
self._formulas.append(formula)
163+
else:
164+
raise TypeError(
165+
"the element in formula_list should be either str or Formula, now is %s" % type(Formula)
166+
)
85167
forest += tree
86-
else:
87-
formula_sep_index.append(len(forest))
88-
forest = link_variable(forest)
168+
variable_connect_dict = link_variable(forest)
89169
self._forest = forest
90-
self._formulas = []
91-
for i, sep in enumerate(formula_sep_index[:-1]):
92-
self._formulas.append(Formula(forest[sep: formula_sep_index[i + 1]], is_str=False))
93170
if variable_standardization:
94-
self.variable_standardization(inplace=True, const_mathord=const_mathord)
171+
self.variable_standardization(
172+
inplace=True,
173+
const_mathord=const_mathord,
174+
variable_connect_dict=variable_connect_dict
175+
)
95176

96177
def __iter__(self):
97178
return iter(self._formulas)
@@ -102,14 +183,47 @@ def __getitem__(self, item) -> Formula:
102183
def __contains__(self, item) -> bool:
103184
return item in self._formulas
104185

105-
def variable_standardization(self, inplace=False, const_mathord=None):
186+
def variable_standardization(self, inplace=False, const_mathord=None, variable_connect_dict=None):
106187
ret = []
107188
for formula in self._formulas:
108-
ret.append(formula.variable_standardization(inplace=inplace, const_mathord=const_mathord))
189+
ret.append(formula.variable_standardization(inplace=inplace, const_mathord=const_mathord,
190+
variable_connect_dict=variable_connect_dict))
109191
return ret
110192

111193
def to_str(self):
112-
return pformat(self._formulas)
194+
return pformat(self._forest)
113195

114196
def __repr__(self):
115197
return "<FormulaGroup: %s>" % ";".join([repr(_formula) for _formula in self._formulas])
198+
199+
@property
200+
def ast(self):
201+
return self._forest
202+
203+
@property
204+
def elements(self):
205+
return [self.ast_graph.nodes[node] for node in self.ast_graph.nodes]
206+
207+
@property
208+
def ast_graph(self) -> (nx.Graph, nx.DiGraph):
209+
edges = [(edge[0], edge[1]) for edge in get_edges(self._forest) if edge[2] == 3]
210+
tree = nx.DiGraph()
211+
for node in self._forest:
212+
tree.add_node(
213+
node["val"]["id"],
214+
**node["val"]
215+
)
216+
tree.add_edges_from(edges)
217+
return tree
218+
219+
220+
def link_formulas(*formula: Formula, **kwargs):
221+
forest = []
222+
for form in formula:
223+
forest += form.reset_ast(
224+
forest_begin=len(forest),
225+
**kwargs
226+
)
227+
variable_connect_dict = link_variable(forest)
228+
for form in formula:
229+
form.variable_standardization(inplace=True, variable_connect_dict=variable_connect_dict, **kwargs)

EduNLP/Formula/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .Formula import Formula, FormulaGroup
1+
from .Formula import Formula, FormulaGroup, link_formulas
22
from .ast import link_variable
33
from .Formula import CONST_MATHORD

EduNLP/Formula/ast/ast.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
__all__ = ["str2ast", "get_edges", "ast", "link_variable"]
88

99

10-
def str2ast(formula: str):
11-
return ast(formula, is_str=True)
10+
def str2ast(formula: str, *args, **kwargs):
11+
return ast(formula, is_str=True, *args, **kwargs)
1212

1313

1414
def ast(formula: (str, List[Dict]), index=0, forest_begin=0, father_tree=None, is_str=False):
@@ -18,7 +18,7 @@ def ast(formula: (str, List[Dict]), index=0, forest_begin=0, father_tree=None, i
1818
Parameters
1919
----------
2020
formula: str or List[Dict]
21-
21+
公式字符串或通过katex解析得到的结构体
2222
index: int
2323
本子树在树上的位置
2424
forest_begin: int
@@ -224,8 +224,11 @@ def link_variable(forest):
224224
l_v = [] + v
225225
index = l_v.pop(i)
226226
forest[index]['structure']['forest'] = l_v
227-
228-
return forest
227+
variable_connect_dict = {
228+
"var2id": forest_connect_dict,
229+
"var_code": {}
230+
}
231+
return variable_connect_dict
229232

230233

231234
def get_edges(forest):

EduNLP/SIF/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
# 2021/5/16 @ tongshiwei
33

44
from .sif import is_sif, to_sif, sif4sci
5+
from .tokenization import link_formulas

EduNLP/SIF/sif.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# 2021/5/16 @ tongshiwei
33

44
from .segment import seg
5-
from .tokenization import tokenize
5+
from .tokenization import tokenize, link_formulas
66

77

88
def is_sif(item):
@@ -60,6 +60,24 @@ def sif4sci(item: str, figures: (dict, bool) = None, safe=True, symbol: str = No
6060
>>> sif4sci(test_item, symbol="gm",
6161
... tokenization_params={"formula_params": {"method": "ast", "return_type": "list"}})
6262
['如图所示', '\\\\bigtriangleup', 'A', 'B', 'C', '面积', '[MARK]', '[FIGURE]']
63+
>>> test_item_1 = {
64+
... "stem": r"若$x=2$, $y=\\sqrt{x}$,则下列说法正确的是$\\SIFChoice$",
65+
... "options": [r"$x < y$", r"$y = x$", r"$y < x$"]
66+
... }
67+
>>> tls = [
68+
... sif4sci(e, symbol="gm",
69+
... tokenization_params={
70+
... "formula_params": {
71+
... "method": "ast", "return_type": "list", "ord2token": True, "var_numbering": True,
72+
... "link_variable": False}
73+
... })
74+
... for e in ([test_item_1["stem"]] + test_item_1["options"])
75+
... ]
76+
>>> tls[1:]
77+
[['mathord_0', '<', 'mathord_1'], ['mathord_0', '=', 'mathord_1'], ['mathord_0', '<', 'mathord_1']]
78+
>>> link_formulas(*tls)
79+
>>> tls[1:]
80+
[['mathord_0', '<', 'mathord_1'], ['mathord_1', '=', 'mathord_0'], ['mathord_1', '<', 'mathord_0']]
6381
"""
6482
if safe is True and is_sif(item) is not True:
6583
item = to_sif(item)

EduNLP/SIF/tokenization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# coding: utf-8
22
# 2021/5/18 @ tongshiwei
33

4-
from .tokenization import tokenize
4+
from .tokenization import tokenize, link_formulas

EduNLP/SIF/tokenization/formula/ast_token.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ def ast_tokenize(formula, ord2token=False, var_numbering=False, return_type="for
7979
<Formula: {x + y}^\\frac{\\pi}{2} + 1 = x>
8080
"""
8181
if return_type == "list":
82-
ast = Formula(formula, variable_standardization=True).ast
82+
ast = Formula(formula, variable_standardization=True).ast_graph
8383
return traversal_formula(ast, ord2token=ord2token, var_numbering=var_numbering)
8484
elif return_type == "formula":
8585
return Formula(formula)
8686
elif return_type == "ast":
87-
return Formula(formula).ast
87+
return Formula(formula).ast_graph
8888
else:
8989
raise ValueError()
9090

0 commit comments

Comments
 (0)