Skip to content

[SOT][NumPy] Complete the basic procedure #72154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 4, 2025
Merged
7 changes: 7 additions & 0 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ void BindGuard(pybind11::module *m) {
"NumPyArrayValueMatchGuard",
R"DOC(NumPyArrayValueMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("array"));
py::class_<NumPyArrayShapeMatchGuard,
GuardBase,
std::shared_ptr<NumPyArrayShapeMatchGuard>>(
*m,
"NumPyArrayShapeMatchGuard",
R"DOC(NumPyArrayShapeMatchGuard Class.)DOC")
.def(py::init<const std::vector<py::object> &>(), py::arg("shape"));
py::class_<WeakRefMatchGuard, GuardBase, std::shared_ptr<WeakRefMatchGuard>>(
*m, "WeakRefMatchGuard", R"DOC(WeakRefMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("func"));
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,24 @@ bool NumPyArrayValueMatchGuard::check(PyObject* value) {
.cast<bool>();
}

bool NumPyArrayShapeMatchGuard::check(PyObject* value) {
py::array array = py::reinterpret_borrow<py::array>(value);
if (!array) {
return false;
}
int ndim = array.ndim();
auto shape = array.shape();
if (ndim != static_cast<int>(expected_.size())) {
return false;
}
for (int i = 0; i < ndim; ++i) {
if (expected_[i].has_value() && shape[i] != expected_[i].value()) {
return false;
}
}
return true;
}

bool WeakRefMatchGuard::check(PyObject* value) {
if (value == nullptr || expected_ == nullptr || Py_IsNone(expected_)) {
return false;
Expand Down
24 changes: 24 additions & 0 deletions paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,30 @@ class NumPyArrayValueMatchGuard : public GuardBase {
PyObject* expected_;
};

class NumPyArrayShapeMatchGuard : public GuardBase {
public:
explicit NumPyArrayShapeMatchGuard(
const std::vector<std::optional<int64_t>>& shape)
: expected_(shape) {}

explicit NumPyArrayShapeMatchGuard(const std::vector<py::object>& shape) {
Copy link
Member

@gouzil gouzil Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是最多只需要支持两种初始化方式

  • 直接传入一个 np.array (由于一些编译原因,得传py::object转一下)
  • 在 python 直接 .shape, 传的是一个 tuple[int],也就是 std::vector<py::int_> 或者初始化为 py::tuple,这样()(2,)都能覆盖到

expected_.resize(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
if (py::isinstance<py::int_>(shape[i]) && shape[i].cast<int64_t>() > 0) {
expected_[i] = std::make_optional(shape[i].cast<int64_t>());
}
}
}

bool check(PyObject* value) override;
std::string get_guard_name() const override {
return "NumPyArrayShapeMatchGuard";
}

private:
std::vector<std::optional<int64_t>> expected_;
};

class WeakRefMatchGuard : public GuardBase {
public:
explicit WeakRefMatchGuard(const py::object& obj) {
Expand Down
28 changes: 27 additions & 1 deletion python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import copy
from functools import cached_property
from typing import TypeVar
from typing import TYPE_CHECKING, Any, TypeVar

import paddle
from paddle.amp.auto_cast import amp_state
from paddle.base.data_feeder import convert_dtype
from paddle.base.framework import convert_np_dtype_to_dtype_
from paddle.base.unique_name import (
UniqueNameGenerator,
guard as UniqueNameGuard,
Expand All @@ -46,6 +47,9 @@
meta_str,
)

if TYPE_CHECKING:
import numpy.typing as npt

DynamicSymbolT = TypeVar("DynamicSymbolT")
SOT_INFER_META_INNER_VAR = "___SOT_INFER_META_INNER_VAR"

Expand Down Expand Up @@ -226,6 +230,28 @@ def from_value(value) -> MetaInfo:
dist_info=dist_info,
)

@staticmethod
def from_numpy(
nparray: npt.NDArray[Any], *, dynamic_axes: list[int] | None = None
):
dtype = convert_np_dtype_to_dtype_(nparray.dtype)
dynamic_axes = dynamic_axes or []
shape = [
SymbolicInt() if i in dynamic_axes else dim
for i, dim in enumerate(nparray.shape)
]
return MetaInfo(
shape,
dtype,
True, # stop_gradient
None,
None, # persistable
None,
None,
None,
dist_info=None,
)

def is_inner_var(self):
return self.name == SOT_INFER_META_INNER_VAR

Expand Down
107 changes: 80 additions & 27 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from collections import namedtuple
from contextlib import contextmanager
from copy import deepcopy
from enum import Enum
from functools import reduce
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union

Expand Down Expand Up @@ -49,6 +50,7 @@
from ...symbolic_shape.operators import SYMBOLIC_BINARY_OPS, SYMBOLIC_UNARY_OPS
from ...utils import (
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
NUMPY_API_SUPPORTED_DICT,
NameGenerator,
SIRToCodeMap,
SotUndefinedVar,
Expand Down Expand Up @@ -86,6 +88,7 @@
GlobalVariable,
ListVariable,
NullVariable,
NumpyArrayVariable,
PaddleLayerVariable,
ParameterVariable,
SymbolicVariable,
Expand All @@ -99,6 +102,10 @@
if TYPE_CHECKING:
import types

GraphNodeVariableType: TypeAlias = Union[
TensorVariable, SymbolicVariable, NumpyArrayVariable
]


CompileGraphResult: TypeAlias = Tuple[
Callable[..., Any],
Expand All @@ -108,6 +115,11 @@
OrderedSet[Union[TensorVariable, SymbolicVariable]],
],
]
GraphNodeVariableClasses = (
TensorVariable,
SymbolicVariable,
NumpyArrayVariable,
)


def convert_to_meta(inputs: Any):
Expand All @@ -116,7 +128,7 @@ def convert_to_meta(inputs: Any):
"""

def func(x):
if isinstance(x, (TensorVariable, SymbolicVariable)):
if isinstance(x, GraphNodeVariableClasses):
return x.meta
if isinstance(x, VariableBase):
return x.get_py_value()
Expand All @@ -131,7 +143,7 @@ def convert_to_symbol(inputs: Any):
"""

def func(x):
if isinstance(x, (TensorVariable, SymbolicVariable)):
if isinstance(x, GraphNodeVariableClasses):
return x.get_symbol()
if isinstance(x, VariableBase):
return x.get_py_value()
Expand All @@ -155,7 +167,7 @@ def record_symbols(SIR, *args, **kwargs):
non_params = set()

def fn(value):
if isinstance(value, (TensorVariable, SymbolicVariable)):
if isinstance(value, GraphNodeVariableClasses):
symbol_meta_map[value.get_symbol()] = value.meta
if isinstance(value, ParameterVariable):
params.add(value.get_symbol())
Expand Down Expand Up @@ -190,6 +202,12 @@ def func(x):
return map_variables(func, inputs, restore_variable=True)


class APIType(Enum):
PADDLE = 0
SYMBOLIC = 1
NUMPY = 2


class VariableLoader:
def __init__(self, store_var_info, pycode_gen):
self._store_var_info = store_var_info
Expand Down Expand Up @@ -541,7 +559,34 @@ def message_handler(*args, **kwargs):
InferMetaCache(),
self.sir_builder.call_API,
func,
False,
APIType.PADDLE,
*args,
**kwargs,
)

def call_numpy_api(
self,
func: Callable[..., Any],
*args: VariableBase,
**kwargs: VariableBase,
):
"""
Record Numpy API to SIR

Args:
func: numpy api
"""
assert func in NUMPY_API_SUPPORTED_DICT.values()
log(3, f"call numpy.api : {func.__name__}", "\n")

def message_handler(*args, **kwargs):
return f"Call numpy api error: {func.__name__}, may be not a operator api?"

return inner_error_default_handler(self.symbolic_call, message_handler)(
InferMetaCache(),
self.sir_builder.call_API,
func,
APIType.NUMPY,
*args,
**kwargs,
)
Expand All @@ -562,7 +607,7 @@ def message_handler(*args, **kwargs):
InferMetaCache(),
self.sir_builder.call_API,
op,
True,
APIType.SYMBOLIC,
*args,
**kwargs,
)
Expand All @@ -584,7 +629,7 @@ def message_handler(*args, **kwargs):
InferMetaCache(),
self.sir_builder.call_METHOD,
method_name,
False,
APIType.PADDLE,
*args,
**kwargs,
)
Expand Down Expand Up @@ -619,7 +664,7 @@ def message_handler(*args, **kwargs):
return f"Call paddle layer error: {layer}, may be not a valid paddle layer?"

return inner_error_default_handler(self.symbolic_call, message_handler)(
infer_meta_fn, compute_fn, layer, False, *args, **kwargs
infer_meta_fn, compute_fn, layer, APIType.PADDLE, *args, **kwargs
)

def call_ast(
Expand Down Expand Up @@ -653,7 +698,7 @@ def message_handler(*args, **kwargs):
ast_infer_meta,
compute_fn,
static_function,
False,
APIType.PADDLE,
*args,
**kwargs,
)
Expand All @@ -662,7 +707,7 @@ def message_handler(*args, **kwargs):
return None

def symbolic_call(
self, infer_meta_fn, compute_fn, func, is_symbolic_var, *args, **kwargs
self, infer_meta_fn, compute_fn, func, api_type, *args, **kwargs
):
"""
Using infer_meta_fn and compute_fn convert func to symbolic function.
Expand Down Expand Up @@ -763,11 +808,14 @@ def try_infer_meta_fn(args, kwargs) -> Any:

log(3, f" inputs : {inputs_symbols}", "\n")

if is_symbolic_var:
if api_type == APIType.SYMBOLIC:
var_cls = SymbolicVariable
tracker = SymbolicOperationTracker(
list(args) + list(kwargs.values()), func
)
elif api_type == APIType.NUMPY:
var_cls = NumpyArrayVariable
tracker = DummyTracker(list(args) + list(kwargs.values()))
else:
var_cls = TensorVariable
tracker = DummyTracker(list(args) + list(kwargs.values()))
Expand Down Expand Up @@ -807,7 +855,7 @@ def try_infer_meta_fn(args, kwargs) -> Any:
stmt_stacks,
) # symbolic only contain symbols.
self._put_inner(outputs)
if is_symbolic_var:
if api_type == APIType.SYMBOLIC:
# compute_fn should be call_method
tracker = SymbolicOperationTracker(
list(args) + list(kwargs.values()), func
Expand Down Expand Up @@ -892,13 +940,13 @@ def remove_global_guarded_variable(self, variable: VariableBase):

def _find_tensor_inputs(
self, input_names: list[str]
) -> OrderedSet[TensorVariable | SymbolicVariable]:
inputs: OrderedSet[TensorVariable | SymbolicVariable] = OrderedSet()
) -> OrderedSet[GraphNodeVariableType]:
inputs: OrderedSet[GraphNodeVariableType] = OrderedSet()
for name in input_names:
found = False
for variable in self.input_variables:
if (
isinstance(variable, (TensorVariable, SymbolicVariable))
isinstance(variable, GraphNodeVariableClasses)
and variable.get_symbol().name == name
):
inputs.add(variable)
Expand All @@ -908,30 +956,37 @@ def _find_tensor_inputs(
assert len(inputs) == len(input_names), "Number of inputs not match."
return inputs

def gen_load_inputs(
self, inputs: OrderedSet[TensorVariable | SymbolicVariable]
):
def gen_load_inputs(self, inputs: OrderedSet[GraphNodeVariableType]):
for input_var in inputs:
# For SymbolicVariable, we use paddle.full([], value, "int64")
# to convert it to a Tensor
if isinstance(input_var, SymbolicVariable):
# For SymbolicVariable, we use paddle.full([], value, "int64")
# to convert it to a Tensor
self.pycode_gen.gen_load_object(
paddle.full,
"___paddle_full",
)
self.pycode_gen.gen_build_list(0)
input_var.tracker.gen_instructions(self.pycode_gen)
if isinstance(input_var, SymbolicVariable):
input_var.tracker.gen_instructions(self.pycode_gen)
self.pycode_gen.gen_load_const("int64")
self.pycode_gen.gen_call_function(3)
elif isinstance(input_var, NumpyArrayVariable):
# For NumpyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
self.pycode_gen.gen_load_object(
paddle.to_tensor,
"___paddle_to_tensor",
)
input_var.tracker.gen_instructions(self.pycode_gen)
self.pycode_gen.gen_call_function(1)
else:
input_var.tracker.gen_instructions(self.pycode_gen)

@staticmethod
def _is_graph_output(
var,
) -> TypeGuard[TensorVariable | SymbolicVariable]:
) -> TypeGuard[GraphNodeVariableType]:
return isinstance(
var.tracker, (DummyTracker, SymbolicOperationTracker)
) and isinstance(var, (TensorVariable, SymbolicVariable))
) and isinstance(var, GraphNodeVariableClasses)

@staticmethod
def _collect_related_dummy_tensor(var):
Expand All @@ -949,17 +1004,15 @@ def _collect_related_dummy_tensor(var):

def _find_tensor_outputs(
self, outputs: list[VariableBase]
) -> OrderedSet[TensorVariable | SymbolicVariable]:
) -> OrderedSet[GraphNodeVariableType]:
"""
Return all TensorVariable. find TensorVariables participating in networking from the output Variables

Args:
outputs: output variables
"""

output_tensors: OrderedSet[TensorVariable | SymbolicVariable] = (
OrderedSet()
)
output_tensors: OrderedSet[GraphNodeVariableType] = OrderedSet()
# Find Tensor Variables from outputs.
for output in outputs:
if isinstance(
Expand Down
Loading
Loading