From 2d63858cc28cf4f812ae81c2ee047a274037e43f Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 19 Mar 2025 14:50:36 +0800 Subject: [PATCH] [SOT] Add tracker for generator in Python 3.11+ --- .../opcode_translator/executor/opcode_inline_executor.py | 5 +++-- .../sot/opcode_translator/executor/variables/callable.py | 7 ++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py index 710be3a2249c0..074b6327d578a 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -28,7 +28,7 @@ from .dispatch_functions import generator_send from .guard import StringifiedExpression, union_free_vars from .opcode_executor import OpcodeExecutorBase, Stop -from .tracker import DanglingTracker, DummyTracker, Tracker +from .tracker import DanglingTracker, Tracker from .variables import ( BuiltinVariable, ConstantVariable, @@ -245,8 +245,9 @@ def inline_call(self) -> VariableBase: def RETURN_GENERATOR(self, instr: Instruction): vframe = self.vframe code_var = self._code_var + # NOTE: we set the real tracker in calling function self.return_value = GeneratorVariable( - code_var, vframe, self._graph, DummyTracker([]) # TODO: Add tracker + code_var, vframe, self._graph, DanglingTracker() ) return Stop(state="Return") diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py index c5b26973d6f6f..4015055d1cd7c 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -889,7 +889,12 @@ def call_function(self, /, *args, **kwargs): inline_gen_executor = OpcodeInlineGeneratorExecutor( vframe, code_var, self.graph ) - return inline_gen_executor.inline_call() + gen = inline_gen_executor.inline_call() + assert isinstance( + gen, GeneratorVariable + ), f"GeneratorFunction calling result should be GeneratorVariable, but got {type(gen)}" + gen.tracker = DummyTracker([self, *args, *kwargs.values()]) + return gen return GeneratorVariable( code_var, vframe,