Skip to content

Commit 1d1bebf

Browse files
authored
[cherry-pick 2.0-beta] Raise RuntimeError if run the callable object decorated by '@paddle.jit.to_static' not in dynamic mode. (#26750) (#27053)
Change-Id: I21a07cc2bc39acb753ab8fc00c72e269ddef0df1
1 parent a068168 commit 1d1bebf

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import gast
2626
from paddle.fluid import framework
27+
from paddle.fluid import in_dygraph_mode
2728
from paddle.fluid.dygraph import layers
2829
from paddle.fluid.data_feeder import check_type
2930
from paddle.fluid.layers.utils import flatten
@@ -32,6 +33,7 @@
3233
from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst
3334
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
3435
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
36+
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
3537
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
3638
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
3739
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
@@ -283,13 +285,21 @@ def __call__(self, *args, **kwargs):
283285
Return:
284286
Outputs of decorated function.
285287
"""
288+
286289
# 1. call dygraph function directly if not enable `declarative`
287290
if not self._program_trans.enable_declarative:
288-
warnings.warn(
289-
"The decorator '@paddle.jit.to_static' doesn't work when setting ProgramTranslator.enable=False. "
291+
logging_utils.warn(
292+
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
290293
"We will just return dygraph output.")
291294
return self._call_dygraph_function(*args, **kwargs)
292295

296+
if not in_dygraph_mode() and self._program_trans.enable_declarative:
297+
raise RuntimeError(
298+
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
299+
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
300+
"following API: paddle.disable_static().".format(
301+
self.dygraph_function))
302+
293303
# 2. trace ops from dygraph layers and cache the generated program.
294304
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
295305
try:

python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
import unittest
17+
1618
import paddle
17-
from paddle.static import InputSpec
1819
import paddle.fluid as fluid
20+
from paddle.static import InputSpec
1921
from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator, Layer, jit
2022
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ConcreteProgram
2123

22-
import unittest
24+
from test_basic_api_transformation import dyfunc_to_variable
2325

2426
program_trans = ProgramTranslator()
2527

@@ -181,6 +183,9 @@ def foo_func(a, b, c=1, d=2):
181183

182184

183185
class TestDifferentInputSpecCacheProgram(unittest.TestCase):
186+
def setUp(self):
187+
program_trans.enable(True)
188+
184189
def test_with_different_input(self):
185190
with fluid.dygraph.guard(fluid.CPUPlace()):
186191
x_data = np.ones([16, 10]).astype('float32')
@@ -272,5 +277,23 @@ def test_concrete_program(self):
272277
foo_3.concrete_program
273278

274279

280+
class TestDeclarativeAPI(unittest.TestCase):
281+
def test_error(self):
282+
func = declarative(dyfunc_to_variable)
283+
284+
paddle.enable_static()
285+
286+
# Failed to run the callable object decorated by '@paddle.jit.to_static'
287+
# if it does NOT in dynamic mode.
288+
with self.assertRaises(RuntimeError):
289+
func(np.ones(5).astype("int32"))
290+
291+
program_trans.enable(False)
292+
with self.assertRaises(AssertionError):
293+
# AssertionError: We Only support to_variable in imperative mode,
294+
# please use fluid.dygraph.guard() as context to run it in imperative Mode
295+
func(np.ones(5).astype("int32"))
296+
297+
275298
if __name__ == '__main__':
276299
unittest.main()

0 commit comments

Comments
 (0)