Skip to content

Commit 04eefb7

Browse files
authored
Merge pull request #70 from lab-v2/fp-fix
Fp fix
2 parents e9f3700 + e5245eb commit 04eefb7

22 files changed

+2728
-22
lines changed

pyreason/pyreason.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self):
5656
self.__parallel_computing = None
5757
self.__update_mode = None
5858
self.__allow_ground_rules = None
59+
self.__fp_version = None
5960
self.reset()
6061

6162
def reset(self):
@@ -76,6 +77,7 @@ def reset(self):
7677
self.__parallel_computing = False
7778
self.__update_mode = 'intersection'
7879
self.__allow_ground_rules = False
80+
self.__fp_version = False
7981

8082
@property
8183
def verbose(self) -> bool:
@@ -219,6 +221,14 @@ def allow_ground_rules(self) -> bool:
219221
"""
220222
return self.__allow_ground_rules
221223

224+
@property
225+
def fp_version(self) -> bool:
226+
"""Returns whether we are using the fixed point version or the optimized version. Default is false
227+
228+
:return: bool
229+
"""
230+
return self.__fp_version
231+
222232
@verbose.setter
223233
def verbose(self, value: bool) -> None:
224234
"""Set verbose mode. Default is True
@@ -430,6 +440,18 @@ def allow_ground_rules(self, value: bool) -> None:
430440
else:
431441
self.__allow_ground_rules = value
432442

443+
@fp_version.setter
444+
def fp_version(self, value: bool) -> None:
445+
"""Set the fixed point or optimized version. Default is False
446+
447+
:param value: Whether to use the fixed point version or the optimized version
448+
:raises TypeError: If not bool raise error
449+
"""
450+
if not isinstance(value, bool):
451+
raise TypeError('value has to be a bool')
452+
else:
453+
self.__fp_version = value
454+
433455

434456
# VARIABLES
435457
__graph: Optional[nx.DiGraph] = None
@@ -506,7 +528,7 @@ def load_graphml(path: str) -> None:
506528
507529
:param path: Path for the GraphMl file
508530
"""
509-
global __graph, __graphml_parser, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels, settings
531+
global __graph, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels
510532

511533
# Parse graph
512534
__graph = __graphml_parser.parse_graph(path, settings.reverse_digraph)
@@ -528,7 +550,7 @@ def load_graph(graph: nx.DiGraph) -> None:
528550
:type graph: nx.DiGraph
529551
:return: None
530552
"""
531-
global __graph, __graphml_parser, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels, settings
553+
global __graph, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels
532554

533555
# Load graph
534556
__graph = __graphml_parser.load_graph(graph)
@@ -629,7 +651,6 @@ def add_annotation_function(function: Callable) -> None:
629651
:type function: Callable
630652
:return: None
631653
"""
632-
global __annotation_functions
633654
# Make sure that the functions are jitted so that they can be passed around in other jitted functions
634655
# TODO: Remove if necessary
635656
# assert hasattr(function, 'nopython_signatures'), 'The function to be added has to be under a `numba.njit` decorator'
@@ -648,7 +669,7 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
648669
:param restart: Whether to restart the program time from 0 when reasoning again, defaults to True
649670
:return: The final interpretation after reasoning.
650671
"""
651-
global settings, __timestamp
672+
global __timestamp
652673

653674
# Timestamp for saving files
654675
__timestamp = time.strftime('%Y%m%d-%H%M%S')
@@ -676,8 +697,8 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
676697

677698
def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queries):
678699
# Globals
679-
global __graph, __rules, __clause_maps, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
680-
global settings, __timestamp, __program
700+
global __rules, __clause_maps, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels
701+
global __program
681702

682703
# Assert variables are of correct type
683704

@@ -748,7 +769,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
748769
__rules.append(r)
749770

750771
# Setup logical program
751-
__program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.persistent, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules)
772+
__program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.persistent, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules, settings.fp_version)
752773
__program.specific_node_labels = __specific_node_labels
753774
__program.specific_edge_labels = __specific_edge_labels
754775

@@ -764,9 +785,6 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
764785

765786
def _reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold):
766787
# Globals
767-
global __graph, __rules, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
768-
global settings, __timestamp, __program
769-
770788
assert __program is not None, 'To run `reason_again` you need to have reasoned once before'
771789

772790
# Extend facts
@@ -788,8 +806,6 @@ def save_rule_trace(interpretation, folder: str='./'):
788806
:param interpretation: the output of `pyreason.reason()`, the final interpretation
789807
:param folder: the folder in which to save the result, defaults to './'
790808
"""
791-
global __timestamp, __clause_maps, settings
792-
793809
assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace'
794810

795811
output = Output(__timestamp, __clause_maps)
@@ -804,8 +820,6 @@ def get_rule_trace(interpretation) -> Tuple[pd.DataFrame, pd.DataFrame]:
804820
:param interpretation: the output of `pyreason.reason()`, the final interpretation
805821
:returns two pandas dataframes (nodes, edges) representing the changes that occurred during reasoning
806822
"""
807-
global __timestamp, __clause_maps, settings
808-
809823
assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace'
810824

811825
output = Output(__timestamp, __clause_maps)

0 commit comments

Comments
 (0)