@@ -56,6 +56,7 @@ def __init__(self):
56
56
self .__parallel_computing = None
57
57
self .__update_mode = None
58
58
self .__allow_ground_rules = None
59
+ self .__fp_version = None
59
60
self .reset ()
60
61
61
62
def reset (self ):
@@ -76,6 +77,7 @@ def reset(self):
76
77
self .__parallel_computing = False
77
78
self .__update_mode = 'intersection'
78
79
self .__allow_ground_rules = False
80
+ self .__fp_version = False
79
81
80
82
@property
81
83
def verbose (self ) -> bool :
@@ -219,6 +221,14 @@ def allow_ground_rules(self) -> bool:
219
221
"""
220
222
return self .__allow_ground_rules
221
223
224
+ @property
225
+ def fp_version (self ) -> bool :
226
+ """Returns whether we are using the fixed point version or the optimized version. Default is false
227
+
228
+ :return: bool
229
+ """
230
+ return self .__fp_version
231
+
222
232
@verbose .setter
223
233
def verbose (self , value : bool ) -> None :
224
234
"""Set verbose mode. Default is True
@@ -430,6 +440,18 @@ def allow_ground_rules(self, value: bool) -> None:
430
440
else :
431
441
self .__allow_ground_rules = value
432
442
443
+ @fp_version .setter
444
+ def fp_version (self , value : bool ) -> None :
445
+ """Set the fixed point or optimized version. Default is False
446
+
447
+ :param value: Whether to use the fixed point version or the optimized version
448
+ :raises TypeError: If not bool raise error
449
+ """
450
+ if not isinstance (value , bool ):
451
+ raise TypeError ('value has to be a bool' )
452
+ else :
453
+ self .__fp_version = value
454
+
433
455
434
456
# VARIABLES
435
457
__graph : Optional [nx .DiGraph ] = None
@@ -506,7 +528,7 @@ def load_graphml(path: str) -> None:
506
528
507
529
:param path: Path for the GraphMl file
508
530
"""
509
- global __graph , __graphml_parser , __non_fluent_graph_facts_node , __non_fluent_graph_facts_edge , __specific_graph_node_labels , __specific_graph_edge_labels , settings
531
+ global __graph , __non_fluent_graph_facts_node , __non_fluent_graph_facts_edge , __specific_graph_node_labels , __specific_graph_edge_labels
510
532
511
533
# Parse graph
512
534
__graph = __graphml_parser .parse_graph (path , settings .reverse_digraph )
@@ -528,7 +550,7 @@ def load_graph(graph: nx.DiGraph) -> None:
528
550
:type graph: nx.DiGraph
529
551
:return: None
530
552
"""
531
- global __graph , __graphml_parser , __non_fluent_graph_facts_node , __non_fluent_graph_facts_edge , __specific_graph_node_labels , __specific_graph_edge_labels , settings
553
+ global __graph , __non_fluent_graph_facts_node , __non_fluent_graph_facts_edge , __specific_graph_node_labels , __specific_graph_edge_labels
532
554
533
555
# Load graph
534
556
__graph = __graphml_parser .load_graph (graph )
@@ -629,7 +651,6 @@ def add_annotation_function(function: Callable) -> None:
629
651
:type function: Callable
630
652
:return: None
631
653
"""
632
- global __annotation_functions
633
654
# Make sure that the functions are jitted so that they can be passed around in other jitted functions
634
655
# TODO: Remove if necessary
635
656
# assert hasattr(function, 'nopython_signatures'), 'The function to be added has to be under a `numba.njit` decorator'
@@ -648,7 +669,7 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
648
669
:param restart: Whether to restart the program time from 0 when reasoning again, defaults to True
649
670
:return: The final interpretation after reasoning.
650
671
"""
651
- global settings , __timestamp
672
+ global __timestamp
652
673
653
674
# Timestamp for saving files
654
675
__timestamp = time .strftime ('%Y%m%d-%H%M%S' )
@@ -676,8 +697,8 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
676
697
677
698
def _reason (timesteps , convergence_threshold , convergence_bound_threshold , queries ):
678
699
# Globals
679
- global __graph , __rules , __clause_maps , __node_facts , __edge_facts , __ipl , __specific_node_labels , __specific_edge_labels , __graphml_parser
680
- global settings , __timestamp , __program
700
+ global __rules , __clause_maps , __node_facts , __edge_facts , __ipl , __specific_node_labels , __specific_edge_labels
701
+ global __program
681
702
682
703
# Assert variables are of correct type
683
704
@@ -748,7 +769,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
748
769
__rules .append (r )
749
770
750
771
# Setup logical program
751
- __program = Program (__graph , all_node_facts , all_edge_facts , __rules , __ipl , annotation_functions , settings .reverse_digraph , settings .atom_trace , settings .save_graph_attributes_to_trace , settings .persistent , settings .inconsistency_check , settings .store_interpretation_changes , settings .parallel_computing , settings .update_mode , settings .allow_ground_rules )
772
+ __program = Program (__graph , all_node_facts , all_edge_facts , __rules , __ipl , annotation_functions , settings .reverse_digraph , settings .atom_trace , settings .save_graph_attributes_to_trace , settings .persistent , settings .inconsistency_check , settings .store_interpretation_changes , settings .parallel_computing , settings .update_mode , settings .allow_ground_rules , settings . fp_version )
752
773
__program .specific_node_labels = __specific_node_labels
753
774
__program .specific_edge_labels = __specific_edge_labels
754
775
@@ -764,9 +785,6 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
764
785
765
786
def _reason_again (timesteps , restart , convergence_threshold , convergence_bound_threshold ):
766
787
# Globals
767
- global __graph , __rules , __node_facts , __edge_facts , __ipl , __specific_node_labels , __specific_edge_labels , __graphml_parser
768
- global settings , __timestamp , __program
769
-
770
788
assert __program is not None , 'To run `reason_again` you need to have reasoned once before'
771
789
772
790
# Extend facts
@@ -788,8 +806,6 @@ def save_rule_trace(interpretation, folder: str='./'):
788
806
:param interpretation: the output of `pyreason.reason()`, the final interpretation
789
807
:param folder: the folder in which to save the result, defaults to './'
790
808
"""
791
- global __timestamp , __clause_maps , settings
792
-
793
809
assert settings .store_interpretation_changes , 'store interpretation changes setting is off, turn on to save rule trace'
794
810
795
811
output = Output (__timestamp , __clause_maps )
@@ -804,8 +820,6 @@ def get_rule_trace(interpretation) -> Tuple[pd.DataFrame, pd.DataFrame]:
804
820
:param interpretation: the output of `pyreason.reason()`, the final interpretation
805
821
:returns two pandas dataframes (nodes, edges) representing the changes that occurred during reasoning
806
822
"""
807
- global __timestamp , __clause_maps , settings
808
-
809
823
assert settings .store_interpretation_changes , 'store interpretation changes setting is off, turn on to save rule trace'
810
824
811
825
output = Output (__timestamp , __clause_maps )
0 commit comments