Skip to content

Commit da36481

Browse files
authored
[SOT] Add tracker for generator in Python 3.11+ (#71777)
1 parent 4830713 commit da36481

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .dispatch_functions import generator_send
2929
from .guard import StringifiedExpression, union_free_vars
3030
from .opcode_executor import OpcodeExecutorBase, Stop
31-
from .tracker import DanglingTracker, DummyTracker, Tracker
31+
from .tracker import DanglingTracker, Tracker
3232
from .variables import (
3333
BuiltinVariable,
3434
ConstantVariable,
@@ -245,8 +245,9 @@ def inline_call(self) -> VariableBase:
245245
def RETURN_GENERATOR(self, instr: Instruction):
246246
vframe = self.vframe
247247
code_var = self._code_var
248+
# NOTE: we set the real tracker in calling function
248249
self.return_value = GeneratorVariable(
249-
code_var, vframe, self._graph, DummyTracker([]) # TODO: Add tracker
250+
code_var, vframe, self._graph, DanglingTracker()
250251
)
251252
return Stop(state="Return")
252253

python/paddle/jit/sot/opcode_translator/executor/variables/callable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,12 @@ def call_function(self, /, *args, **kwargs):
891891
inline_gen_executor = OpcodeInlineGeneratorExecutor(
892892
vframe, code_var, self.graph
893893
)
894-
return inline_gen_executor.inline_call()
894+
gen = inline_gen_executor.inline_call()
895+
assert isinstance(
896+
gen, GeneratorVariable
897+
), f"GeneratorFunction calling result should be GeneratorVariable, but got {type(gen)}"
898+
gen.tracker = DummyTracker([self, *args, *kwargs.values()])
899+
return gen
895900
return GeneratorVariable(
896901
code_var,
897902
vframe,

0 commit comments

Comments
 (0)