diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py index cb623141dc3daa..c011c3856fbcd3 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -43,6 +43,7 @@ from ...utils.exceptions import InnerError from ...utils.magic_methods import ( BINARY_OPS, + NEED_GUARD_ZERO_DIVISION_ERROR_OPS, UNARY_OPS, magic_method_builtin_dispatch, non_inplace_op_to_inplace_op, @@ -85,6 +86,7 @@ ) if TYPE_CHECKING: + from ...utils.magic_methods import BinaryOp from .variables import DataVariable, TensorVariable @@ -1036,6 +1038,23 @@ def is_not_func(var: VariableBase, other: VariableBase): ) +def apply_op_with_zero_division_check( + op: BinaryOp, lhs: VariableBase, rhs: VariableBase +): + + graph = lhs.graph + if op in NEED_GUARD_ZERO_DIVISION_ERROR_OPS: + call_eq = BuiltinVariable(operator.eq, graph, DanglingTracker()) + zero = ConstantVariable.wrap_literal(0, graph) + rhs_eq_to_zero = call_eq(rhs, zero) + add_guard(rhs_eq_to_zero) + return VariableFactory.from_value( + op(lhs.get_py_value(), rhs.get_py_value()), + graph, + DummyTracker([lhs, rhs]), + ) + + # Constant for unary_fn in UNARY_OPS: for magic_method in magic_method_builtin_dispatch(unary_fn): @@ -1060,11 +1079,7 @@ def is_not_func(var: VariableBase, other: VariableBase): "ConstantVariable | NumPyNumberVariable", ), partial( - lambda fn, var, other: VariableFactory.from_value( - fn(var.get_py_value(), other.get_py_value()), - var.graph, - tracker=DummyTracker([var, other]), - ), + apply_op_with_zero_division_check, binary_fn, ), ) diff --git a/python/paddle/jit/sot/symbolic_shape/operators.py b/python/paddle/jit/sot/symbolic_shape/operators.py index 99bdb8fbebf400..cf9d0e30432fae 100644 --- a/python/paddle/jit/sot/symbolic_shape/operators.py +++ b/python/paddle/jit/sot/symbolic_shape/operators.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ..utils.magic_methods import BinaryOp, UnaryOp diff --git a/python/paddle/jit/sot/utils/magic_methods.py b/python/paddle/jit/sot/utils/magic_methods.py index f835470d04e1ff..3ae1f01d992324 100644 --- a/python/paddle/jit/sot/utils/magic_methods.py +++ b/python/paddle/jit/sot/utils/magic_methods.py @@ -95,6 +95,18 @@ UNARY_OPS = set(UNARY_OPS_TO_MAGIC_NAMES.keys()) +# NOTE: Both operator.pow and operator.ipow should be considered for inclusion in this list, +# as they raise ZeroDivisionError when evaluating 0^n where n < 0 (division by zero). +NEED_GUARD_ZERO_DIVISION_ERROR_OPS: list[BinaryOp] = [ + operator.floordiv, + operator.truediv, + operator.mod, + operator.ifloordiv, + operator.itruediv, + operator.imod, +] + + @dataclass class MagicMethod: name: str diff --git a/test/sot/test_24_exceptions.py b/test/sot/test_24_exceptions.py index 42647978c65ffa..6a1de1b0c3380e 100644 --- a/test/sot/test_24_exceptions.py +++ b/test/sot/test_24_exceptions.py @@ -944,5 +944,29 @@ def fn(): self.assert_results(fn) +class TestBuiltinFunctionRaiseExceptionGuard(TestCaseBase): + def test_guard_run(self): + def foo_floordiv(x): + 1 / x + + def foo_mod(x): + 2 % x + + self.assert_results(foo_floordiv, 1) + self.assert_exceptions( + ZeroDivisionError, + "division by zero", + foo_floordiv, + 0, + ) + self.assert_results(foo_mod, 10) + self.assert_exceptions( + ZeroDivisionError, + "integer (.)*modulo by zero", + foo_mod, + 0, + ) + + if __name__ == "__main__": unittest.main()