Skip to content

Commit 71389a5

Browse files
committed
fp tests
1 parent 6288778 commit 71389a5

20 files changed

+739
-621
lines changed

pyreason/scripts/interpretation/interpretation_fp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def _start_fp(self, rules, max_facts_time, verbose, again, restart):
235235
@numba.njit(cache=True, parallel=False)
236236
def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again):
237237
t = prev_reasoning_data[0]
238+
max_t = t # Keeps track of the max time in each fp operation
238239
fp_cnt = prev_reasoning_data[1]
239240
max_rules_time = 0
240241
fp_loop = True
@@ -578,6 +579,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
578579

579580
# Increment t, update number of ground atoms
580581
t += 1
582+
max_t = max(max_t, t)
581583

582584
# Now apply the rules and go back through all timesteps to see if there are more
583585
# Apply the rules that need to be applied at this timestep
@@ -738,7 +740,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
738740
# t += 1
739741
fp_cnt += 1
740742

741-
return fp_cnt, t
743+
return fp_cnt, max_t
742744

743745
def add_edge(self, edge, l):
744746
# This function is useful for pyreason gym, called externally
@@ -810,7 +812,7 @@ def get_final_num_ground_atoms(self):
810812

811813
return ga_cnt
812814

813-
def query(self, query, return_bool=True) -> Union[bool, Tuple[float, float]]:
815+
def query(self, query, t=0, return_bool=True) -> Union[bool, Tuple[float, float]]:
814816
"""
815817
This function is used to query the graph after reasoning
816818
:param query: A PyReason query object
@@ -833,21 +835,21 @@ def query(self, query, return_bool=True) -> Union[bool, Tuple[float, float]]:
833835

834836
# Check if the predicate exists
835837
if comp_type == 'node':
836-
if pred not in self.interpretations_node[component].world:
838+
if pred not in self.interpretations_node[t][component].world:
837839
return False if return_bool else (0, 0)
838840
else:
839-
if pred not in self.interpretations_edge[component].world:
841+
if pred not in self.interpretations_edge[t][component].world:
840842
return False if return_bool else (0, 0)
841843

842844
# Check if the bounds are satisfied
843845
if comp_type == 'node':
844-
if self.interpretations_node[component].world[pred] in bnd:
845-
return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper)
846+
if self.interpretations_node[t][component].world[pred] in bnd:
847+
return True if return_bool else (self.interpretations_node[t][component].world[pred].lower, self.interpretations_node[t][component].world[pred].upper)
846848
else:
847849
return False if return_bool else (0, 0)
848850
else:
849-
if self.interpretations_edge[component].world[pred] in bnd:
850-
return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper)
851+
if self.interpretations_edge[t][component].world[pred] in bnd:
852+
return True if return_bool else (self.interpretations_edge[t][component].world[pred].lower, self.interpretations_edge[t][component].world[pred].upper)
851853
else:
852854
return False if return_bool else (0, 0)
853855

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
<?xml version='1.0' encoding='utf-8'?>
2+
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
3+
<key id="owns" for="edge" attr.name="owns" attr.type="long" />
4+
<key id="Friends" for="edge" attr.name="Friends" attr.type="long" />
5+
<graph edgedefault="directed">
6+
<node id="John" />
7+
<node id="Mary" />
8+
<node id="Justin" />
9+
<node id="Dog" />
10+
<node id="Cat" />
11+
<edge source="John" target="Mary">
12+
<data key="Friends">1</data>
13+
</edge>
14+
<edge source="John" target="Justin">
15+
<data key="Friends">1</data>
16+
</edge>
17+
<edge source="John" target="Dog">
18+
<data key="owns">1</data>
19+
</edge>
20+
<edge source="Mary" target="Cat">
21+
<data key="owns">1</data>
22+
</edge>
23+
<edge source="Justin" target="Mary">
24+
<data key="Friends">1</data>
25+
</edge>
26+
<edge source="Justin" target="Cat">
27+
<data key="owns">1</data>
28+
</edge>
29+
<edge source="Justin" target="Dog">
30+
<data key="owns">1</data>
31+
</edge>
32+
</graph>
33+
</graphml>
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<?xml version='1.0' encoding='utf-8'?>
2+
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
3+
<key id="d0" for="edge" attr.name="HaveAccess" attr.type="long" />
4+
<graph edgedefault="directed">
5+
<node id="TextMessage" />
6+
<node id="Zach" />
7+
<node id="Justin" />
8+
<node id="Michelle" />
9+
<node id="Amy" />
10+
<edge source="Zach" target="TextMessage">
11+
<data key="d0">1</data>
12+
</edge>
13+
<edge source="Justin" target="TextMessage">
14+
<data key="d0">1</data>
15+
</edge>
16+
<edge source="Michelle" target="TextMessage">
17+
<data key="d0">1</data>
18+
</edge>
19+
<edge source="Amy" target="TextMessage">
20+
<data key="d0">1</data>
21+
</edge>
22+
</graph>
23+
</graphml>
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
<?xml version='1.0' encoding='utf-8'?>
2+
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
3+
<key id="isConnectedTo" for="edge" attr.name="isConnectedTo" attr.type="long" />
4+
<key id="Amsterdam_Airport_Schiphol" for="node" attr.name="Amsterdam_Airport_Schiphol" attr.type="long" />
5+
<key id="Riga_International_Airport" for="node" attr.name="Riga_International_Airport" attr.type="long" />
6+
<key id="Chișinău_International_Airport" for="node" attr.name="Chișinău_International_Airport" attr.type="long" />
7+
<key id="Düsseldorf_Airport" for="node" attr.name="Düsseldorf_Airport" attr.type="long" />
8+
<key id="Dubrovnik_Airport" for="node" attr.name="Dubrovnik_Airport" attr.type="long" />
9+
<key id="Athens_International_Airport" for="node" attr.name="Athens_International_Airport" attr.type="long" />
10+
<key id="Yali" for="node" attr.name="Yali" attr.type="long" />
11+
<key id="Vnukovo_International_Airport" for="node" attr.name="Vnukovo_International_Airport" attr.type="long" />
12+
<key id="Hévíz-Balaton_Airport" for="node" attr.name="Hévíz-Balaton_Airport" attr.type="long" />
13+
<key id="Pobedilovo_Airport" for="node" attr.name="Pobedilovo_Airport" attr.type="long" />
14+
<graph id="G" edgedefault="directed">
15+
<node id="Amsterdam_Airport_Schiphol">
16+
<data key="Amsterdam_Airport_Schiphol">1</data>
17+
</node>
18+
<node id="Riga_International_Airport">
19+
<data key="Riga_International_Airport">1</data>
20+
</node>
21+
<node id="Chișinău_International_Airport">
22+
<data key="Chișinău_International_Airport">1</data>
23+
</node>
24+
<node id="Yali">
25+
<data key="Yali">1</data>
26+
</node>
27+
<node id="Düsseldorf_Airport">
28+
<data key="Düsseldorf_Airport">1</data>
29+
</node>
30+
<node id="Pobedilovo_Airport">
31+
<data key="Pobedilovo_Airport">1</data>
32+
</node>
33+
<node id="Dubrovnik_Airport">
34+
<data key="Dubrovnik_Airport">1</data>
35+
</node>
36+
<node id="Hévíz-Balaton_Airport">
37+
<data key="Hévíz-Balaton_Airport">1</data>
38+
</node>
39+
<node id="Athens_International_Airport">
40+
<data key="Athens_International_Airport">1</data>
41+
</node>
42+
<node id="Vnukovo_International_Airport">
43+
<data key="Vnukovo_International_Airport">1</data>
44+
</node>
45+
<edge source="Pobedilovo_Airport" target="Vnukovo_International_Airport">
46+
<data key="isConnectedTo">1</data>
47+
</edge>
48+
<edge source="Vnukovo_International_Airport" target="Hévíz-Balaton_Airport">
49+
<data key="isConnectedTo">1</data>
50+
</edge>
51+
<edge source="Düsseldorf_Airport" target="Dubrovnik_Airport">
52+
<data key="isConnectedTo">1</data>
53+
</edge>
54+
<edge source="Dubrovnik_Airport" target="Athens_International_Airport">
55+
<data key="isConnectedTo">1</data>
56+
</edge>
57+
<edge source="Riga_International_Airport" target="Amsterdam_Airport_Schiphol">
58+
<data key="isConnectedTo">1</data>
59+
</edge>
60+
<edge source="Riga_International_Airport" target="Düsseldorf_Airport">
61+
<data key="isConnectedTo">1</data>
62+
</edge>
63+
<edge source="Chișinău_International_Airport" target="Riga_International_Airport">
64+
<data key="isConnectedTo">1</data>
65+
</edge>
66+
<edge source="Amsterdam_Airport_Schiphol" target="Yali">
67+
<data key="isConnectedTo">1</data>
68+
</edge>
69+
70+
</graph>
71+
</graphml>
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Test if annotation functions work
2+
import pyreason as pr
3+
import numba
4+
import numpy as np
5+
6+
7+
@numba.njit
8+
def probability_func(annotations, weights):
9+
print("in ann func", annotations)
10+
prob_A = annotations[0][0].lower
11+
prob_B = annotations[1][0].lower
12+
union_prob = prob_A + prob_B
13+
union_prob = np.round(union_prob, 3)
14+
return union_prob, 1
15+
16+
17+
def test_annotation_function_fp():
18+
# Reset PyReason
19+
pr.reset()
20+
pr.reset_rules()
21+
pr.reset_settings()
22+
print("fp version", pr.settings.fp_version)
23+
24+
pr.settings.allow_ground_rules = True
25+
pr.settings.fp_version = True
26+
27+
pr.add_fact(pr.Fact('P(A) : [0.01, 1]'))
28+
pr.add_fact(pr.Fact('P(B) : [0.2, 1]'))
29+
pr.add_annotation_function(probability_func)
30+
pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))
31+
32+
interpretation = pr.reason(timesteps=1)
33+
34+
dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability'])
35+
for t, df in enumerate(dataframes):
36+
print(f'TIMESTEP - {t}')
37+
print(df)
38+
print()
39+
40+
assert interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')), 'Union probability should be 0.21'
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import pyreason as pr
2+
3+
4+
def test_anyBurl_rule_1_fp():
5+
graph_path = './tests/knowledge_graph_test_subset.graphml'
6+
pr.reset()
7+
pr.reset_rules()
8+
pr.reset_settings()
9+
# Modify pyreason settings to make verbose and to save the rule trace to a file
10+
pr.settings.verbose = True
11+
pr.settings.atom_trace = True
12+
pr.settings.memory_profile = False
13+
pr.settings.canonical = True
14+
pr.settings.inconsistency_check = False
15+
pr.settings.static_graph_facts = False
16+
pr.settings.output_to_file = False
17+
pr.settings.store_interpretation_changes = True
18+
pr.settings.save_graph_attributes_to_trace = True
19+
pr.settings.fp_version = True
20+
# Load all the files into pyreason
21+
pr.load_graphml(graph_path)
22+
pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True))
23+
24+
# Run the program for two timesteps to see the diffusion take place
25+
interpretation = pr.reason(timesteps=1)
26+
# pr.save_rule_trace(interpretation)
27+
28+
# Display the changes in the interpretation for each timestep
29+
dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
30+
for t, df in enumerate(dataframes):
31+
print(f'TIMESTEP - {t}')
32+
print(df)
33+
print()
34+
assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations'
35+
assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
36+
assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'
37+
38+
39+
def test_anyBurl_rule_2_fp():
40+
graph_path = './tests/knowledge_graph_test_subset.graphml'
41+
pr.reset()
42+
pr.reset_rules()
43+
pr.reset_settings()
44+
# Modify pyreason settings to make verbose and to save the rule trace to a file
45+
pr.settings.verbose = True
46+
pr.settings.atom_trace = True
47+
pr.settings.memory_profile = False
48+
pr.settings.canonical = True
49+
pr.settings.inconsistency_check = False
50+
pr.settings.static_graph_facts = False
51+
pr.settings.output_to_file = False
52+
pr.settings.store_interpretation_changes = True
53+
pr.settings.save_graph_attributes_to_trace = True
54+
pr.settings.parallel_computing = False
55+
# Load all the files into pyreason
56+
pr.load_graphml(graph_path)
57+
58+
pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_2', infer_edges=True))
59+
60+
# Run the program for two timesteps to see the diffusion take place
61+
interpretation = pr.reason(timesteps=1)
62+
# pr.save_rule_trace(interpretation)
63+
64+
# Display the changes in the interpretation for each timestep
65+
dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
66+
for t, df in enumerate(dataframes):
67+
print(f'TIMESTEP - {t}')
68+
print(df)
69+
print()
70+
assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations'
71+
assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
72+
assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'
73+
74+
75+
def test_anyBurl_rule_3_fp():
76+
graph_path = './tests/knowledge_graph_test_subset.graphml'
77+
pr.reset()
78+
pr.reset_rules()
79+
pr.reset_settings()
80+
# Modify pyreason settings to make verbose and to save the rule trace to a file
81+
pr.settings.verbose = True
82+
pr.settings.atom_trace = True
83+
pr.settings.memory_profile = False
84+
pr.settings.canonical = True
85+
pr.settings.inconsistency_check = False
86+
pr.settings.static_graph_facts = False
87+
pr.settings.output_to_file = False
88+
pr.settings.store_interpretation_changes = True
89+
pr.settings.save_graph_attributes_to_trace = True
90+
pr.settings.parallel_computing = False
91+
# Load all the files into pyreason
92+
pr.load_graphml(graph_path)
93+
94+
pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_3', infer_edges=True))
95+
96+
# Run the program for two timesteps to see the diffusion take place
97+
interpretation = pr.reason(timesteps=1)
98+
# pr.save_rule_trace(interpretation)
99+
100+
# Display the changes in the interpretation for each timestep
101+
dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
102+
for t, df in enumerate(dataframes):
103+
print(f'TIMESTEP - {t}')
104+
print(df)
105+
print()
106+
assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations'
107+
assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
108+
assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps'
109+
110+
111+
def test_anyBurl_rule_4_fp():
112+
graph_path = './tests/knowledge_graph_test_subset.graphml'
113+
pr.reset()
114+
pr.reset_rules()
115+
pr.reset_settings()
116+
# Modify pyreason settings to make verbose and to save the rule trace to a file
117+
pr.settings.verbose = True
118+
pr.settings.atom_trace = True
119+
pr.settings.memory_profile = False
120+
pr.settings.canonical = True
121+
pr.settings.inconsistency_check = False
122+
pr.settings.static_graph_facts = False
123+
pr.settings.output_to_file = False
124+
pr.settings.store_interpretation_changes = True
125+
pr.settings.save_graph_attributes_to_trace = True
126+
pr.settings.parallel_computing = False
127+
# Load all the files into pyreason
128+
pr.load_graphml(graph_path)
129+
130+
pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_4', infer_edges=True))
131+
132+
# Run the program for two timesteps to see the diffusion take place
133+
interpretation = pr.reason(timesteps=1)
134+
# pr.save_rule_trace(interpretation)
135+
136+
# Display the changes in the interpretation for each timestep
137+
dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
138+
for t, df in enumerate(dataframes):
139+
print(f'TIMESTEP - {t}')
140+
print(df)
141+
print()
142+
assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations'
143+
assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
144+
assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'

0 commit comments

Comments
 (0)