From 32549781fb08bb34d4cf887032ae4424436f9036 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 25 Nov 2024 13:51:25 +0000 Subject: [PATCH] [SOT] Store all loop body inputs when for breakgraph --- .../opcode_translator/executor/opcode_executor.py | 4 +++- test/sot/test_12_for_loop.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index cbd41e38f1c68..ec91fd85dc278 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -2400,7 +2400,9 @@ def create_after_loop_fn(): self._graph.pycode_gen.gen_store(name, self._code) # 5. compile sub graph before for-loop - update_names = list(loop_body_read_names | after_loop_read_names) + update_names = list( + OrderedSet(loop_body_inputs[:-1]) | after_loop_read_names + ) extra_store_vars = ( [iterator] if isinstance(iterator, IterVariable) diff --git a/test/sot/test_12_for_loop.py b/test/sot/test_12_for_loop.py index 1ececd87d434f..d16a173ffda3f 100644 --- a/test/sot/test_12_for_loop.py +++ b/test/sot/test_12_for_loop.py @@ -342,5 +342,19 @@ def test_for_break_with_load_same_consts(self): self.assert_results(for_break_with_load_same_consts, x) +def for_break_with_write_pre_defined_name(x: paddle.Tensor): + y = None + for i in [1, 2, 3]: + y = i + sot.psdb.breakgraph() + return x + 1 + + +class TestForBreakWithWritePreDefinedName(TestCaseBase): + def test_for_break_with_write_pre_defined_name(self): + x = paddle.to_tensor(1) + self.assert_results(for_break_with_write_pre_defined_name, x) + + if __name__ == "__main__": unittest.main()