39
39
'''
40
40
41
41
import keyword
42
- from typing import Dict , List , Optional , Set , Union
42
+ from typing import Dict , Iterable , List , Optional , Set , Union
43
43
44
44
import sympy
45
45
from sympy .parsing .sympy_parser import parse_expr
@@ -102,6 +102,14 @@ class SymPyWriter(FortranWriter):
102
102
# This class attribute will get initialised in __init__:
103
103
_RESERVED_NAMES : Set [str ] = set ()
104
104
105
+ # A mapping of PSyIR's logical binary operations to the required
106
+ # SymPy format:
107
+ _BINARY_OP_MAPPING : Dict [BinaryOperation .Operator , str ] = \
108
+ {BinaryOperation .Operator .AND : "And({lhs}, {rhs})" ,
109
+ BinaryOperation .Operator .OR : "Or({lhs}, {rhs})" ,
110
+ BinaryOperation .Operator .EQV : "Equivalent({lhs}, {rhs})" ,
111
+ BinaryOperation .Operator .NEQV : "Xor({lhs}, {rhs})" }
112
+
105
113
def __init__ (self ):
106
114
super ().__init__ ()
107
115
@@ -255,7 +263,7 @@ def _create_sympy_array_function(
255
263
256
264
# -------------------------------------------------------------------------
257
265
def _create_type_map (self ,
258
- list_of_expressions : List [Node ],
266
+ list_of_expressions : Iterable [Node ],
259
267
identical_variables : Optional [Dict [str , str ]] = None ,
260
268
all_variables_positive : Optional [bool ] = None ):
261
269
'''This function creates a dictionary mapping each Reference in any
@@ -399,7 +407,7 @@ def type_map(self) -> Dict[str, Union[sympy.core.symbol.Symbol,
399
407
# -------------------------------------------------------------------------
400
408
def _to_str (
401
409
self ,
402
- list_of_expressions : Union [Node , List [Node ]],
410
+ list_of_expressions : Union [Node , Iterable [Node ]],
403
411
identical_variables : Optional [Dict [str , str ]] = None ,
404
412
all_variables_positive : Optional [bool ] = False ) -> Union [str ,
405
413
List [str ]]:
@@ -423,14 +431,14 @@ def _to_str(
423
431
:returns: the converted strings(s).
424
432
425
433
'''
426
- is_list = isinstance (list_of_expressions , (tuple , list ))
434
+ is_list = isinstance (list_of_expressions , (Iterable ))
427
435
if not is_list :
428
436
# Make mypy happy:
429
437
assert isinstance (list_of_expressions , Node )
430
438
list_of_expressions = [list_of_expressions ]
431
439
432
440
# Make mypy happy:
433
- assert isinstance (list_of_expressions , List )
441
+ assert isinstance (list_of_expressions , Iterable )
434
442
# Create the type map in `self._sympy_type_map`, which is required
435
443
# when converting these strings to SymPy expressions
436
444
self ._create_type_map (list_of_expressions ,
@@ -711,18 +719,11 @@ def binaryoperation_node(self, node: BinaryOperation) -> str:
711
719
:param node: a Reference PSyIR BinaryOperation.
712
720
713
721
'''
714
- for psy_op , sympy_op in [(BinaryOperation .Operator .AND ,
715
- "{lhs} & {rhs}" ),
716
- (BinaryOperation .Operator .OR ,
717
- "{lhs} | {rhs}" ),
718
- (BinaryOperation .Operator .EQV ,
719
- "Equivalent({lhs}, {rhs})" ),
720
- (BinaryOperation .Operator .NEQV ,
721
- "~Equivalent({lhs}, {rhs})" )]:
722
- if node .operator == psy_op :
723
- lhs = self ._visit (node .children [0 ])
724
- rhs = self ._visit (node .children [1 ])
725
- return sympy_op .format (rhs = rhs , lhs = lhs )
722
+ if node .operator in self ._BINARY_OP_MAPPING :
723
+ lhs = self ._visit (node .children [0 ])
724
+ rhs = self ._visit (node .children [1 ])
725
+ return self ._BINARY_OP_MAPPING [node .operator ].format (rhs = rhs ,
726
+ lhs = lhs )
726
727
727
728
return super ().binaryoperation_node (node )
728
729
0 commit comments