Skip to content

Commit 83baab9

Browse files
authored
[cherry-pick 2.0-beta][Dy2Stat] Transforme api 'to_tensor' to 'assign'. (#26873) (#27055)
Change-Id: Ic5b211f1bab42067715297fe58a78646e13e048d
1 parent 873da75 commit 83baab9

File tree

4 files changed

+63
-43
lines changed

4 files changed

+63
-43
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py

+43-8
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
import gast
1717

1818
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
19-
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
20-
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
21-
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
19+
from paddle.fluid.dygraph.dygraph_to_static import utils
2220

2321

2422
class BasicApiTransformer(gast.NodeTransformer):
@@ -56,7 +54,7 @@ def visit_Expr(self, node):
5654
if isinstance(child_node, gast.Call):
5755
# TODO(liym27):
5856
# Considers that a dygraph api which modifies the input or has a output.
59-
if is_dygraph_api(child_node):
57+
if utils.is_dygraph_api(child_node):
6058
return
6159
else:
6260
self._visit_Call(child_node)
@@ -73,7 +71,7 @@ def _visit_Call(self, node):
7371

7472
if self._is_dygraph_forward(func_name):
7573
class_node = self._get_class_node(func_name)
76-
static_node = to_static_ast(node, class_node)
74+
static_node = utils.to_static_ast(node, class_node)
7775
return static_node
7876
else:
7977
return node
@@ -91,14 +89,51 @@ def _update_class_node_dict(self, node):
9189
if is_to_variable(node_value):
9290
return False
9391

94-
if is_dygraph_api(node_value):
92+
if utils.is_dygraph_api(node_value):
9593
dygraph_api = node_value.func.attr
96-
if not dygraph_class_to_static_api.get(dygraph_api):
94+
if not utils.dygraph_class_to_static_api.get(dygraph_api):
9795
return False
9896

99-
update_args_of_func(node_value, node_value, "__init__")
97+
utils.update_args_of_func(node_value, node_value, "__init__")
10098
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
10199
self.class_node_dict[target_str] = node_value
102100
return True
103101
# TODO: node.value is not dygraph class
104102
return False
103+
104+
105+
def is_to_variable(node):
106+
assert isinstance(node, gast.Call)
107+
api_name = utils.ast_to_source_code(node.func).strip()
108+
109+
if utils.is_dygraph_api(node):
110+
return api_name.endswith("to_variable")
111+
112+
if utils.is_paddle_api(node):
113+
return api_name.endswith("to_tensor")
114+
115+
return False
116+
117+
118+
def to_assign_node(node):
119+
# Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `fluid.layers.assign`.
120+
# NOTE:
121+
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
122+
# but api `assign` only supports {float32, float64, int32, int64, bool};
123+
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
124+
125+
assert isinstance(node, gast.Call)
126+
assign_api = gast.parse('fluid.layers.assign').body[0].value
127+
node.func = assign_api
128+
129+
if node.args:
130+
node.args = [node.args[0]]
131+
node.keywords = []
132+
else:
133+
for idx, kw in enumerate(node.keywords):
134+
if kw.arg == 'value' or kw.arg == 'data':
135+
node.keywords[idx].arg = 'input'
136+
node.keywords = [node.keywords[idx]]
137+
node.args = []
138+
break
139+
return node

python/paddle/fluid/dygraph/dygraph_to_static/utils.py

+8-33
Original file line numberDiff line numberDiff line change
@@ -136,25 +136,31 @@ def is_api_in_module(node, module_prefix):
136136
# import_str = "".join(import_statements)
137137
import paddle
138138
import paddle.fluid as fluid
139+
import paddle.fluid.dygraph as dygraph
139140
import paddle.fluid.layers as layers
141+
140142
from paddle.fluid.dygraph import to_variable
141-
import paddle.fluid.dygraph as dygraph
143+
from paddle import to_tensor
144+
142145
return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
143146
module_prefix))
144147
except NameError:
145148
return False
146149

147150

148151
def is_dygraph_api(node):
152+
149153
# Note: A api in module dygraph_to_static is not a real dygraph api.
150154
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"):
151155
return False
152156

157+
# TODO(liym27): A better way to determine whether it is a dygraph api.
158+
# Consider the decorator @dygraph_only
153159
return is_api_in_module(node, "paddle.fluid.dygraph")
154160

155161

156162
def is_paddle_api(node):
157-
return is_api_in_module(node, "paddle.fluid")
163+
return is_api_in_module(node, "paddle")
158164

159165

160166
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
@@ -233,14 +239,6 @@ def _add_keywords_to(node, dygraph_api_name):
233239
return
234240

235241

236-
def is_to_variable(node):
237-
assert isinstance(node, gast.Call)
238-
if is_dygraph_api(node):
239-
api_name = ast_to_source_code(node.func).strip()
240-
return api_name.endswith("to_variable")
241-
return False
242-
243-
244242
def to_static_ast(node, class_node):
245243
assert isinstance(node, gast.Call)
246244
assert isinstance(class_node, gast.Call)
@@ -268,29 +266,6 @@ def to_static_ast(node, class_node):
268266
return node
269267

270268

271-
def to_assign_node(node):
272-
# Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`.
273-
# NOTE:
274-
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
275-
# but api `assign` only supports {float32, float64, int32, int64, bool};
276-
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
277-
assert isinstance(node, gast.Call)
278-
assign_api = gast.parse('fluid.layers.assign').body[0].value
279-
node.func = assign_api
280-
281-
if node.args:
282-
node.args = [node.args[0]]
283-
node.keywords = []
284-
else:
285-
for idx, kw in enumerate(node.keywords):
286-
if kw.arg == 'value':
287-
node.keywords[idx].arg = 'input'
288-
node.keywords = [node.keywords[idx]]
289-
node.args = []
290-
break
291-
return node
292-
293-
294269
def update_args_of_func(node, dygraph_node, method_name):
295270
assert isinstance(node, gast.Call)
296271
if method_name not in ["__init__", "forward"]:

python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import inspect
2020
import gast
2121

22+
import paddle
2223
import paddle.fluid as fluid
2324
import paddle.fluid.dygraph as dygraph
2425

26+
from paddle import to_tensor
2527
from paddle.fluid.dygraph import to_variable
2628
from paddle.fluid.dygraph.jit import dygraph_to_static_func
2729
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
@@ -45,11 +47,19 @@ def dyfunc_to_variable_3(x):
4547
return res
4648

4749

50+
def dyfunc_to_tensor(x):
51+
res1 = paddle.to_tensor(x, dtype=None, place=None, stop_gradient=True)
52+
res2 = paddle.tensor.to_tensor(data=res1)
53+
res3 = to_tensor(data=res2)
54+
return res3
55+
56+
4857
class TestDygraphBasicApi_ToVariable(unittest.TestCase):
4958
def setUp(self):
5059
self.input = np.ones(5).astype("int32")
5160
self.test_funcs = [
52-
dyfunc_to_variable, dyfunc_to_variable_2, dyfunc_to_variable_3
61+
dyfunc_to_tensor, dyfunc_to_variable, dyfunc_to_variable_2,
62+
dyfunc_to_variable_3
5363
]
5464
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
5565
) else fluid.CPUPlace()

python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)