Skip to content

[SOT][CodeStyle] Replace Numpy with NumPy #72580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
GlobalVariable,
ListVariable,
NullVariable,
NumpyArrayVariable,
NumPyArrayVariable,
PaddleLayerVariable,
ParameterVariable,
SymbolicVariable,
Expand All @@ -103,7 +103,7 @@
import types

GraphNodeVariableType: TypeAlias = Union[
TensorVariable, SymbolicVariable, NumpyArrayVariable
TensorVariable, SymbolicVariable, NumPyArrayVariable
]


Expand All @@ -118,7 +118,7 @@
GraphNodeVariableClasses = (
TensorVariable,
SymbolicVariable,
NumpyArrayVariable,
NumPyArrayVariable,
)


Expand Down Expand Up @@ -571,7 +571,7 @@ def call_numpy_api(
**kwargs: VariableBase,
):
"""
Record Numpy API to SIR
Record NumPy API to SIR

Args:
func: numpy api
Expand Down Expand Up @@ -814,7 +814,7 @@ def try_infer_meta_fn(args, kwargs) -> Any:
list(args) + list(kwargs.values()), func
)
elif api_type == APIType.NUMPY:
var_cls = NumpyArrayVariable
var_cls = NumPyArrayVariable
tracker = DummyTracker(list(args) + list(kwargs.values()))
else:
var_cls = TensorVariable
Expand Down Expand Up @@ -969,8 +969,8 @@ def gen_load_inputs(self, inputs: OrderedSet[GraphNodeVariableType]):
input_var.tracker.gen_instructions(self.pycode_gen)
self.pycode_gen.gen_load_const("int64")
self.pycode_gen.gen_call_function(3)
elif isinstance(input_var, NumpyArrayVariable):
# For NumpyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
elif isinstance(input_var, NumPyArrayVariable):
# For NumPyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
self.pycode_gen.gen_load_object(
paddle.to_tensor,
"___paddle_to_tensor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
ListVariable,
MethodVariable,
NullVariable,
NumpyArrayVariable,
NumPyArrayVariable,
RangeVariable,
SequenceIterVariable,
SliceVariable,
Expand Down Expand Up @@ -270,7 +270,7 @@ def if_break_graph_decorator(normal_jump: Callable):

def inner(self: OpcodeExecutor, instr: Instruction):
result = self.stack.top
if isinstance(result, (TensorVariable, NumpyArrayVariable)):
if isinstance(result, (TensorVariable, NumPyArrayVariable)):
# fallback when in OpcodeExecutor
# raise error in OpcodeInlineExecutor
log(3, "[BreakGraph] break graph for if jump tensor\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
IterVariable,
ListVariable,
MapVariable,
NumpyArrayVariable,
NumpyVariable,
NumPyArrayVariable,
NumPyVariable,
RangeVariable,
SliceVariable,
SuperVariable,
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
for magic_method in magic_method_builtin_dispatch(unary_fn):
Dispatcher.register(
unary_fn,
("ConstantVariable | NumpyNumberVariable",),
("ConstantVariable | NumPyNumberVariable",),
partial(
lambda fn, var: VariableFactory.from_value(
fn(var.get_py_value()),
Expand All @@ -1033,8 +1033,8 @@ def is_not_func(var: VariableBase, other: VariableBase):
Dispatcher.register(
binary_fn,
(
"ConstantVariable | NumpyNumberVariable",
"ConstantVariable | NumpyNumberVariable",
"ConstantVariable | NumPyNumberVariable",
"ConstantVariable | NumPyNumberVariable",
),
partial(
lambda fn, var, other: VariableFactory.from_value(
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
binary_fn,
(
"TensorVariable",
"TensorVariable | SymbolicVariable | ConstantVariable | NumpyNumberVariable",
"TensorVariable | SymbolicVariable | ConstantVariable | NumPyNumberVariable",
),
partial(
lambda magic_name, var, other: var.graph.call_tensor_method(
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def tensor_mod_dispatcher(
Dispatcher.register(
binary_fn,
(
"SymbolicVariable | ConstantVariable | NumpyNumberVariable",
"SymbolicVariable | ConstantVariable | NumPyNumberVariable",
"TensorVariable",
),
partial(
Expand Down Expand Up @@ -1440,7 +1440,7 @@ def get_math_unary_functions():
for fn in get_math_unary_functions():
Dispatcher.register(
fn,
("ConstantVariable | NumpyNumberVariable",),
("ConstantVariable | NumPyNumberVariable",),
partial(
lambda fn, var: ConstantVariable(
fn(var.get_py_value()),
Expand All @@ -1452,7 +1452,7 @@ def get_math_unary_functions():
)
Dispatcher.register(
math.log,
("ConstantVariable | NumpyNumberVariable",),
("ConstantVariable | NumPyNumberVariable",),
lambda var: ConstantVariable(
math.log(var.get_py_value()),
var.graph,
Expand All @@ -1461,7 +1461,7 @@ def get_math_unary_functions():
)


# NumpyVariable dispatch
# NumPyVariable dispatch
def constant_numpy_equal(left, right):
return left.graph.call_numpy_api(
NUMPY_API_SUPPORTED_DICT[np.equal], left, right
Expand All @@ -1474,13 +1474,13 @@ def constant_numpy_equal(left, right):
for magic_method in magic_method_builtin_dispatch(unary_fn):

@Dispatcher.register_decorator(unary_fn)
def numpy_unary_dispatcher(var: NumpyArrayVariable):
raise FallbackError("Numpy operator need fallback to dygraph")
def numpy_unary_dispatcher(var: NumPyArrayVariable):
raise FallbackError("NumPy operator need fallback to dygraph")


Dispatcher.register(
operator.eq,
("NumpyVariable", "ConstantVariable | NumpyVariable"),
("NumPyVariable", "ConstantVariable | NumPyVariable"),
lambda left, right: constant_numpy_equal(right, left),
)

Expand All @@ -1489,19 +1489,19 @@ def numpy_unary_dispatcher(var: NumpyArrayVariable):
for magic_method in magic_method_builtin_dispatch(binary_fn):

@Dispatcher.register_decorator(binary_fn)
def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable):
raise FallbackError("Numpy operator need fallback to dygraph")
def numpy_binary_dispatcher(var: NumPyVariable, other: NumPyVariable):
raise FallbackError("NumPy operator need fallback to dygraph")


Dispatcher.register(
operator.eq,
("ConstantVariable", "NumpyVariable"),
("ConstantVariable", "NumPyVariable"),
lambda left, right: constant_numpy_equal(left, right),
)

Dispatcher.register(
bool,
("NumpyVariable",),
("NumPyVariable",),
lambda x: ConstantVariable(
bool(x.get_py_value()),
x.graph,
Expand Down Expand Up @@ -1548,7 +1548,7 @@ def dispatch_all(var: ContainerVariable | IterVariable):

Dispatcher.register(
np.number.item,
("NumpyNumberVariable",),
("NumPyNumberVariable",),
lambda x: ConstantVariable(
x.get_py_value().item(),
x.graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
GlobalVariable,
ModuleVariable,
NullVariable,
NumpyArrayVariable,
NumpyNumberVariable,
NumpyVariable,
NumPyArrayVariable,
NumPyNumberVariable,
NumPyVariable,
ObjectVariable,
ParameterVariable,
PlaceVariable,
Expand All @@ -46,7 +46,7 @@
FunctionVariable,
LayerVariable,
MethodVariable,
NumpyApiVariable,
NumPyApiVariable,
PaddleApiVariable,
PaddleLayerVariable,
UserCodeVariable,
Expand Down
26 changes: 13 additions & 13 deletions python/paddle/jit/sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,9 +1689,9 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
return None


class NumpyVariable(VariableBase):
class NumPyVariable(VariableBase):
"""
NumpyVariable is a subclass of VariableBase used to wrap a Variable of the numpy type.
NumPyVariable is a subclass of VariableBase used to wrap a Variable of the numpy type.

Args:
value: The numpy value to be wrapped.
Expand Down Expand Up @@ -1721,7 +1721,7 @@ def format_dtype(dtype: np.dtype):

@staticmethod
def format_number(number: np.number):
return f"{NumpyVariable.format_dtype(number.dtype)}({number.item()})"
return f"{NumPyVariable.format_dtype(number.dtype)}({number.item()})"

@check_faster_guard
def make_faster_guard(self) -> list[paddle.framework.core.GuardNodeBase]:
Expand All @@ -1733,7 +1733,7 @@ def make_stringified_guard(self) -> None:
raise NotImplementedError


class NumpyNumberVariable(NumpyVariable):
class NumPyNumberVariable(NumPyVariable):
def _reconstruct(self, codegen: PyCodeGen):
np_type = self.get_py_type()
type_id = f"___np_{np_type.__name__}"
Expand Down Expand Up @@ -1773,7 +1773,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
frame_value_tracer = self.tracker.trace_value_from_frame()

dtype_guard = FasterStringifiedExpression(
f"{{}}.dtype == {NumpyVariable.format_dtype(self.get_py_value().dtype)}",
f"{{}}.dtype == {NumPyVariable.format_dtype(self.get_py_value().dtype)}",
paddle.framework.core.NumPyDtypeMatchGuard(
self.get_py_value().dtype
),
Expand All @@ -1784,7 +1784,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
return [
dtype_guard,
FasterStringifiedExpression(
f"{{}} == {NumpyVariable.format_number(self.get_py_value())}",
f"{{}} == {NumPyVariable.format_number(self.get_py_value())}",
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
[frame_value_tracer],
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
Expand All @@ -1794,19 +1794,19 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, np.number):
return NumpyNumberVariable(value, graph, tracker)
return NumPyNumberVariable(value, graph, tracker)
return None


class NumpyBoolVariable(NumpyNumberVariable):
class NumPyBoolVariable(NumPyNumberVariable):
@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, np.bool_):
return NumpyBoolVariable(value, graph, tracker)
return NumPyBoolVariable(value, graph, tracker)
return None


class NumpyArrayVariable(NumpyVariable):
class NumPyArrayVariable(NumPyVariable):
var_name_generator = NameGenerator("np_var_")
value: npt.NDArray[Any]
mutable_attrs: list[str] = ["meta"]
Expand Down Expand Up @@ -1844,7 +1844,7 @@ def get_py_type(self):
def get_py_value(self, allow_tensor=False) -> Any:
raise BreakGraphError(
UnsupportedOperationBreak(
reason_str="NumpyArrayVariable doesn't support get_py_value operation."
reason_str="NumPyArrayVariable doesn't support get_py_value operation."
)
)

Expand All @@ -1859,7 +1859,7 @@ def get_iter(self):
@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, np.ndarray):
return NumpyArrayVariable(value, graph, tracker)
return NumPyArrayVariable(value, graph, tracker)
return None

@property
Expand Down Expand Up @@ -1896,7 +1896,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
meta = self.meta

dtype_guard = FasterStringifiedExpression(
f"{{}}.dtype == {NumpyVariable.format_dtype(np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[self.meta.dtype]))}",
f"{{}}.dtype == {NumPyVariable.format_dtype(np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[self.meta.dtype]))}",
paddle.framework.core.NumPyDtypeMatchGuard(
np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[self.meta.dtype])
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
OtherInlineCallBreak,
PsdbBreakReason,
SotErrorBase,
UnsupportedNumpyAPIBreak,
UnsupportedNumPyAPIBreak,
UnsupportedOperationBreak,
UnsupportedPaddleAPIBreak,
)
Expand Down Expand Up @@ -97,7 +97,7 @@
)
from .basic import (
ConstantVariable,
NumpyNumberVariable,
NumPyNumberVariable,
ObjectVariable,
PrintStmtVariable,
SliceVariable,
Expand Down Expand Up @@ -371,9 +371,9 @@ def main_info(self) -> dict[str, Any]:
make_faster_guard = object_equal_faster_guard


class NumpyApiVariable(FunctionVariable):
class NumPyApiVariable(FunctionVariable):
"""
NumpyApiVariable is a subclass of FunctionVariable used to wrap a numpy API function.
NumPyApiVariable is a subclass of FunctionVariable used to wrap a numpy API function.

Args:
fn (Callable[..., Any]): The numpy API to be wrapped.
Expand All @@ -391,12 +391,12 @@ def __init__(
def call_function(self, /, *args, **kwargs):
# TODO(wangmingkai02): judge whether this is a break api
if all(
isinstance(arg, (ConstantVariable, NumpyNumberVariable))
isinstance(arg, (ConstantVariable, NumPyNumberVariable))
for arg in args
):
if any(
self.value in ufuncs
for ufuncs in NumpyApiVariable._get_numpy_ufuncs()
for ufuncs in NumPyApiVariable._get_numpy_ufuncs()
):
vars = list(args)
var_py_values = [var.get_py_value() for var in vars]
Expand All @@ -410,7 +410,7 @@ def call_function(self, /, *args, **kwargs):
NUMPY_API_SUPPORTED_DICT[self.value], *args, **kwargs
)
raise BreakGraphError(
UnsupportedNumpyAPIBreak(fn_name=self.value.__name__)
UnsupportedNumPyAPIBreak(fn_name=self.value.__name__)
)

@classmethod
Expand All @@ -429,11 +429,11 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
value in NUMPY_API_SUPPORTED_DICT
or any(
value in ufuncs
for ufuncs in NumpyApiVariable._get_numpy_ufuncs()
for ufuncs in NumPyApiVariable._get_numpy_ufuncs()
)
)
):
return NumpyApiVariable(value, graph, tracker)
return NumPyApiVariable(value, graph, tracker)
return None

@property
Expand Down
Loading