From 06c6954e7bb995fe2ed1364a4d11a041aa33f1c0 Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Fri, 30 May 2025 08:54:38 +0000 Subject: [PATCH 1/4] fix zerodiv not raising --- .../executor/variable_dispatch.py | 25 +++++++++++++++---- .../jit/sot/symbolic_shape/operators.py | 1 - python/paddle/jit/sot/utils/magic_methods.py | 11 ++++++++ test/sot/test_24_exceptions.py | 24 ++++++++++++++++++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py index cb623141dc3daa..c011c3856fbcd3 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -43,6 +43,7 @@ from ...utils.exceptions import InnerError from ...utils.magic_methods import ( BINARY_OPS, + NEED_GUARD_ZERO_DIVISION_ERROR_OPS, UNARY_OPS, magic_method_builtin_dispatch, non_inplace_op_to_inplace_op, @@ -85,6 +86,7 @@ ) if TYPE_CHECKING: + from ...utils.magic_methods import BinaryOp from .variables import DataVariable, TensorVariable @@ -1036,6 +1038,23 @@ def is_not_func(var: VariableBase, other: VariableBase): ) +def apply_op_with_zero_division_check( + op: BinaryOp, lhs: VariableBase, rhs: VariableBase +): + + graph = lhs.graph + if op in NEED_GUARD_ZERO_DIVISION_ERROR_OPS: + call_eq = BuiltinVariable(operator.eq, graph, DanglingTracker()) + zero = ConstantVariable.wrap_literal(0, graph) + rhs_eq_to_zero = call_eq(rhs, zero) + add_guard(rhs_eq_to_zero) + return VariableFactory.from_value( + op(lhs.get_py_value(), rhs.get_py_value()), + graph, + DummyTracker([lhs, rhs]), + ) + + # Constant for unary_fn in UNARY_OPS: for magic_method in magic_method_builtin_dispatch(unary_fn): @@ -1060,11 +1079,7 @@ def is_not_func(var: VariableBase, other: VariableBase): "ConstantVariable | NumPyNumberVariable", ), partial( - lambda fn, var, other: VariableFactory.from_value( - fn(var.get_py_value(), other.get_py_value()), - var.graph, - tracker=DummyTracker([var, other]), - ), + apply_op_with_zero_division_check, binary_fn, ), ) diff --git a/python/paddle/jit/sot/symbolic_shape/operators.py b/python/paddle/jit/sot/symbolic_shape/operators.py index 99bdb8fbebf400..cf9d0e30432fae 100644 --- a/python/paddle/jit/sot/symbolic_shape/operators.py +++ b/python/paddle/jit/sot/symbolic_shape/operators.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ..utils.magic_methods import BinaryOp, UnaryOp diff --git a/python/paddle/jit/sot/utils/magic_methods.py b/python/paddle/jit/sot/utils/magic_methods.py index f835470d04e1ff..0260f49d05aee7 100644 --- a/python/paddle/jit/sot/utils/magic_methods.py +++ b/python/paddle/jit/sot/utils/magic_methods.py @@ -95,6 +95,17 @@ UNARY_OPS = set(UNARY_OPS_TO_MAGIC_NAMES.keys()) +# operator.pow operator.ipow, 当底数为0且指数为负数时,会引发ZeroDivisionError +NEED_GUARD_ZERO_DIVISION_ERROR_OPS: list[BinaryOp] = [ + operator.floordiv, + operator.truediv, + operator.mod, + operator.ifloordiv, + operator.itruediv, + operator.imod, +] + + @dataclass class MagicMethod: name: str diff --git a/test/sot/test_24_exceptions.py b/test/sot/test_24_exceptions.py index 42647978c65ffa..3e02819d8cf7be 100644 --- a/test/sot/test_24_exceptions.py +++ b/test/sot/test_24_exceptions.py @@ -944,5 +944,29 @@ def fn(): self.assert_results(fn) +class TestBuiltinFunctionRaiseExceptionGuard(TestCaseBase): + def test_guard_run(self): + def foo_floordiv(x): + 1 / x + + def foo_mod(x): + 2 % x + + self.assert_results(foo_floordiv, 1) + self.assert_exceptions( + ZeroDivisionError, + "division by zero", + foo_floordiv, + 0, + ) + self.assert_results(foo_mod, 10) + self.assert_exceptions( + ZeroDivisionError, + "integer division or modulo by zero", + foo_mod, + 0, + ) + + if __name__ == "__main__": unittest.main() From efc0d3ad215f6e56a5389cd65dd1c51e1d55ac1a Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Fri, 30 May 2025 09:07:23 +0000 Subject: [PATCH 2/4] zh->en --- python/paddle/jit/sot/utils/magic_methods.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/jit/sot/utils/magic_methods.py b/python/paddle/jit/sot/utils/magic_methods.py index 0260f49d05aee7..3ae1f01d992324 100644 --- a/python/paddle/jit/sot/utils/magic_methods.py +++ b/python/paddle/jit/sot/utils/magic_methods.py @@ -95,7 +95,8 @@ UNARY_OPS = set(UNARY_OPS_TO_MAGIC_NAMES.keys()) -# operator.pow operator.ipow, 当底数为0且指数为负数时,会引发ZeroDivisionError +# NOTE: Both operator.pow and operator.ipow should be considered for inclusion in this list, +# as they raise ZeroDivisionError when evaluating 0^n where n < 0 (division by zero). NEED_GUARD_ZERO_DIVISION_ERROR_OPS: list[BinaryOp] = [ operator.floordiv, operator.truediv, From 252474e4a4be494cb1f86cb46dc330809604a87f Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Fri, 30 May 2025 14:30:10 +0000 Subject: [PATCH 3/4] adaptation for 3.11+ --- test/sot/test_24_exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sot/test_24_exceptions.py b/test/sot/test_24_exceptions.py index 3e02819d8cf7be..0b2aaaa015b4b9 100644 --- a/test/sot/test_24_exceptions.py +++ b/test/sot/test_24_exceptions.py @@ -962,7 +962,7 @@ def foo_mod(x): self.assert_results(foo_mod, 10) self.assert_exceptions( ZeroDivisionError, - "integer division or modulo by zero", + "integer (.)* modulo by zero", foo_mod, 0, ) From 0bfc83a68ef161d47ef9ca7de8a19f5bb2957ff8 Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Tue, 3 Jun 2025 02:50:43 +0000 Subject: [PATCH 4/4] fix re --- test/sot/test_24_exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sot/test_24_exceptions.py b/test/sot/test_24_exceptions.py index 0b2aaaa015b4b9..6a1de1b0c3380e 100644 --- a/test/sot/test_24_exceptions.py +++ b/test/sot/test_24_exceptions.py @@ -962,7 +962,7 @@ def foo_mod(x): self.assert_results(foo_mod, 10) self.assert_exceptions( ZeroDivisionError, - "integer (.)* modulo by zero", + "integer (.)*modulo by zero", foo_mod, 0, )