|
37 | 37 | '''PSyIR frontend to convert a SymPy expression to PSyIR
|
38 | 38 | '''
|
39 | 39 |
|
| 40 | +from sympy.printing.printer import Printer |
40 | 41 |
|
41 | 42 | from psyclone.psyir.frontend.fortran import FortranReader
|
42 | 43 |
|
43 | 44 |
|
| 45 | +# pylint: disable=invalid-name |
| 46 | +class FortranPrinter(Printer): |
| 47 | + '''A helper class that converts Fortran logical operators back |
| 48 | + to Fortran format (while SymPy has a Fortran printer (fcode) does this |
| 49 | + as well, it does not handle e.g. Fortran Array expressions (a(2:5)), |
| 50 | + so we need to use the not-Fortran-aware output to a normal string, |
| 51 | + but handle logical operators separately. |
| 52 | + ''' |
| 53 | + def _print_And(self, expr): |
| 54 | + '''Called when converting an AND expression.''' |
| 55 | + return f"({'.AND.' .join(self._print(i) for i in expr.args)})" |
| 56 | + |
| 57 | + def _print_Or(self, expr): |
| 58 | + '''Called when converting an OR expression.''' |
| 59 | + return f"({'.OR.' .join(self._print(i) for i in expr.args)})" |
| 60 | + |
| 61 | + def _print_Equivalent(self, expr): |
| 62 | + '''Called when converting an EQUIVALENT expression.''' |
| 63 | + return f"({'.EQV.' .join(self._print(i) for i in expr.args)})" |
| 64 | + |
| 65 | + def _print_Xor(self, expr): |
| 66 | + '''Called when converting an XOR expression, which in Fortran |
| 67 | + is NEQV.''' |
| 68 | + return f"({'.NEQV.' .join(self._print(i) for i in expr.args)})" |
| 69 | + |
| 70 | + |
44 | 71 | class SymPyReader():
|
45 | 72 | '''This class converts a SymPy expression, that was created by the
|
46 | 73 | SymPyWriter, back to PSyIR. It basically allows to use SymPy to modify
|
@@ -121,7 +148,9 @@ def psyir_from_expression(self, sympy_expr, symbol_table):
|
121 | 148 | '''
|
122 | 149 | # Convert the new sympy expression to PSyIR
|
123 | 150 | reader = FortranReader()
|
124 |
| - return reader.psyir_from_expression(str(sympy_expr), symbol_table) |
| 151 | + fp = FortranPrinter() |
| 152 | + return reader.psyir_from_expression(fp.doprint(sympy_expr), |
| 153 | + symbol_table) |
125 | 154 |
|
126 | 155 | # -------------------------------------------------------------------------
|
127 | 156 | # pylint: disable=no-self-argument, too-many-branches
|
|
0 commit comments