From 1e5d4fb6fcd9feebe87b2b6d0eec09c5b241186d Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Tue, 29 Apr 2025 11:45:40 +0000 Subject: [PATCH 1/4] add exception for 3.10- --- .../executor/exception_stack.py | 119 ++++++++ .../executor/opcode_executor.py | 277 +++++++++++++++++- .../executor/variable_dispatch.py | 82 ++++-- .../executor/variables/__init__.py | 1 + .../executor/variables/basic.py | 198 +++++++++++++ .../executor/variables/callable.py | 13 +- test/sot/test_24_exceptions.py | 217 ++++++++++++++ test/sot/test_sot_exception.py | 2 +- 8 files changed, 876 insertions(+), 33 deletions(-) create mode 100644 python/paddle/jit/sot/opcode_translator/executor/exception_stack.py create mode 100644 test/sot/test_24_exceptions.py diff --git a/python/paddle/jit/sot/opcode_translator/executor/exception_stack.py b/python/paddle/jit/sot/opcode_translator/executor/exception_stack.py new file mode 100644 index 00000000000000..b0f5d01757307a --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/exception_stack.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses + +from ...utils import InnerError +from .variables import ConstantVariable, ExceptionVariable + + +@dataclasses.dataclass +class ExceptionStack: + + # This data structure manages exceptions as in CPython, primarily handling + # the __context__ attribute of SotCapturedException. + + _exception_stack: list[ExceptionVariable | None] = dataclasses.field( + default_factory=list + ) + _current_exception: ExceptionVariable | None = dataclasses.field( + default=None + ) + + def clear_current_exception(self): + self._current_exception = None + + def set_current_exception(self, val: ExceptionVariable, graph): + self._set_context_and_break_context_reference_cycle(val, graph) + self._current_exception = val + + def move_current_exception_to_stack(self): + self.push(self.get_current_exception()) + self.clear_current_exception() + + def get_current_exception(self): + if self._current_exception is None: + raise InnerError("Current exception should not be None") + return self._current_exception + + def _set_context_recursive(self, val: ExceptionVariable, prev_idx): + # Recursively sets the __context__ attribute for ExceptionVariable objects + # in self._exception_stack. Ensures that __context__ is properly linked + # to the previous exception in the stack. + if (ctx := val.__context__) and type(ctx) is not ConstantVariable: + return val + if ( + len(self._exception_stack) + prev_idx > 0 + ): # Prevent invalid negative indexing + prev = self._exception_stack[prev_idx] + self._set_context_recursive(prev, prev_idx - 1) + val.setattr("__context__", prev) + return val + + def _break_context_reference_cycle(self, val: ExceptionVariable, graph): + # Detects and breaks cycles in exception __context__ chains using Floyd's algorithm, + # following CPython's implementation. + + fast = slow = val + slow_update_toggle = False + while True: + context = fast.__context__ + if ( + type(context) is ConstantVariable + ): # End of the chain; no context set + break + + if context is val: + # The chain loops back to the original exception; break the cycle. + fast.setattr( + "__context__", ConstantVariable.wrap_literal(None, graph) + ) + break + + fast = context + if fast is slow: + # Cycle detected; all exceptions on the path have been visited and checked. + break + + if slow_update_toggle: + slow = slow.__context__ + slow_update_toggle = not slow_update_toggle + + def _set_context_and_break_context_reference_cycle( + self, val: ExceptionVariable, graph + ): + # set Exception.__context__ + self._set_context_recursive(val, len(self._exception_stack) - 1) + self._break_context_reference_cycle(val, graph) + + def pop(self): + return self._exception_stack.pop() + + def push(self, val): + self._exception_stack.append(val) + + def __len__(self): + return len(self._exception_stack) + + def __str__(self): + return f"{self._exception_stack}" + + def __getitem__(self, idx): + return self._exception_stack[idx] + + def cleanup(self): + self._exception_stack[:] = [] + self._current_exception = None diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index a75f8fcc46dd06..d9c4dcd10b0ce1 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -39,6 +39,8 @@ DataDependencyDynamicShapeBreak, FallbackError, InnerError, + SotCapturedException, + SotCapturedExceptionFactory, SotUndefinedVar, UnsupportedIteratorBreak, UnsupportedOperationBreak, @@ -85,12 +87,14 @@ ConstTracker, DanglingTracker, DummyTracker, + GetAttrTracker, ) from .variables import ( BuiltinVariable, ConstantVariable, ContainerVariable, DictVariable, + ExceptionVariable, IterVariable, ListVariable, MethodVariable, @@ -114,6 +118,8 @@ from .variable_stack import VariableStack from .virtual_frame import VirtualFrame +from .exception_stack import ExceptionStack +from .virtual_frame import BlockStackItem COMPARE_OP_NAME_TO_FN = { ">": operator.gt, @@ -341,6 +347,18 @@ def inner(*args, **kwargs): return inner +def fallback_if_python_version_unsupported(fn: Callable): + def inner(*args, **kwargs): + if sys.version_info >= (3, 11): + raise FallbackError( + "SOT currently only partially supports exception handling (Python 3.10 and below). " + "Unsupported exception bytecode will fall back to dynamic graph mode." + ) + return fn(*args, **kwargs) + + return inner + + def parse_force_fallback_sir_ids() -> set[int]: ids_string = ENV_SOT_FORCE_FALLBACK_SIR_IDS.get() if not ids_string: @@ -404,6 +422,7 @@ class EmptyCode: call_stack: list[OpcodeExecutorBase] = [] empty_code = EmptyCode() + exception_stack = ExceptionStack() def __init__(self, vframe: VirtualFrame, graph: FunctionGraph): OpcodeExecutorBase.call_stack.append(self) @@ -646,7 +665,101 @@ def step(self, instr: Instruction): opname = opname if opname != "PRECALL" else "PRECALL__CALL" assert opname != "CALL", "CALL should fused with PRECALL" with EventGuard(f"{opname}", event_level=2): - return getattr(self, opname)(instr) # run single step. + try: + return getattr(self, opname)(instr) # run single step. + except SotCapturedException as e: + self.handle_exception(e) + + def handle_exception(self, e: SotCapturedException): + # TODO(DrRyanHuang): The newly created ExceptionVariable might differ from the previous one + e_var = VariableFactory.from_value(e, self._graph, DummyTracker([])) + + # The exception is not raised by `raise Exception` + if ( + len(self.exception_stack) == 0 + or self.exception_stack.get_current_exception() != e_var + ): + self.exception_stack.set_current_exception(e_var, self._graph) + + if len(self.vframe.block_stack): + # The implementation is referenced from the exception_unwind section + # of CPython's main_loop. + block_stack_entry = self.vframe.block_stack.pop() + while block_stack_entry.inst.opname == "EXCEPT_HANDLER": + # Remove previous EXCEPT_HANDLER entries, which indicate that the + # exception has already been handled. Continue until a SETUP_FINALLY + # block is encountered, which signifies an active exception handler. + self.stack.pop_n(3) + self.exception_stack.pop() + if len(self.vframe.block_stack) == 0: + # Since the block stack is empty, no handler was found in this + # frame, so the exception is propagated to the outer function for handling. + self.stack.pop_n( + len(self.stack) + ) # clear stack to prevent memory leaks + raise e + block_stack_entry = self.vframe.block_stack.pop() + + exception_var = self.exception_stack.get_current_exception() + self.exception_stack.move_current_exception_to_stack() + + # Pop elements from the stack to restore it to the depth recorded by the block, + # ensuring the stack state matches that prior to exception handling. + while len(self.stack) > block_stack_entry.level: + self.stack.pop() + + # Push a dummy EXCEPT_HANDLER block onto the stack to indicate that exception + # handling has begun and to record the current stack level. + EXCEPT_HANDLER_INSTRUCTION = Instruction( + 257, "EXCEPT_HANDLER", None, 0 + ) + self.vframe.block_stack.append( + BlockStackItem( + EXCEPT_HANDLER_INSTRUCTION.opname, + EXCEPT_HANDLER_INSTRUCTION, + None, + len(self.stack), + ) + ) + + # Push the old exception variables (tb, value, type) onto stack + if len(self.exception_stack) >= 2: + old_exception = self.exception_stack[-2] + + # Current SOT implementation does not track traceback information, + # so Traceback is represented as ConstantVariable(None) + self.stack.push( + ConstantVariable.wrap_literal(None, self._graph) + ) + self.stack.push(old_exception) + self.stack.push( + BuiltinVariable( + old_exception.exc_type, + self._graph, + DummyTracker([]), + ) + ) + else: + for _ in range(3): + self.stack.push( + ConstantVariable.wrap_literal(None, self._graph) + ) + + # Push current exception - tb, val, type + self.stack.push(ConstantVariable.wrap_literal(None, self._graph)) + self.stack.push(exception_var) + self.stack.push( + BuiltinVariable( + exception_var.exc_type, + self._graph, + GetAttrTracker(exception_var, "__class__"), + ) + ) + + self.jump_to(block_stack_entry.handler) + else: + self.stack.pop_n(len(self.stack)) + raise e def indexof(self, instr: Instruction): """ @@ -1926,6 +2039,154 @@ def CALL_INTRINSIC_1(self, instr: Instruction): else: raise FallbackError(f"No support Intrinsics, {intrinsic_func.name}") + @fallback_if_python_version_unsupported + def SETUP_FINALLY(self, instr: Instruction): + self.vframe.block_stack.append( + BlockStackItem(instr.opname, instr, instr.jump_to, len(self.stack)) + ) + + @fallback_if_python_version_unsupported + def POP_BLOCK(self, instr: Instruction): + self.vframe.block_stack.pop() + + @fallback_if_python_version_unsupported + def LOAD_ASSERTION_ERROR(self, instr: Instruction): + value = self.vframe.builtins["AssertionError"] + self.stack.push(value) + + @fallback_if_python_version_unsupported + def POP_EXCEPT(self, instr: Instruction): + assert len(self.vframe.block_stack) > 0 + + if self.vframe.block_stack[-1].inst.opname != "EXCEPT_HANDLER": + raise FallbackError( + "Bug in SOT tracing of exception handling." + "Top of the block stack is not EXCEPT_HANDLER." + ) + + self.vframe.block_stack.pop() + self.stack.pop_n(3) + + assert len(self.exception_stack) + self.exception_stack.pop() + + @staticmethod + def _create_exception_instance(val): + if isinstance(val, BuiltinVariable): + val = val.call_function() + return val + + @staticmethod + def _is_exception_isinstance(val): + return isinstance(val, ExceptionVariable) + + def _raise_exception_instance( + self, val: ExceptionVariable | BuiltinVariable + ): + # TODO(DrRyanHuang): need to support user-defined Exception + + val = self._create_exception_instance(val) + self.exception_stack.set_current_exception(val, self._graph) + + if self._is_exception_isinstance(val): + raise SotCapturedExceptionFactory.create( + origin_exc=val.get_py_value() + ) + + raise FallbackError("Attempted to raise a non-Exception type/value.") + + @fallback_if_python_version_unsupported + def RAISE_VARARGS(self, instr: Instruction): + if instr.arg == 0: + if not len(self.exception_stack): + msg = ConstantVariable.wrap_literal( + "No active exception to reraise", self._graph + ) + self.raise_sot_captured_exception(RuntimeError, msg) + + assert len(self.exception_stack) + val = self.exception_stack[-1] + assert self._is_exception_isinstance(val), val + self._raise_exception_instance(val) + elif instr.arg == 1: + val = self.stack.top + self._raise_exception_instance(val) + else: + # raise .. from ... + from_exc = self.stack.pop() + val = self.stack.pop() + + # type -> instance + val = self._create_exception_instance(val) + self.exception_stack.set_current_exception(val, self._graph) + + # Update __cause__/__suppress_context__ in the raised exception + cause = self._create_exception_instance(from_exc) + val.setattr("__cause__", cause) + + raise SotCapturedExceptionFactory.create( + origin_exc=val.get_py_value() + ) + + def check_if_exception_matches(self): + assert len(self.stack) >= 2 + expected_exc_types = self.stack.pop() + exc_instance = self.stack.pop() + + if isinstance(expected_exc_types, TupleVariable): + expected_types = expected_exc_types.get_wrapped_items() + else: + expected_types = [ + expected_exc_types, + ] + + for expected_type in expected_types: + if not isinstance(expected_type, BuiltinVariable): + raise FallbackError( + f"`except ...` requires a BuiltinVariable as the exception type, but received: {expected_type}." + ) + + # Exception -> SotCapturedException + expected_type_exception = SotCapturedExceptionFactory.get( + expected_type.get_py_value() + ) + + if self._is_exception_isinstance(exc_instance) and issubclass( + exc_instance.exc_type, + expected_type_exception, + ): + return True + elif isinstance(exc_instance, BuiltinVariable) and issubclass( + exc_instance.get_py_value(), expected_type_exception + ): + return True + + return False + + @fallback_if_python_version_unsupported + def JUMP_IF_NOT_EXC_MATCH(self, instr: Instruction): + if not self.check_if_exception_matches(): + self.jump_to(instr.jump_to) + + @fallback_if_python_version_unsupported + def RERAISE(self, instr: Instruction): + _exc_type = self.stack.pop() + _exc_instance = self.stack.pop() + _traceback = self.stack.pop() + self._raise_exception_instance(_exc_instance) + + def raise_sot_captured_exception( + self, + exc_type: type[Exception], + *args, + **kwargs, + ): + exc = BuiltinVariable( + exc_type, self._graph, DummyTracker(list(args)) + ).call_function(*args, **kwargs) + self.exception_stack.set_current_exception(exc, self._graph) + raise SotCapturedExceptionFactory.get(exc_type) + class OpcodeExecutor(OpcodeExecutorBase): """ @@ -1971,6 +2232,7 @@ def cleanup(self): self._graph.pycode_gen = None Dispatcher.graph = None self.call_stack[:] = [] + self.exception_stack.cleanup() def FOR_ITER(self, instr): iterator = self.stack.pop() @@ -2082,6 +2344,18 @@ def get_compute_fn_and_update_changed_vars( compile_graph_result, store_vars, store_var_info ) + def fallback_when_block_stack_is_empty(self): + """ + SOT currently doesn't support a non-empty block stack (related to exception handling), + triggering a fallback. + """ + + if len(self.vframe.block_stack): + raise FallbackError( + 'SOT currently does not support a non-empty block stack, ' + 'triggering a fallback\n' + ) + @fallback_when_occur_error def _break_graph_when_if(self, result: TensorVariable, instr: Instruction): """ @@ -2092,6 +2366,7 @@ def _break_graph_when_if(self, result: TensorVariable, instr: Instruction): instr: The jump instruction. """ + self.fallback_when_block_stack_is_empty() self._graph.add_global_guarded_variable(result) # 1. analyse info diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py index 1ea7a1dd7dc2ed..f45a8ab9daae47 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -14,6 +14,7 @@ from __future__ import annotations +import builtins import inspect import math import operator @@ -66,6 +67,7 @@ ContainerVariable, DictVariable, EnumerateVariable, + ExceptionVariable, IterVariable, ListVariable, MapVariable, @@ -85,6 +87,24 @@ from .variables import DataVariable, TensorVariable +# NOTE(SigureMo): Don't directly capture free var inside for-loop, use partial instead. +# ```python +# lambdas = [] +# for i in range(10): +# lambdas.append(lambda: i) +# for fn in lambdas: +# print(fn()) # result is 9, 9, 9, 9, 9, 9, 9, 9, 9, 9 +# ``` +# Rewrite by partial: +# ```python +# lambdas = [] +# for i in range(10): +# lambdas.append(partial(lambda i: i, i)) +# for fn in lambdas: +# print(fn()) # result is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 +# ``` + + def add_guard(var: VariableBase): var.graph.add_global_guarded_variable(var) return var @@ -245,19 +265,6 @@ def inner(*args, **kwargs): ) -# super -Dispatcher.register( - super, - ("ClassVariable", "VariableBase"), - lambda cls, obj: SuperVariable( - cls=cls, - obj=obj, - graph=Dispatcher.graph, - tracker=DummyTracker([cls, obj]), - ), -) - - @Dispatcher.register_decorator(dict) def dispatch_dict_kwargs(**kwargs: VariableBase): res_dict = {} @@ -610,6 +617,38 @@ def dispatch_list_ne(lhs: ListVariable, rhs: ListVariable): lambda var: var.len(), ) +# super +Dispatcher.register( + super, + ("ClassVariable", "VariableBase"), + lambda cls, obj: SuperVariable( + cls=cls, + obj=obj, + graph=Dispatcher.graph, + tracker=DummyTracker([cls, obj]), + ), +) + + +def register_exception(exc): + @Dispatcher.register_decorator(exc) + def builtin_exception_dispatcher(*args) -> int: + return ExceptionVariable( + exc, + *args, + graph=Dispatcher.graph, + tracker=DummyTracker([]), + ) + + +# builtin Exception +for name, obj in builtins.__dict__.items(): + if not (isinstance(obj, type) and issubclass(obj, Exception)): + continue + + register_exception(obj) + + # range # stop Dispatcher.register( @@ -996,23 +1035,6 @@ def is_not_func(var: VariableBase, other: VariableBase): ) -# NOTE(SigureMo): Don't directly capture free var inside for-loop, use partial instead. -# ```python -# lambdas = [] -# for i in range(10): -# lambdas.append(lambda: i) -# for fn in lambdas: -# print(fn()) # result is 9, 9, 9, 9, 9, 9, 9, 9, 9, 9 -# ``` -# Rewrite by partial: -# ```python -# lambdas = [] -# for i in range(10): -# lambdas.append(partial(lambda i: i, i)) -# for fn in lambdas: -# print(fn()) # result is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 -# ``` - # Constant for unary_fn in UNARY_OPS: for magic_method in magic_method_builtin_dispatch(unary_fn): diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py b/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py index 5aedaaac262eda..90a926e037354e 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py @@ -23,6 +23,7 @@ ConstantVariable, DataVariable, DygraphTracerVariable, + ExceptionVariable, FunctionGlobalVariable, GlobalVariable, ModuleVariable, diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index 43c09acbdd9010..b35e5248959995 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -2028,3 +2028,201 @@ def proxy_getter(self, proxy: MutableDictLikeData, key: Any): self.graph, tracker=FunctionGlobalTracker(self.fn, key), ) + + +class ExceptionVariable(VariableBase): + # The ExceptionVariable corresponds to the Exception class in Python + mutable_attrs = [ + "__context__", + "__cause__", + "__suppress_context__", + "__traceback__", + ] + + def __init__( + self, + exc: Exception | type[Exception], + *args, + graph: FunctionGraph = None, + tracker: Tracker = None, + ) -> None: + super().__init__(graph=graph, tracker=tracker) + + self.record_exception = False + if isinstance(exc, Exception): + exc_type = exc.__class__ + self.record_exception = True + + elif isinstance(exc, type) and issubclass(exc, Exception): + exc_type = exc + + else: + # TODO(DrRyanHuang): Should `exc_type` be a `BuiltinVariable`? + raise InnerError( + f"ExceptionVariable parameter `exc` should be an Exception class or instance, but got `{type(exc)}`:`{exc}`" + ) + py_args = [] + for arg in args: + # TODO(DrRyanHuang): Should `args` be a tuple containing exclusively `VariableBase`? + if not isinstance(arg, VariableBase): + raise InnerError( + f"ExceptionVariable parameter `args` be a tuple containing exclusively `VariableBase`, but got `{type(arg)}`:`{arg}`" + ) + py_args.append(arg.get_py_value()) + + self.exc = exc if self.record_exception else exc(*py_args) + self.exc_type = exc_type + self.args = args + + self.__context__ = VariableFactory.from_value( + self.exc.__context__, graph=graph, tracker=tracker + ) + + # raise ... from ... + self.__cause__ = VariableFactory.from_value( + self.exc.__cause__, graph=graph, tracker=tracker + ) + + self.__suppress_context__ = VariableFactory.from_value( + self.exc.__suppress_context__, graph=graph, tracker=tracker + ) + + # NOTE: Currently, since our primary goal is to trace the network structure of variables, + # __traceback__ is always set to None. + self.__traceback__ = ConstantVariable.wrap_literal(None, self.graph) + + self.graph.side_effects.record_mutable_variable(self) + + def get_py_type(self): + return self.exc_type + + def get_py_value(self): + if self.record_exception: + exception = self.exc + else: + exception = self.exc_type( + *[arg.get_py_value() for arg in self.args] + ) + + exception.__context__ = ( + exception.__context__ or self.__context__.get_py_value() + ) + exception.__cause__ = ( + exception.__cause__ or self.__cause__.get_py_value() + ) + exception.__suppress_context__ = exception.__suppress_context__ or ( + self.__suppress_context__.get_py_value() + ) + return exception + + @property + def main_info(self) -> dict[str, Any]: + return { + "exception_cls": self.exc_type, + } + + def setattr(self, key: str, value): + # TODO(DrRyanHuang): Add UserDefinedException to __context__ and __cause__ + # TODO(DrRyanHuang): Do users also manually set exception attributes, and should we change FallbackError/InnerError to TypeError? + if key == "__context__": + if ( + isinstance(value, ConstantVariable) + and value.get_py_value() is None + ) or isinstance( + value, + (ExceptionVariable), + ): + self.__context__ = value + else: + raise FallbackError( + f"`__context__` must be an ExceptionVariable, bug got {type(value)}:{value}" + ) + elif key == "__cause__": + if ( + isinstance(value, ConstantVariable) + and value.get_py_value() is None + ) or isinstance( + value, + (ExceptionVariable), + ): + self.__cause__ = value + self.__suppress_context__ = ConstantVariable.wrap_literal( + True, self.graph + ) + else: + raise FallbackError( + "exception cause must be None or derive from BaseException" + ) + elif key == "__suppress_context__": + if isinstance(value, ConstantVariable) and value.get_py_value() in ( + True, + False, + ): + self.__suppress_context__ = value + else: + raise FallbackError("Type of __suppress_context__ must be bool") + elif key == "__traceback__": + if ( + isinstance(value, ConstantVariable) + and value.get_py_value() is None + ): + self.__traceback__ = value + else: + raise FallbackError( + "Currently, SOT doesn't record information of __traceback__" + ) + else: + raise InnerError(f"ExceptionVariable don't need attribute {key}") + + def getattr(self, name: str, default=None) -> VariableBase: + + if name == "__traceback__": + return ConstantVariable.wrap_literal(None, self.graph) + + if name == "args": + from .container import ListVariable + + return ListVariable( + self.args, self.graph, GetAttrTracker(self, "args") + ) + + return super().getattr(name, default) + + def __str__(self): + return f"{self.__class__.__name__}({self.exc_type})" + + def __repr__(self): + return self.__str__() + + @VariableFactory.register_from_value() + def from_value(value: Exception, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, Exception): + args = [ + ConstantVariable.wrap_literal(arg, graph) for arg in value.args + ] + exception_var = ExceptionVariable( + value.__class__, *args, graph=graph, tracker=tracker + ) + if value.__context__ is not None: + exception_var.setattr( + "__context__", + VariableFactory.from_value( + value.__context__, graph=graph, tracker=tracker + ), + ) + if value.__cause__ is not None: + exception_var.setattr( + "__cause__", + VariableFactory.from_value( + value.__cause__, graph=graph, tracker=tracker + ), + ) + exception_var.setattr( + "__suppress_context__", + ConstantVariable.wrap_literal( + value.__suppress_context__, graph + ), + ) + + return exception_var + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py index edc7af65e272e0..63dd8ecec03c2e 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -59,6 +59,8 @@ InnerError, OtherInlineCallBreak, PsdbBreakReason, + SotCapturedException, + SotCapturedExceptionFactory, SotErrorBase, UnsupportedOperationBreak, UnsupportedPaddleAPIBreak, @@ -269,6 +271,8 @@ def call_function(self, /, *args, **kwargs) -> VariableBase: f"Inline Call: {inline_executor.vframe.code.co_name}, file {inline_executor.vframe.code.co_filename}, line {int(inline_executor.vframe.code.co_firstlineno)}" ): output = inline_executor.inline_call() + except (SotCapturedException, InnerError) as e: + raise e except SotErrorBase as error: self.graph.restore_memo(checkpoint) filename = self.value.__code__.co_filename @@ -783,7 +787,14 @@ def call_function(self, /, *args, **kwargs): handler = Dispatcher.dispatch(self.value, *args, **kwargs) if handler is not None: - return handler(*args, **kwargs) + try: + return handler(*args, **kwargs) + except SotErrorBase as e: + # NOTE: BuiltinVariable.call_function cat not raise SotCapturedException, + # so we can directly raise SotErrorBase. + raise + except Exception as e: + raise SotCapturedExceptionFactory.create(origin_exc=e) from e if ENV_SOT_ALLOW_DYNAMIC_SHAPE.get() and any( isinstance(var, SymbolicVariable) diff --git a/test/sot/test_24_exceptions.py b/test/sot/test_24_exceptions.py new file mode 100644 index 00000000000000..5ffb75cadad081 --- /dev/null +++ b/test/sot/test_24_exceptions.py @@ -0,0 +1,217 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +from test_case_base import ( + TestCaseBase, +) + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph +from paddle.jit.sot.utils import strict_mode_guard + +NOT_ALLOW_FALLBACK = sys.version_info < (3, 11) + + +class TestNestingCase(TestCaseBase): + @strict_mode_guard(NOT_ALLOW_FALLBACK) + @check_no_breakgraph + def test_try_nesting(self): + def try_nesting_wo_error(x): + try: + try: + try: + try: + try: + try: + x -= 1 + raise ValueError( + "TESTING" + ) # RAISE_VARARGS(1) + x += 2 + except NotImplementedError: + x /= 3 + raise # RAISE_VARARGS(0) + except (KeyError, IndexError): + x *= 4 + pass + except ValueError: + x += 5 + raise NameError # RAISE_VARARGS(1) + except SyntaxError: + x /= 6 + pass + except (TypeError, FileNotFoundError, NameError) as e: + x -= 7 + raise TimeoutError( + "TESTING" + ) from e # RAISE_VARARGS(2) + except: + x /= 8 + raise AssertionError + except IndentationError as e: + x *= 9 + pass + except AssertionError as e: + x += 10 + raise # RAISE_VARARGS(0) + except Exception as e: + x /= 11 + + return x + 12 + + self.assert_results(try_nesting_wo_error, paddle.to_tensor(0.5)) + + @strict_mode_guard(NOT_ALLOW_FALLBACK) + @check_no_breakgraph + def test_function_nesting(self): + def raise_value_error_obj(x): + x += 1 + raise ValueError("") + + def raise_value_error_cls(x): + x += 2 + raise ValueError + + def raise_zero_div_error(x): + x += 3 + return 19.0 / 0 + + def raise_assert_error(x): + x += 4 + assert [] + + def raise_not_implemented_error(x): + x += 5 + raise NotImplementedError + + def one_nesting(x, func): + x *= 6 + func(x) + + def two_nesting(x, func): + x /= 7 + one_nesting(x, func) + + def three_nesting(x, func): + x -= 8 + two_nesting(x, func) + + def get_test_func(x, func=None): + try: + x += 1 + try: + x /= 2 + three_nesting(x, func) + x -= 3 + except ValueError: + x *= 4 + except: + x += 5 + return x # / 6 + + self.assert_results( + get_test_func, paddle.to_tensor(0.3), raise_value_error_obj + ) + self.assert_results( + get_test_func, paddle.to_tensor(0.4), raise_value_error_cls + ) + self.assert_results( + get_test_func, paddle.to_tensor(0.5), raise_zero_div_error + ) + self.assert_results( + get_test_func, paddle.to_tensor(0.6), raise_assert_error + ) + self.assert_results( + get_test_func, paddle.to_tensor(0.7), raise_not_implemented_error + ) + + +class TestAssertException(TestCaseBase): + @staticmethod + def try_assert(x, condition): + # test py value or paddle tensor value as condition + try: + x += 1 + try: + x /= 2 + raise TimeoutError("TESTING") + except: + x -= 3 + assert condition + except: + x *= 4 + return x / 5 + + @strict_mode_guard(NOT_ALLOW_FALLBACK) + def test_assert_with_py_var_as_condition(self): + # Test the case where `condition` is Python variable + self.assert_results(self.try_assert, paddle.to_tensor(1), False) + self.assert_results(self.try_assert, paddle.to_tensor(1), True) + self.assert_results(self.try_assert, paddle.to_tensor(1), []) + self.assert_results(self.try_assert, paddle.to_tensor(1), [1]) + self.assert_results(self.try_assert, paddle.to_tensor(1), "") + self.assert_results(self.try_assert, paddle.to_tensor(1), "QAQ") + self.assert_results(self.try_assert, paddle.to_tensor(1), ValueError) + self.assert_results(self.try_assert, paddle.to_tensor(1), ValueError()) + + # Currently, since the assert statement is essentially an if statement and can cause breakgraph, + # using a Tensor as a condition is not supported. Therefore, fallback is allowed. + @strict_mode_guard(False) + def test_assert_with_tensor_as_condition(self): + # Test the case where `condition` is Paddle Tensor + self.assert_results( + self.try_assert, paddle.to_tensor(2), paddle.to_tensor(1) + ) + self.assert_results( + self.try_assert, paddle.to_tensor(2), paddle.to_tensor(0) + ) + self.assert_results( + self.try_assert, paddle.to_tensor(2), paddle.to_tensor(-1) + ) + + @strict_mode_guard(False) + def test_assert_true(self): + @check_no_breakgraph + def try_assert_except(x): + x += 1 + try: + x += 2 + assert x > -10000 + x += 3 + except: + x += 4 + pass + + self.assert_results(try_assert_except, paddle.to_tensor(10)) + + @strict_mode_guard(False) + def test_assert_false(self): + @check_no_breakgraph + def try_assert_except(x): + try: + x += 5 + assert x < -10000 + except AssertionError: + x += 6 + pass + + return x + + self.assert_results(try_assert_except, paddle.to_tensor(10)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_sot_exception.py b/test/sot/test_sot_exception.py index 21594ea3bab188..d9407df2c0a621 100644 --- a/test/sot/test_sot_exception.py +++ b/test/sot/test_sot_exception.py @@ -55,7 +55,7 @@ def case5_inner3(x): def case5_inner2(x): x += 1 - z = case5_inner3(1 / 0) + z = case5_inner3(y) # noqa: F821 return z + 1 From 0aa572d4f74ab9e8a3fe4f8f9d4a6341e5a64bd5 Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Tue, 6 May 2025 11:44:11 +0000 Subject: [PATCH 2/4] add py3.8 adaptation && add 2 dispatcher --- .../executor/opcode_executor.py | 47 +++-------- .../executor/variable_dispatch.py | 30 +++++++ .../executor/variables/basic.py | 78 ++++++++++++++++--- test/sot/test_24_exceptions.py | 28 +++---- 4 files changed, 121 insertions(+), 62 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index d9c4dcd10b0ce1..73afcc4cf2fe8a 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -138,6 +138,13 @@ # In Python 3.13, the method layout is changed, and a NULL will be pushed after the value. CALL_METHOD_LAYOUT_NULL_AFTER_VALUE = sys.version_info >= (3, 13) +ALREADY_SUPPORTED_EXCEPTION = sys.version_info >= ( + 3, + 9, +) and sys.version_info < ( + 3, + 11, +) @dataclass @@ -2128,44 +2135,14 @@ def RAISE_VARARGS(self, instr: Instruction): origin_exc=val.get_py_value() ) - def check_if_exception_matches(self): + @fallback_if_python_version_unsupported + def JUMP_IF_NOT_EXC_MATCH(self, instr: Instruction): assert len(self.stack) >= 2 expected_exc_types = self.stack.pop() exc_instance = self.stack.pop() - - if isinstance(expected_exc_types, TupleVariable): - expected_types = expected_exc_types.get_wrapped_items() - else: - expected_types = [ - expected_exc_types, - ] - - for expected_type in expected_types: - if not isinstance(expected_type, BuiltinVariable): - raise FallbackError( - f"`except ...` requires a BuiltinVariable as the exception type, but received: {expected_type}." - ) - - # Exception -> SotCapturedException - expected_type_exception = SotCapturedExceptionFactory.get( - expected_type.get_py_value() - ) - - if self._is_exception_isinstance(exc_instance) and issubclass( - exc_instance.exc_type, - expected_type_exception, - ): - return True - elif isinstance(exc_instance, BuiltinVariable) and issubclass( - exc_instance.get_py_value(), expected_type_exception - ): - return True - - return False - - @fallback_if_python_version_unsupported - def JUMP_IF_NOT_EXC_MATCH(self, instr: Instruction): - if not self.check_if_exception_matches(): + if not ExceptionVariable.check_if_exception_matches( + exc_instance, expected_exc_types + ): self.jump_to(instr.jump_to) @fallback_if_python_version_unsupported diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py index f45a8ab9daae47..378b3d07c51284 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -50,6 +50,7 @@ from .dispatch_functions import ( create_raise_break_graph_handler, generator_send, + operator_exception_match, operator_in, operator_is_none, operator_is_not_none, @@ -1524,6 +1525,23 @@ def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable): lambda left, right: constant_numpy_equal(left, right), ) + +# `operator.eq` of `ExceptionVariable` dispatch +def exception_variable_equal(left, right): + result = (left is right) or (left.get_py_value() == right.get_py_value()) + return VariableFactory.from_value( + result, + left.graph, + tracker=DummyTracker([left, right]), + ) + + +Dispatcher.register( + operator.eq, + ("ExceptionVariable", "ExceptionVariable"), + lambda left, right: exception_variable_equal(left, right), +) + Dispatcher.register( bool, ("NumpyVariable",), @@ -1633,3 +1651,15 @@ def dispatch_all(var: ContainerVariable | IterVariable): not x.get_py_value(allow_tensor=False), x.graph, DummyTracker([x]) ), ) + + +Dispatcher.register( + operator_exception_match, + ("BuiltinVariable | ExceptionVariable", "BuiltinVariable | TupleVariable"), + lambda exc_instance, expected_exc_types: ConstantVariable.wrap_literal( + ExceptionVariable.check_if_exception_matches( + exc_instance, expected_exc_types + ), + exc_instance.graph, + ), +) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index b35e5248959995..1b379132fb7cca 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -78,6 +78,7 @@ DataDependencyOperationBreak, FallbackError, NameGenerator, + SotCapturedExceptionFactory, UnsupportedOperationBreak, get_tensor_methods, log, @@ -124,7 +125,8 @@ from ..function_graph import FunctionGraph from ..pycode_generator import PyCodeGen - from .callable import ClassVariable, FunctionVariable + from .callable import BuiltinVariable, ClassVariable, FunctionVariable + from .container import TupleVariable SymbolicConstraint: TypeAlias = tuple[ ConstraintNode, dict[str, "SymbolicVariable"] @@ -2104,15 +2106,15 @@ def get_py_value(self): *[arg.get_py_value() for arg in self.args] ) - exception.__context__ = ( - exception.__context__ or self.__context__.get_py_value() - ) - exception.__cause__ = ( - exception.__cause__ or self.__cause__.get_py_value() - ) - exception.__suppress_context__ = exception.__suppress_context__ or ( - self.__suppress_context__.get_py_value() - ) + exception.__context__ = ( + exception.__context__ or self.__context__.get_py_value() + ) + exception.__cause__ = ( + exception.__cause__ or self.__cause__.get_py_value() + ) + exception.__suppress_context__ = exception.__suppress_context__ or ( + self.__suppress_context__.get_py_value() + ) return exception @property @@ -2175,7 +2177,6 @@ def setattr(self, key: str, value): raise InnerError(f"ExceptionVariable don't need attribute {key}") def getattr(self, name: str, default=None) -> VariableBase: - if name == "__traceback__": return ConstantVariable.wrap_literal(None, self.graph) @@ -2194,6 +2195,44 @@ def __str__(self): def __repr__(self): return self.__str__() + @classmethod + def check_if_exception_matches( + cls, + exc_instance: BuiltinVariable | ExceptionVariable, + expected_exc_types: BuiltinVariable | TupleVariable, + ): + """ + try: exc_instance except: expected_exc_types + """ + from .callable import BuiltinVariable + from .container import TupleVariable + + if isinstance(expected_exc_types, TupleVariable): + expected_types = expected_exc_types.get_wrapped_items() + else: + expected_types = [ + expected_exc_types, + ] + for expected_type in expected_types: + if not isinstance(expected_type, BuiltinVariable): + raise FallbackError( + f"`except ...` requires a BuiltinVariable as the exception type, but received: {expected_type}." + ) + # Exception -> SotCapturedException + expected_type_exception = SotCapturedExceptionFactory.get( + expected_type.get_py_value() + ) + if isinstance(exc_instance, ExceptionVariable) and issubclass( + exc_instance.exc_type, + expected_type_exception, + ): + return True + elif isinstance(exc_instance, BuiltinVariable) and issubclass( + exc_instance.get_py_value(), expected_type_exception + ): + return True + return False + @VariableFactory.register_from_value() def from_value(value: Exception, graph: FunctionGraph, tracker: Tracker): if isinstance(value, Exception): @@ -2226,3 +2265,20 @@ def from_value(value: Exception, graph: FunctionGraph, tracker: Tracker): return exception_var return None + + # def __eq__(self, other: ExceptionVariable) -> bool: + # if sys.version_info >= (3, 8) and sys.version_info < (3, 9): + # raise FallbackError("Python version >= 3.8 but < 3.9") + + # # `operator.eq` of `ExceptionVariable` dispatch + # def exception_variable_equal(left, right): + # result = (left is right) or ( + # left.get_py_value() == right.get_py_value() + # ) + # return VariableFactory.from_value( + # result, + # left.graph, + # tracker=DummyTracker([left, right]), + # ) + + # return exception_variable_equal(self, other) diff --git a/test/sot/test_24_exceptions.py b/test/sot/test_24_exceptions.py index 5ffb75cadad081..7da81cc1b58a32 100644 --- a/test/sot/test_24_exceptions.py +++ b/test/sot/test_24_exceptions.py @@ -22,7 +22,7 @@ from paddle.jit.sot.psdb import check_no_breakgraph from paddle.jit.sot.utils import strict_mode_guard -NOT_ALLOW_FALLBACK = sys.version_info < (3, 11) +NOT_ALLOW_FALLBACK = sys.version_info < (3, 11) and sys.version_info >= (3, 9) class TestNestingCase(TestCaseBase): @@ -46,13 +46,11 @@ def try_nesting_wo_error(x): raise # RAISE_VARARGS(0) except (KeyError, IndexError): x *= 4 - pass except ValueError: x += 5 raise NameError # RAISE_VARARGS(1) except SyntaxError: x /= 6 - pass except (TypeError, FileNotFoundError, NameError) as e: x -= 7 raise TimeoutError( @@ -63,7 +61,6 @@ def try_nesting_wo_error(x): raise AssertionError except IndentationError as e: x *= 9 - pass except AssertionError as e: x += 10 raise # RAISE_VARARGS(0) @@ -159,13 +156,14 @@ def try_assert(x, condition): def test_assert_with_py_var_as_condition(self): # Test the case where `condition` is Python variable self.assert_results(self.try_assert, paddle.to_tensor(1), False) - self.assert_results(self.try_assert, paddle.to_tensor(1), True) - self.assert_results(self.try_assert, paddle.to_tensor(1), []) - self.assert_results(self.try_assert, paddle.to_tensor(1), [1]) - self.assert_results(self.try_assert, paddle.to_tensor(1), "") - self.assert_results(self.try_assert, paddle.to_tensor(1), "QAQ") - self.assert_results(self.try_assert, paddle.to_tensor(1), ValueError) - self.assert_results(self.try_assert, paddle.to_tensor(1), ValueError()) + self.assert_results(self.try_assert, paddle.to_tensor(2), True) + self.assert_results(self.try_assert, paddle.to_tensor(3), []) + self.assert_results(self.try_assert, paddle.to_tensor(4), [1]) + self.assert_results(self.try_assert, paddle.to_tensor(5), "") + self.assert_results(self.try_assert, paddle.to_tensor(6), "QAQ") + # TODO(DrRyanHuang): The following two cases are not supported yet. + # self.assert_results(self.try_assert, paddle.to_tensor(7), ValueError) + # self.assert_results(self.try_assert, paddle.to_tensor(8), ValueError()) # Currently, since the assert statement is essentially an if statement and can cause breakgraph, # using a Tensor as a condition is not supported. Therefore, fallback is allowed. @@ -173,13 +171,13 @@ def test_assert_with_py_var_as_condition(self): def test_assert_with_tensor_as_condition(self): # Test the case where `condition` is Paddle Tensor self.assert_results( - self.try_assert, paddle.to_tensor(2), paddle.to_tensor(1) + self.try_assert, paddle.to_tensor(8), paddle.to_tensor(1) ) self.assert_results( - self.try_assert, paddle.to_tensor(2), paddle.to_tensor(0) + self.try_assert, paddle.to_tensor(9), paddle.to_tensor(0) ) self.assert_results( - self.try_assert, paddle.to_tensor(2), paddle.to_tensor(-1) + self.try_assert, paddle.to_tensor(10), paddle.to_tensor(-1) ) @strict_mode_guard(False) @@ -193,7 +191,6 @@ def try_assert_except(x): x += 3 except: x += 4 - pass self.assert_results(try_assert_except, paddle.to_tensor(10)) @@ -206,7 +203,6 @@ def try_assert_except(x): assert x < -10000 except AssertionError: x += 6 - pass return x From 63ca2ac062ed68eb2933b3d1278c4745da126d8c Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Wed, 7 May 2025 06:08:42 +0000 Subject: [PATCH 3/4] fix HasNoAttributeError --- .../executor/variables/basic.py | 5 +- test/sot/test_23_super.py | 58 +++++++++++++++++++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index fb7dbd2aedd40b..5d7739b1a2c97e 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -87,7 +87,6 @@ ) from ....utils.envs import ENV_SOT_BREAK_GRAPH_ON_GET_SYMBOLIC_VALUE from ....utils.exceptions import ( - HasNoAttributeError, InnerError, UnsupportedPaddleAPIBreak, ) @@ -1511,9 +1510,7 @@ def getattr(self, name: str, default=None) -> VariableBase: attr = attr.bind(self.obj, name) return attr - raise HasNoAttributeError( - f"{self.obj.__class__.__name__} {self} has no attribute {name}" - ) + return super().getattr(name) @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): diff --git a/test/sot/test_23_super.py b/test/sot/test_23_super.py index 5991a4004ee8f5..943e42ffed814a 100644 --- a/test/sot/test_23_super.py +++ b/test/sot/test_23_super.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import sys import types import unittest +import numpy as np from test_case_base import ( TestCaseBase, test_instruction_translator_cache_context, @@ -24,6 +26,11 @@ from paddle.jit.sot.psdb import check_no_breakgraph from paddle.jit.sot.utils.exceptions import InnerError +sot_test_dir = os.path.dirname(__file__) +sys.path.insert(0, os.path.abspath(f'{sot_test_dir}/../dygraph_to_static')) + +from dygraph_to_static_utils import Dy2StTestBase, test_default_mode_only + # ---------------------- test single inheritance case ---------------------- class A: @@ -364,5 +371,56 @@ def test_super_function_as_input(self): ) +# ------ test SuperVariable setattr + getattr ------ +class LrClassBase: + def __init__(self, last_lr, last_epoch): + self.last_lr = last_lr + self.last_epoch = last_epoch + self.step() + + def step(self): + self.last_epoch += 1 + self.last_lr = self.get_lr() + + def get_lr(self): + return self.last_lr + 0.0001 + + def __call__(self) -> float: + return self.last_lr + + +class LrClassSub(LrClassBase): + def __init__(self, **kwargs): + super().__init__( + kwargs.get("last_lr", 0.01), kwargs.get("last_epoch", 0) + ) + + +class TestSuperSetattrGetattr(Dy2StTestBase): + def setUp(self): + def dyfunc(lr_decay): + lr = lr_decay() + return paddle.to_tensor(lr) + + lr_decay = LrClassSub( + base_lr=0.1, verbose=True, last_lr=0.3, last_epoch=3 + ) + self.dygraph_func = lambda: dyfunc(lr_decay) + + def get_dygraph_output(self): + res = self.dygraph_func() + return res + + def get_static_output(self): + static_res = paddle.jit.to_static(self.dygraph_func)() + return static_res + + @test_default_mode_only + def test_transformed_static_result(self): + dygraph_res = self.get_dygraph_output() + static_res = self.get_static_output() + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + + if __name__ == "__main__": unittest.main() From 71c4bcc389aceb5e4a87f28511498c2a34abfa1a Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Wed, 7 May 2025 12:50:38 +0000 Subject: [PATCH 4/4] 2023 -> 2025 --- .../jit/sot/opcode_translator/executor/exception_stack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/exception_stack.py b/python/paddle/jit/sot/opcode_translator/executor/exception_stack.py index b0f5d01757307a..14c4f866261453 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/exception_stack.py +++ b/python/paddle/jit/sot/opcode_translator/executor/exception_stack.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.