Skip to content

[SOT] Support recursive fallback in psdb and add ut for psdb #72381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(
def handle_psdb_function(self, /, *args, **kwargs):
# special function for inner debug.
if self.value is psdb.assert_true:
self.graph.add_global_guarded_variable(args[0])
return ConstantVariable.wrap_literal(
self.value(args[0].value), self.graph
)
Expand All @@ -220,7 +221,19 @@ def handle_psdb_function(self, /, *args, **kwargs):
PsdbBreakReason("breakgraph by psdb.breakgraph")
)
elif self.value is psdb.fallback:
raise FallbackError("fallback by psdb.fallback")
fallback_sig = inspect.signature(psdb.fallback)
bound_args = fallback_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
recursive_var = VariableFactory.from_value(
bound_args.arguments["recursive"],
graph=self.graph,
tracker=DanglingTracker(),
)
assert isinstance(recursive_var, ConstantVariable)
raise FallbackError(
f"Fallback by psdb.fallback (recursive={recursive_var.get_py_value()})",
disable_eval_frame=recursive_var.get_py_value(),
)
elif self.value is psdb.in_sot:
return ConstantVariable.wrap_literal(True, self.graph)
return None
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/sot/psdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def check_no_fallback(fn: Callable[P, T]) -> Callable[P, T]:
return fn


def fallback():
def fallback(recursive=False):
pass


Expand Down
120 changes: 120 additions & 0 deletions test/sot/test_psdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

from test_case_base import test_instruction_translator_cache_context

import paddle
from paddle.jit.sot import (
psdb,
symbolic_translate,
)
from paddle.jit.sot.utils.envs import strict_mode_guard
from paddle.jit.sot.utils.exceptions import InnerError


def assert_true_case(input: bool):
psdb.assert_true(input)


def breakgraph_case(x):
x = x + 1
psdb.breakgraph()
x = x + 1
return x


def fallback_not_recursive_inner(x):
x = x + 1
x = x + 1
return x


def fallback_not_recursive_case(x):
psdb.fallback(recursive=False)
x = fallback_not_recursive_inner(x)
return x


def fallback_recursive_case(x):
psdb.fallback(recursive=True)
x = fallback_not_recursive_inner(x)
return x


@psdb.check_no_breakgraph
def check_no_breakgraph_case(x):
x = x + 1
psdb.breakgraph()
x = x + 1
return x


@psdb.check_no_fallback
def check_no_fallback_case(x):
x = x + 1
psdb.fallback(recursive=False)
x = x + 1
return x


class TestPsdb(unittest.TestCase):
def test_assert_true(self):
# Test with True
symbolic_translate(assert_true_case)(True)

# Test with False
with self.assertRaises(InnerError):
symbolic_translate(assert_true_case)(False)

def test_breakgraph(self):
x = paddle.to_tensor([1.0])
with test_instruction_translator_cache_context() as ctx:
self.assertEqual(ctx.translate_count, 0)
symbolic_translate(breakgraph_case)(x)
self.assertEqual(ctx.translate_count, 2)

@strict_mode_guard(False)
def test_fallback_not_recursive(self):
x = paddle.to_tensor([1.0])
with test_instruction_translator_cache_context() as ctx:
self.assertEqual(ctx.translate_count, 0)
symbolic_translate(fallback_not_recursive_case)(x)
self.assertEqual(ctx.translate_count, 2)

@strict_mode_guard(False)
def test_fallback_recursive(self):
x = paddle.to_tensor([1.0])
with test_instruction_translator_cache_context() as ctx:
self.assertEqual(ctx.translate_count, 0)
symbolic_translate(fallback_recursive_case)(x)
self.assertEqual(ctx.translate_count, 1)

def test_check_no_breakgraph(self):
x = paddle.to_tensor([1.0])
with self.assertRaises(InnerError):
symbolic_translate(check_no_breakgraph_case)(x)

@strict_mode_guard(False)
def test_check_no_fallback(self):
x = paddle.to_tensor([1.0])
with self.assertRaises(InnerError):
symbolic_translate(check_no_fallback_case)(x)


if __name__ == "__main__":
unittest.main()
Loading