Skip to content

Commit 21a4207

Browse files
SigureMoCopilot
andauthored
[SOT] Add new internal API paddle.jit.marker.unified to mark an API as unified in dynamic and static mode and support custom op (#72466)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b94ea21 commit 21a4207

File tree

12 files changed

+361
-117
lines changed

12 files changed

+361
-117
lines changed

python/paddle/jit/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from . import marker as marker
1617
from .api import (
1718
ignore_module,
1819
json_to_pdmodel, # noqa: F401
1920
load,
20-
not_to_static,
2121
save,
2222
to_static,
2323
)
2424
from .dy2static.logging_utils import set_code_level, set_verbosity
2525
from .dy2static.program_translator import enable_to_static
26+
from .marker import (
27+
not_to_static,
28+
)
2629
from .translated_layer import TranslatedLayer
2730

2831
__all__ = [

python/paddle/jit/api.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
)
7070

7171
from .dy2static import logging_utils
72-
from .dy2static.convert_call_func import ConversionOptions, add_ignore_module
72+
from .dy2static.convert_call_func import add_ignore_module
7373
from .dy2static.program_translator import (
7474
ASTStaticFunction,
7575
ProgramTranslator,
@@ -346,69 +346,6 @@ def decorated(python_func):
346346
return decorated
347347

348348

349-
class _NotToStaticDecorator(Protocol):
350-
@overload
351-
def __call__(
352-
self, func: Callable[_InputT, _RetT]
353-
) -> Callable[_InputT, _RetT]: ...
354-
355-
@overload
356-
def __call__(self, func: None = ...) -> _NotToStaticDecorator: ...
357-
358-
359-
@overload
360-
def not_to_static(
361-
func: Callable[_InputT, _RetT],
362-
) -> Callable[_InputT, _RetT]: ...
363-
364-
365-
@overload
366-
def not_to_static(func: None = ...) -> _NotToStaticDecorator: ...
367-
368-
369-
def not_to_static(func=None):
370-
"""
371-
A Decorator to suppresses the convention of a function.
372-
373-
Args:
374-
func(callable): The function to decorate.
375-
376-
Returns:
377-
callable: A function which won't be converted in Dynamic-to-Static.
378-
379-
Examples:
380-
.. code-block:: python
381-
382-
>>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
383-
>>> import paddle
384-
385-
>>> @paddle.jit.not_to_static
386-
... def func_not_to_static(x):
387-
... res = x - 1
388-
... return res
389-
390-
>>> @paddle.jit.to_static
391-
... def func(x):
392-
... if paddle.mean(x) < 0:
393-
... out = func_not_to_static(x)
394-
... else:
395-
... out = x + 1
396-
... return out
397-
...
398-
>>> x = paddle.ones([1, 2], dtype='float32')
399-
>>> out = func(x)
400-
>>> print(out)
401-
Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
402-
[[2., 2.]])
403-
"""
404-
if func is None:
405-
return not_to_static
406-
407-
options = ConversionOptions(not_convert=True)
408-
options.attach(func)
409-
return func
410-
411-
412349
class _SaveLoadConfig:
413350
def __init__(self):
414351
self._output_spec = None

python/paddle/jit/dy2static/convert_call_func.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@
3737
)
3838
from .logging_utils import TranslatorLogger
3939
from .program_translator import (
40-
CONVERSION_OPTIONS,
4140
StaticFunction,
4241
convert_to_static,
4342
unwrap_decorators,
4443
)
45-
from .utils import is_builtin, is_paddle_func, patch_method_guard
44+
from .utils import (
45+
TransformOptions,
46+
is_builtin,
47+
is_paddle_func,
48+
patch_method_guard,
49+
)
4650

4751
if TYPE_CHECKING:
4852
from types import ModuleType
@@ -53,31 +57,6 @@
5357
translator_logger = TranslatorLogger()
5458

5559

56-
class ConversionOptions:
57-
"""
58-
A container for conversion flags of a function in dynamic-to-static.
59-
60-
Attributes:
61-
not_convert(bool): An attribute indicates that the function won't be converted in dynamic-to-static.
62-
63-
NOTE(liym27): More attributes and methods can be added in this class.
64-
"""
65-
66-
def __init__(self, not_convert=False):
67-
self.not_convert = not_convert
68-
69-
def attach(self, func):
70-
if inspect.ismethod(func):
71-
func = func.__func__
72-
73-
if inspect.isfunction(func):
74-
setattr(func, CONVERSION_OPTIONS, self)
75-
else:
76-
translator_logger.warn(
77-
f"Only support @not_to_static to type(function) or type(method), but received {type(func)}"
78-
)
79-
80-
8160
def builtin_modules():
8261
"""
8362
Return builtin modules.
@@ -259,8 +238,9 @@ def convert_call(func):
259238
# in this case, unwraps it into a raw method or function.
260239
_, func = unwrap_decorators(func)
261240

262-
options = getattr(func, CONVERSION_OPTIONS, None)
263-
if options is not None and options.not_convert:
241+
if not TransformOptions.check_fn_need_transform(
242+
func, TransformOptions.ToStaticMode.AST
243+
):
264244
translator_logger.log(
265245
2,
266246
"%s is not converted when it is decorated by 'paddle.jit.not_to_static'.",

python/paddle/jit/dy2static/program_translator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from .utils import (
6262
ALREADY_D2S,
6363
NO_SHAPE_VAR_TYPE,
64+
TransformOptions,
6465
ast_to_func,
6566
backend_guard,
6667
cuda_pinned_tensors_move_to_excepted_place,
@@ -86,8 +87,6 @@
8687
# Once exceeding the threshold, we will raise warning to users to make sure the conversion is as expected.
8788
MAX_TRACED_PROGRAM_COUNT = 10
8889

89-
CONVERSION_OPTIONS = "__jst_not_to_static"
90-
9190

9291
def synchronized(func):
9392
func.__lock__ = threading.Lock()
@@ -252,12 +251,13 @@ def convert_to_static(function):
252251
if getattr(function, ALREADY_D2S, None):
253252
return function
254253

255-
# Return directly if decorated with @not_to_static and DO NOT Cache it
256-
options = getattr(function, CONVERSION_OPTIONS, None)
254+
# Return directly if decorated with @jit.marker.unified and DO NOT Cache it
257255
# or ignore paddle api
258-
need_skip = (options is not None and options.not_convert) or is_paddle_func(
259-
function
260-
)
256+
need_skip = (
257+
not TransformOptions.check_fn_need_transform(
258+
function, TransformOptions.ToStaticMode.AST
259+
)
260+
) or is_paddle_func(function)
261261
if need_skip:
262262
return function.__func__ if inspect.ismethod(function) else function
263263

python/paddle/jit/dy2static/utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import types
3030
import warnings
3131
from contextlib import contextmanager
32-
from enum import Enum, auto
32+
from enum import Enum, Flag, auto
3333
from importlib.machinery import SourceFileLoader
3434
from typing import TYPE_CHECKING, Any
3535

@@ -115,6 +115,44 @@ def is_phi(self):
115115
return self == Backend.PHI
116116

117117

118+
class TransformOptions:
119+
120+
class ToStaticMode(Flag):
121+
SOT = auto()
122+
AST = auto()
123+
124+
@classmethod
125+
def Nil(cls):
126+
return cls(0)
127+
128+
TRANSFORM_OPTIONS_ATTR_NAME = "___jit_transform_options___"
129+
130+
def __init__(self, skip_transform_mode: ToStaticMode = ToStaticMode.Nil()):
131+
self.skip_transform_mode = skip_transform_mode
132+
133+
def attach(self, fn):
134+
if inspect.ismethod(fn):
135+
fn = fn.__func__
136+
137+
if inspect.isfunction(fn):
138+
setattr(fn, TransformOptions.TRANSFORM_OPTIONS_ATTR_NAME, self)
139+
else:
140+
warnings.warn(
141+
f"Only support @jit.marker.unified to type(function) or type(method), but received {type(fn)}"
142+
)
143+
144+
def need_transform(self, mode: ToStaticMode):
145+
return not (self.skip_transform_mode & mode)
146+
147+
@staticmethod
148+
def check_fn_need_transform(fn, mode: ToStaticMode):
149+
if not hasattr(fn, TransformOptions.TRANSFORM_OPTIONS_ATTR_NAME):
150+
return True
151+
return getattr(
152+
fn, TransformOptions.TRANSFORM_OPTIONS_ATTR_NAME
153+
).need_transform(mode)
154+
155+
118156
class TimeCounter:
119157
def __init__(self):
120158
self._time_history: list[float] = []

python/paddle/jit/marker.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import (
18+
Callable,
19+
Protocol,
20+
TypeVar,
21+
overload,
22+
)
23+
24+
from typing_extensions import (
25+
ParamSpec,
26+
)
27+
28+
from .dy2static.utils import (
29+
TransformOptions,
30+
)
31+
32+
_RetT = TypeVar("_RetT")
33+
_InputT = ParamSpec("_InputT")
34+
35+
36+
class _NotToStaticDecorator(Protocol):
37+
@overload
38+
def __call__(
39+
self, func: Callable[_InputT, _RetT]
40+
) -> Callable[_InputT, _RetT]: ...
41+
42+
@overload
43+
def __call__(self, func: None = ...) -> _NotToStaticDecorator: ...
44+
45+
46+
@overload
47+
def not_to_static(
48+
func: Callable[_InputT, _RetT],
49+
) -> Callable[_InputT, _RetT]: ...
50+
51+
52+
@overload
53+
def not_to_static(func: None = ...) -> _NotToStaticDecorator: ...
54+
55+
56+
# Legacy decorator only for AST
57+
def not_to_static(func=None):
58+
"""
59+
A Decorator to suppresses the convention of a function.
60+
61+
Args:
62+
func(callable): The function to decorate.
63+
64+
Returns:
65+
callable: A function which won't be converted in Dynamic-to-Static.
66+
67+
Examples:
68+
.. code-block:: python
69+
70+
>>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
71+
>>> import paddle
72+
73+
>>> @paddle.jit.not_to_static
74+
... def func_not_to_static(x):
75+
... res = x - 1
76+
... return res
77+
78+
>>> @paddle.jit.to_static
79+
... def func(x):
80+
... if paddle.mean(x) < 0:
81+
... out = func_not_to_static(x)
82+
... else:
83+
... out = x + 1
84+
... return out
85+
...
86+
>>> x = paddle.ones([1, 2], dtype='float32')
87+
>>> out = func(x)
88+
>>> print(out)
89+
Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
90+
[[2., 2.]])
91+
"""
92+
return unified(func, for_sot=False, for_ast=True)
93+
94+
95+
def unified(
96+
fn: Callable[_InputT, _RetT] | None = None,
97+
*,
98+
for_sot: bool = True,
99+
for_ast: bool = True,
100+
) -> Callable[_InputT, _RetT]:
101+
"""
102+
Mark a function already unified in dygraph and static mode. So
103+
that it won't be transformed again in SOT or AST mode.
104+
105+
Args:
106+
fn(callable): The function to decorate.
107+
for_sot(bool): Whether to mark the function as unified in SOT mode.
108+
for_ast(bool): Whether to mark the function as unified in AST mode.
109+
"""
110+
111+
def _mark_as_unified(fn, *, for_sot: bool, for_ast: bool):
112+
mode = TransformOptions.ToStaticMode.Nil()
113+
if for_sot:
114+
mode |= TransformOptions.ToStaticMode.SOT
115+
if for_ast:
116+
mode |= TransformOptions.ToStaticMode.AST
117+
options = TransformOptions(
118+
skip_transform_mode=mode,
119+
)
120+
options.attach(fn)
121+
return fn
122+
123+
if fn is None:
124+
return lambda fn: _mark_as_unified(fn, for_sot=for_sot, for_ast=for_ast)
125+
return _mark_as_unified(fn, for_sot=for_sot, for_ast=for_ast)

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@
5454
NameGenerator,
5555
SIRToCodeMap,
5656
SotUndefinedVar,
57+
already_unified_in_dynamic_and_static_graph,
5758
inner_error_default_handler,
5859
is_inplace_api,
59-
is_paddle_api,
6060
log,
6161
log_do,
6262
map_if,
@@ -549,7 +549,7 @@ def call_paddle_api(
549549
Args:
550550
func: paddle api
551551
"""
552-
assert is_paddle_api(func)
552+
assert already_unified_in_dynamic_and_static_graph(func)
553553
log(3, f"call paddle.api : {func.__name__}", "\n")
554554

555555
def message_handler(*args, **kwargs):

0 commit comments

Comments
 (0)