diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index cbd41e38f1c68b..ec91fd85dc2781 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -2400,7 +2400,9 @@ def create_after_loop_fn(): self._graph.pycode_gen.gen_store(name, self._code) # 5. compile sub graph before for-loop - update_names = list(loop_body_read_names | after_loop_read_names) + update_names = list( + OrderedSet(loop_body_inputs[:-1]) | after_loop_read_names + ) extra_store_vars = ( [iterator] if isinstance(iterator, IterVariable) diff --git a/test/sot/test_12_for_loop.py b/test/sot/test_12_for_loop.py index 1ececd87d434f0..d16a173ffda3f8 100644 --- a/test/sot/test_12_for_loop.py +++ b/test/sot/test_12_for_loop.py @@ -342,5 +342,19 @@ def test_for_break_with_load_same_consts(self): self.assert_results(for_break_with_load_same_consts, x) +def for_break_with_write_pre_defined_name(x: paddle.Tensor): + y = None + for i in [1, 2, 3]: + y = i + sot.psdb.breakgraph() + return x + 1 + + +class TestForBreakWithWritePreDefinedName(TestCaseBase): + def test_for_break_with_write_pre_defined_name(self): + x = paddle.to_tensor(1) + self.assert_results(for_break_with_write_pre_defined_name, x) + + if __name__ == "__main__": unittest.main()