Skip to content

[CherryPick] Support TypeHint for function decorated by @to_static #47147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/einsum_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ inline static void InferLabelShape(const std::vector<std::string>& op_labels,
} else if (labelshape->is_default(c) || (*labelshape)[c] == -1) {
(*labelshape)[c] = op_dim[dim_ptr];
dim_ptr++;
} else {
} else if (op_dim[dim_ptr] != -1) {
PADDLE_ENFORCE_EQ(
(*labelshape)[c],
op_dim[dim_ptr],
Expand Down
31 changes: 30 additions & 1 deletion python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import warnings
from ..framework import _get_paddle_place, _in_legacy_dygraph, _in_eager_without_dygraph_check
import paddle
import warnings

__all__ = [
'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
Expand All @@ -45,6 +46,20 @@ def in_declarative_mode():
return _in_declarative_mode_


def declarative_unsupport_argument_warning(func_name, input_names, inputs,
support_values):
"""
Warning if inputs do not elementwisely equals to support_values.
It's a utility function for dy2static when dygraph interface have
more inputs than static interface such as paddle.grad.

"""
for name, inp, sup in zip(input_names, inputs, support_values):
if inp != sup:
warnings.warn(f"{func_name} has unsupported parameter in jit: " +
f"{name}, jit will discard it")


def _switch_to_static_graph_(func):

def __impl__(*args, **kwargs):
Expand Down Expand Up @@ -290,6 +305,10 @@ def test_layer():
test_layer()

"""
if in_declarative_mode():
warnings.warn(
"paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
)
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:
Expand Down Expand Up @@ -428,7 +447,7 @@ def guard(place=None):
yield


@framework.dygraph_only
@framework.non_static_only
def grad(outputs,
inputs,
grad_outputs=None,
Expand Down Expand Up @@ -563,6 +582,16 @@ def test_dygraph_grad(grad_outputs=None):
grad_y1 = paddle.to_tensor(3.0)
print(test_dygraph_grad([grad_y1, grad_value])) # [24.]
'''
if in_declarative_mode():
# In dy2static context, we call static interface `gradients`
# to calculate grads.
from paddle.static import gradients
declarative_unsupport_argument_warning(
"paddle.grad",
["retain_graph", "create_grad", "only_inputs", "allow_unused"],
[retain_graph, create_graph, only_inputs, allow_unused],
[None, False, True, False])
return gradients(outputs, inputs, grad_outputs, no_grad_vars)

def check_in_out(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer
from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import TypeHintTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
Expand Down Expand Up @@ -104,8 +105,9 @@ def transfer_from_node_type(self, node_wrapper):
PrintTransformer, # print statement
CallTransformer, # transform call recursively
CastTransformer, # type casting statement
GradTransformer, # transform paddle.grad to paddle.gradients
#GradTransformer, # transform paddle.grad to paddle.gradients
DecoratorTransformer, # transform decorators to function call
TypeHintTransformer, # remove all typehint in gast.Name
]

apply_optimization(transformers)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.utils import gast
import warnings

from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer


class TypeHintTransformer(BaseTransformer):
"""
A class remove all the typehint in gast.Name(annotation).
Please put it behind other transformers because other transformer may relay on typehints.
"""

def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of TypeHintTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node

def transform(self):
self.visit(self.root)

def visit_FunctionDef(self, node):
node.returns = None
self.generic_visit(node)
return node

def visit_Name(self, node):
node.annotation = None
self.generic_visit(node)
return node
12 changes: 12 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,17 @@ def __impl__(*args, **kwargs):
return __impl__


def _non_static_only_(func):

def __impl__(*args, **kwargs):
from .dygraph.base import in_declarative_mode
assert _non_static_mode() or in_declarative_mode(
), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__
return func(*args, **kwargs)

return __impl__


def _static_only_(func):

def __impl__(*args, **kwargs):
Expand Down Expand Up @@ -570,6 +581,7 @@ def wrapper(*args, **kwargs):
dygraph_only = wrap_decorator(_dygraph_only_)
static_only = wrap_decorator(_static_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_)
non_static_only = wrap_decorator(_non_static_only_)


def _dygraph_tracer():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle.fluid as fluid
import unittest

from paddle.fluid.dygraph.jit import declarative

SEED = 2020
np.random.seed(SEED)


class A:
pass


def function(x: A) -> A:
t: A = A()
return 2 * x


class TestTransformWhileLoop(unittest.TestCase):

def setUp(self):
self.place = fluid.CUDAPlace(
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
self.x = np.zeros(shape=(1), dtype=np.int32)
self._init_dyfunc()

def _init_dyfunc(self):
self.dyfunc = function

def _run_static(self):
return self._run(to_static=True)

def _run_dygraph(self):
return self._run(to_static=False)

def _run(self, to_static):
with fluid.dygraph.guard(self.place):
# Set the input of dyfunc to VarBase
tensor_x = fluid.dygraph.to_variable(self.x, zero_copy=False)
if to_static:
ret = declarative(self.dyfunc)(tensor_x)
else:
ret = self.dyfunc(tensor_x)
if hasattr(ret, "numpy"):
return ret.numpy()
else:
return ret

def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
print(static_numpy, dygraph_numpy)
np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05)


class TestTypeHint(TestTransformWhileLoop):

def _init_dyfunc(self):
self.dyfunc = function


if __name__ == '__main__':
with fluid.framework._test_eager_guard():
unittest.main()