Skip to content

[SOT] Add support for numpy ufunc with numpy number #71295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from functools import partial, reduce
from typing import TYPE_CHECKING

import numpy as np

import paddle

from ...utils import BreakGraphError, FallbackError
from ...utils import BreakGraphError, FallbackError, get_numpy_ufuncs
from ...utils.magic_methods import (
BINARY_OPS,
UNARY_OPS,
Expand All @@ -47,6 +49,7 @@
EnumerateVariable,
ListVariable,
MapVariable,
NumpyArrayVariable,
NumpyVariable,
RangeVariable,
SliceVariable,
Expand Down Expand Up @@ -980,7 +983,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
for magic_method in magic_method_builtin_dispatch(unary_fn):
Dispatcher.register(
unary_fn,
("ConstantVariable",),
("ConstantVariable | NumpyNumberVariable",),
partial(
lambda fn, var: VariableFactory.from_value(
fn(var.get_py_value()),
Expand All @@ -994,7 +997,10 @@ def is_not_func(var: VariableBase, other: VariableBase):
for magic_method in magic_method_builtin_dispatch(binary_fn):
Dispatcher.register(
binary_fn,
("ConstantVariable", "ConstantVariable"),
(
"ConstantVariable | NumpyNumberVariable",
"ConstantVariable | NumpyNumberVariable",
),
partial(
lambda fn, var, other: VariableFactory.from_value(
fn(var.get_py_value(), other.get_py_value()),
Expand Down Expand Up @@ -1138,31 +1144,6 @@ def tensor_mod_dispatcher(
),
)

# Register dispatch for NumpyVariable: fallback !
for unary_fn in UNARY_OPS:
if unary_fn in [bool]:
continue
for magic_method in magic_method_builtin_dispatch(unary_fn):

@Dispatcher.register_decorator(unary_fn)
def numpy_unary_dispatcher(var: NumpyVariable):
raise FallbackError('Numpy operator need fallback to dygraph')


Dispatcher.register(
operator.eq,
("NumpyVariable", "ConstantVariable | NumpyVariable"),
lambda left, right: constant_numpy_equal(right, left),
)


for binary_fn in BINARY_OPS:
for magic_method in magic_method_builtin_dispatch(binary_fn):

@Dispatcher.register_decorator(binary_fn)
def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable):
raise FallbackError('Numpy operator need fallback to dygraph')


# Register dispatch for DataVariable: directly call and return a wrapped variable.
def data_variable_binary_dispatcher(var, other, operator):
Expand Down Expand Up @@ -1344,7 +1325,7 @@ def get_math_unary_functions():
for fn in get_math_unary_functions():
Dispatcher.register(
fn,
("ConstantVariable",),
("ConstantVariable | NumpyNumberVariable",),
partial(
lambda fn, var: ConstantVariable(
fn(var.get_py_value()),
Expand All @@ -1356,7 +1337,7 @@ def get_math_unary_functions():
)
Dispatcher.register(
math.log,
("ConstantVariable",),
("ConstantVariable | NumpyNumberVariable",),
lambda var: ConstantVariable(
math.log(var.get_py_value()),
var.graph,
Expand All @@ -1365,15 +1346,41 @@ def get_math_unary_functions():
)


# NumpyVariable dispatch
def constant_numpy_equal(left, right):
numpy_ans = left.get_py_value() == right.get_py_value()
return NumpyVariable(
return VariableFactory.from_value(
numpy_ans,
left.graph,
tracker=DummyTracker([left, right]),
)


for unary_fn in UNARY_OPS:
if unary_fn is bool:
continue
for magic_method in magic_method_builtin_dispatch(unary_fn):

@Dispatcher.register_decorator(unary_fn)
def numpy_unary_dispatcher(var: NumpyArrayVariable):
raise FallbackError('Numpy operator need fallback to dygraph')


Dispatcher.register(
operator.eq,
("NumpyVariable", "ConstantVariable | NumpyVariable"),
lambda left, right: constant_numpy_equal(right, left),
)


for binary_fn in BINARY_OPS:
for magic_method in magic_method_builtin_dispatch(binary_fn):

@Dispatcher.register_decorator(binary_fn)
def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable):
raise FallbackError('Numpy operator need fallback to dygraph')


Dispatcher.register(
operator.eq,
("ConstantVariable", "NumpyVariable"),
Expand All @@ -1389,3 +1396,44 @@ def constant_numpy_equal(left, right):
tracker=DummyTracker([x]),
),
)

Dispatcher.register(
np.number.item,
("NumpyNumberVariable",),
lambda x: ConstantVariable(
x.get_py_value().item(),
x.graph,
tracker=DummyTracker([x]),
),
)

unary_ufuncs, binary_ufuncs = get_numpy_ufuncs()
for ufunc in unary_ufuncs:
Dispatcher.register(
ufunc,
("ConstantVariable | NumpyNumberVariable",),
partial(
lambda ufunc, var: VariableFactory.from_value(
ufunc(var.get_py_value()),
var.graph,
tracker=DummyTracker([var]),
),
ufunc,
),
)
for ufunc in binary_ufuncs:
Dispatcher.register(
ufunc,
(
"ConstantVariable | NumpyNumberVariable",
"ConstantVariable | NumpyNumberVariable",
),
partial(
lambda ufunc, var, other: VariableFactory.from_value(
ufunc(var.get_py_value(), other.get_py_value()),
var.graph,
tracker=DummyTracker([var, other]),
),
ufunc,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -1227,10 +1227,25 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):


class NumpyNumberVariable(NumpyVariable):
def _reconstruct(self, codegen: PyCodeGen):
np_type = self.get_py_type()
type_id = f"___np_{np_type.__name__}"
codegen.gen_load_object(np_type, type_id)
codegen.gen_load_const(self.value.item())
codegen.gen_call_function(1)

def getattr(self, name: str, default=None):
from .callable import BuiltinVariable

if name != "item":
return super().getattr(name, default)
return BuiltinVariable(
np.number.item, self.graph, GetAttrTracker(self, name)
).bind(self, name)

@check_guard
def make_stringified_guard(self) -> list[StringifiedExpression]:
frame_value_tracer = self.tracker.trace_value_from_frame()
obj_free_var_name = f"__{self.id}"

dtype_guard = StringifiedExpression(
f"{{}}.dtype == {NumpyVariable.format_dtype(self.get_py_value().dtype)}",
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
flatten,
flatten_extend,
get_api_fullname,
get_numpy_ufuncs,
get_unbound_method,
hashable,
in_paddle_module,
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,3 +546,15 @@ def get_api_fullname(api):
return module_str + "." + api_name
module_str = module_str.rpartition(".")[0]
return None


def get_numpy_ufuncs():
ufuncs = [
ufunc
for _, ufunc in inspect.getmembers(
np, lambda member: isinstance(member, np.ufunc)
)
]
unary_ufuncs = filter(lambda ufunc: ufunc.nin == 1, ufuncs)
binary_ufuncs = filter(lambda ufunc: ufunc.nin == 2, ufuncs)
return list(unary_ufuncs), list(binary_ufuncs)
22 changes: 22 additions & 0 deletions test/sot/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

import paddle
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils import strict_mode_guard


Expand All @@ -42,6 +43,23 @@ def normal_numpy_array_to_tensor(x):
return paddle.to_tensor(x)


@check_no_breakgraph
def numpy_api_with_number_calculation(t):
a = np.log(2)
b = np.exp(3)
c = np.sqrt(4)
d = np.ceil(5.1)
e = np.add(1, 2)
f = a + 1
g = 1 - b
h = c * 2
i = int(a)
j = float(b)
k = c.item()
l = t + d
return a, b, c, d, e, f, g, h, i, j, k, l


class TestNumpy(TestCaseBase):
@strict_mode_guard(False)
def test_numpy_add(self):
Expand Down Expand Up @@ -77,6 +95,10 @@ def test_numpy_array_guard(self):
self.assert_results(normal_numpy_array_to_tensor, x)
self.assertEqual(ctx.translate_count, 1)

def test_numpy_api_with_number_calculation(self):
t = paddle.to_tensor([1.0])
self.assert_results(numpy_api_with_number_calculation, t)


if __name__ == "__main__":
unittest.main()