Skip to content

Commit d5de88a

Browse files
authored
[SOT][CodeStyle] Replace Numpy with NumPy (#72580)
1 parent da42c1c commit d5de88a

File tree

11 files changed

+59
-59
lines changed

11 files changed

+59
-59
lines changed

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
GlobalVariable,
8989
ListVariable,
9090
NullVariable,
91-
NumpyArrayVariable,
91+
NumPyArrayVariable,
9292
PaddleLayerVariable,
9393
ParameterVariable,
9494
SymbolicVariable,
@@ -103,7 +103,7 @@
103103
import types
104104

105105
GraphNodeVariableType: TypeAlias = Union[
106-
TensorVariable, SymbolicVariable, NumpyArrayVariable
106+
TensorVariable, SymbolicVariable, NumPyArrayVariable
107107
]
108108

109109

@@ -118,7 +118,7 @@
118118
GraphNodeVariableClasses = (
119119
TensorVariable,
120120
SymbolicVariable,
121-
NumpyArrayVariable,
121+
NumPyArrayVariable,
122122
)
123123

124124

@@ -571,7 +571,7 @@ def call_numpy_api(
571571
**kwargs: VariableBase,
572572
):
573573
"""
574-
Record Numpy API to SIR
574+
Record NumPy API to SIR
575575
576576
Args:
577577
func: numpy api
@@ -814,7 +814,7 @@ def try_infer_meta_fn(args, kwargs) -> Any:
814814
list(args) + list(kwargs.values()), func
815815
)
816816
elif api_type == APIType.NUMPY:
817-
var_cls = NumpyArrayVariable
817+
var_cls = NumPyArrayVariable
818818
tracker = DummyTracker(list(args) + list(kwargs.values()))
819819
else:
820820
var_cls = TensorVariable
@@ -969,8 +969,8 @@ def gen_load_inputs(self, inputs: OrderedSet[GraphNodeVariableType]):
969969
input_var.tracker.gen_instructions(self.pycode_gen)
970970
self.pycode_gen.gen_load_const("int64")
971971
self.pycode_gen.gen_call_function(3)
972-
elif isinstance(input_var, NumpyArrayVariable):
973-
# For NumpyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
972+
elif isinstance(input_var, NumPyArrayVariable):
973+
# For NumPyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
974974
self.pycode_gen.gen_load_object(
975975
paddle.to_tensor,
976976
"___paddle_to_tensor",

python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
ListVariable,
9696
MethodVariable,
9797
NullVariable,
98-
NumpyArrayVariable,
98+
NumPyArrayVariable,
9999
RangeVariable,
100100
SequenceIterVariable,
101101
SliceVariable,
@@ -270,7 +270,7 @@ def if_break_graph_decorator(normal_jump: Callable):
270270

271271
def inner(self: OpcodeExecutor, instr: Instruction):
272272
result = self.stack.top
273-
if isinstance(result, (TensorVariable, NumpyArrayVariable)):
273+
if isinstance(result, (TensorVariable, NumPyArrayVariable)):
274274
# fallback when in OpcodeExecutor
275275
# raise error in OpcodeInlineExecutor
276276
log(3, "[BreakGraph] break graph for if jump tensor\n")

python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@
6969
IterVariable,
7070
ListVariable,
7171
MapVariable,
72-
NumpyArrayVariable,
73-
NumpyVariable,
72+
NumPyArrayVariable,
73+
NumPyVariable,
7474
RangeVariable,
7575
SliceVariable,
7676
SuperVariable,
@@ -1018,7 +1018,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
10181018
for magic_method in magic_method_builtin_dispatch(unary_fn):
10191019
Dispatcher.register(
10201020
unary_fn,
1021-
("ConstantVariable | NumpyNumberVariable",),
1021+
("ConstantVariable | NumPyNumberVariable",),
10221022
partial(
10231023
lambda fn, var: VariableFactory.from_value(
10241024
fn(var.get_py_value()),
@@ -1033,8 +1033,8 @@ def is_not_func(var: VariableBase, other: VariableBase):
10331033
Dispatcher.register(
10341034
binary_fn,
10351035
(
1036-
"ConstantVariable | NumpyNumberVariable",
1037-
"ConstantVariable | NumpyNumberVariable",
1036+
"ConstantVariable | NumPyNumberVariable",
1037+
"ConstantVariable | NumPyNumberVariable",
10381038
),
10391039
partial(
10401040
lambda fn, var, other: VariableFactory.from_value(
@@ -1097,7 +1097,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
10971097
binary_fn,
10981098
(
10991099
"TensorVariable",
1100-
"TensorVariable | SymbolicVariable | ConstantVariable | NumpyNumberVariable",
1100+
"TensorVariable | SymbolicVariable | ConstantVariable | NumPyNumberVariable",
11011101
),
11021102
partial(
11031103
lambda magic_name, var, other: var.graph.call_tensor_method(
@@ -1129,7 +1129,7 @@ def tensor_mod_dispatcher(
11291129
Dispatcher.register(
11301130
binary_fn,
11311131
(
1132-
"SymbolicVariable | ConstantVariable | NumpyNumberVariable",
1132+
"SymbolicVariable | ConstantVariable | NumPyNumberVariable",
11331133
"TensorVariable",
11341134
),
11351135
partial(
@@ -1440,7 +1440,7 @@ def get_math_unary_functions():
14401440
for fn in get_math_unary_functions():
14411441
Dispatcher.register(
14421442
fn,
1443-
("ConstantVariable | NumpyNumberVariable",),
1443+
("ConstantVariable | NumPyNumberVariable",),
14441444
partial(
14451445
lambda fn, var: ConstantVariable(
14461446
fn(var.get_py_value()),
@@ -1452,7 +1452,7 @@ def get_math_unary_functions():
14521452
)
14531453
Dispatcher.register(
14541454
math.log,
1455-
("ConstantVariable | NumpyNumberVariable",),
1455+
("ConstantVariable | NumPyNumberVariable",),
14561456
lambda var: ConstantVariable(
14571457
math.log(var.get_py_value()),
14581458
var.graph,
@@ -1461,7 +1461,7 @@ def get_math_unary_functions():
14611461
)
14621462

14631463

1464-
# NumpyVariable dispatch
1464+
# NumPyVariable dispatch
14651465
def constant_numpy_equal(left, right):
14661466
return left.graph.call_numpy_api(
14671467
NUMPY_API_SUPPORTED_DICT[np.equal], left, right
@@ -1474,13 +1474,13 @@ def constant_numpy_equal(left, right):
14741474
for magic_method in magic_method_builtin_dispatch(unary_fn):
14751475

14761476
@Dispatcher.register_decorator(unary_fn)
1477-
def numpy_unary_dispatcher(var: NumpyArrayVariable):
1478-
raise FallbackError("Numpy operator need fallback to dygraph")
1477+
def numpy_unary_dispatcher(var: NumPyArrayVariable):
1478+
raise FallbackError("NumPy operator need fallback to dygraph")
14791479

14801480

14811481
Dispatcher.register(
14821482
operator.eq,
1483-
("NumpyVariable", "ConstantVariable | NumpyVariable"),
1483+
("NumPyVariable", "ConstantVariable | NumPyVariable"),
14841484
lambda left, right: constant_numpy_equal(right, left),
14851485
)
14861486

@@ -1489,19 +1489,19 @@ def numpy_unary_dispatcher(var: NumpyArrayVariable):
14891489
for magic_method in magic_method_builtin_dispatch(binary_fn):
14901490

14911491
@Dispatcher.register_decorator(binary_fn)
1492-
def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable):
1493-
raise FallbackError("Numpy operator need fallback to dygraph")
1492+
def numpy_binary_dispatcher(var: NumPyVariable, other: NumPyVariable):
1493+
raise FallbackError("NumPy operator need fallback to dygraph")
14941494

14951495

14961496
Dispatcher.register(
14971497
operator.eq,
1498-
("ConstantVariable", "NumpyVariable"),
1498+
("ConstantVariable", "NumPyVariable"),
14991499
lambda left, right: constant_numpy_equal(left, right),
15001500
)
15011501

15021502
Dispatcher.register(
15031503
bool,
1504-
("NumpyVariable",),
1504+
("NumPyVariable",),
15051505
lambda x: ConstantVariable(
15061506
bool(x.get_py_value()),
15071507
x.graph,
@@ -1548,7 +1548,7 @@ def dispatch_all(var: ContainerVariable | IterVariable):
15481548

15491549
Dispatcher.register(
15501550
np.number.item,
1551-
("NumpyNumberVariable",),
1551+
("NumPyNumberVariable",),
15521552
lambda x: ConstantVariable(
15531553
x.get_py_value().item(),
15541554
x.graph,

python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
GlobalVariable,
2828
ModuleVariable,
2929
NullVariable,
30-
NumpyArrayVariable,
31-
NumpyNumberVariable,
32-
NumpyVariable,
30+
NumPyArrayVariable,
31+
NumPyNumberVariable,
32+
NumPyVariable,
3333
ObjectVariable,
3434
ParameterVariable,
3535
PlaceVariable,
@@ -46,7 +46,7 @@
4646
FunctionVariable,
4747
LayerVariable,
4848
MethodVariable,
49-
NumpyApiVariable,
49+
NumPyApiVariable,
5050
PaddleApiVariable,
5151
PaddleLayerVariable,
5252
UserCodeVariable,

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,9 +1689,9 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
16891689
return None
16901690

16911691

1692-
class NumpyVariable(VariableBase):
1692+
class NumPyVariable(VariableBase):
16931693
"""
1694-
NumpyVariable is a subclass of VariableBase used to wrap a Variable of the numpy type.
1694+
NumPyVariable is a subclass of VariableBase used to wrap a Variable of the numpy type.
16951695
16961696
Args:
16971697
value: The numpy value to be wrapped.
@@ -1721,7 +1721,7 @@ def format_dtype(dtype: np.dtype):
17211721

17221722
@staticmethod
17231723
def format_number(number: np.number):
1724-
return f"{NumpyVariable.format_dtype(number.dtype)}({number.item()})"
1724+
return f"{NumPyVariable.format_dtype(number.dtype)}({number.item()})"
17251725

17261726
@check_faster_guard
17271727
def make_faster_guard(self) -> list[paddle.framework.core.GuardNodeBase]:
@@ -1733,7 +1733,7 @@ def make_stringified_guard(self) -> None:
17331733
raise NotImplementedError
17341734

17351735

1736-
class NumpyNumberVariable(NumpyVariable):
1736+
class NumPyNumberVariable(NumPyVariable):
17371737
def _reconstruct(self, codegen: PyCodeGen):
17381738
np_type = self.get_py_type()
17391739
type_id = f"___np_{np_type.__name__}"
@@ -1773,7 +1773,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
17731773
frame_value_tracer = self.tracker.trace_value_from_frame()
17741774

17751775
dtype_guard = FasterStringifiedExpression(
1776-
f"{{}}.dtype == {NumpyVariable.format_dtype(self.get_py_value().dtype)}",
1776+
f"{{}}.dtype == {NumPyVariable.format_dtype(self.get_py_value().dtype)}",
17771777
paddle.framework.core.NumPyDtypeMatchGuard(
17781778
self.get_py_value().dtype
17791779
),
@@ -1784,7 +1784,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
17841784
return [
17851785
dtype_guard,
17861786
FasterStringifiedExpression(
1787-
f"{{}} == {NumpyVariable.format_number(self.get_py_value())}",
1787+
f"{{}} == {NumPyVariable.format_number(self.get_py_value())}",
17881788
paddle.framework.core.ValueMatchGuard(self.get_py_value()),
17891789
[frame_value_tracer],
17901790
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
@@ -1794,19 +1794,19 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
17941794
@VariableFactory.register_from_value()
17951795
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
17961796
if isinstance(value, np.number):
1797-
return NumpyNumberVariable(value, graph, tracker)
1797+
return NumPyNumberVariable(value, graph, tracker)
17981798
return None
17991799

18001800

1801-
class NumpyBoolVariable(NumpyNumberVariable):
1801+
class NumPyBoolVariable(NumPyNumberVariable):
18021802
@VariableFactory.register_from_value()
18031803
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
18041804
if isinstance(value, np.bool_):
1805-
return NumpyBoolVariable(value, graph, tracker)
1805+
return NumPyBoolVariable(value, graph, tracker)
18061806
return None
18071807

18081808

1809-
class NumpyArrayVariable(NumpyVariable):
1809+
class NumPyArrayVariable(NumPyVariable):
18101810
var_name_generator = NameGenerator("np_var_")
18111811
value: npt.NDArray[Any]
18121812
mutable_attrs: list[str] = ["meta"]
@@ -1844,7 +1844,7 @@ def get_py_type(self):
18441844
def get_py_value(self, allow_tensor=False) -> Any:
18451845
raise BreakGraphError(
18461846
UnsupportedOperationBreak(
1847-
reason_str="NumpyArrayVariable doesn't support get_py_value operation."
1847+
reason_str="NumPyArrayVariable doesn't support get_py_value operation."
18481848
)
18491849
)
18501850

@@ -1859,7 +1859,7 @@ def get_iter(self):
18591859
@VariableFactory.register_from_value()
18601860
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
18611861
if isinstance(value, np.ndarray):
1862-
return NumpyArrayVariable(value, graph, tracker)
1862+
return NumPyArrayVariable(value, graph, tracker)
18631863
return None
18641864

18651865
@property
@@ -1896,7 +1896,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
18961896
meta = self.meta
18971897

18981898
dtype_guard = FasterStringifiedExpression(
1899-
f"{{}}.dtype == {NumpyVariable.format_dtype(np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[self.meta.dtype]))}",
1899+
f"{{}}.dtype == {NumPyVariable.format_dtype(np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[self.meta.dtype]))}",
19001900
paddle.framework.core.NumPyDtypeMatchGuard(
19011901
np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[self.meta.dtype])
19021902
),

python/paddle/jit/sot/opcode_translator/executor/variables/callable.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
OtherInlineCallBreak,
6666
PsdbBreakReason,
6767
SotErrorBase,
68-
UnsupportedNumpyAPIBreak,
68+
UnsupportedNumPyAPIBreak,
6969
UnsupportedOperationBreak,
7070
UnsupportedPaddleAPIBreak,
7171
)
@@ -97,7 +97,7 @@
9797
)
9898
from .basic import (
9999
ConstantVariable,
100-
NumpyNumberVariable,
100+
NumPyNumberVariable,
101101
ObjectVariable,
102102
PrintStmtVariable,
103103
SliceVariable,
@@ -371,9 +371,9 @@ def main_info(self) -> dict[str, Any]:
371371
make_faster_guard = object_equal_faster_guard
372372

373373

374-
class NumpyApiVariable(FunctionVariable):
374+
class NumPyApiVariable(FunctionVariable):
375375
"""
376-
NumpyApiVariable is a subclass of FunctionVariable used to wrap a numpy API function.
376+
NumPyApiVariable is a subclass of FunctionVariable used to wrap a numpy API function.
377377
378378
Args:
379379
fn (Callable[..., Any]): The numpy API to be wrapped.
@@ -391,12 +391,12 @@ def __init__(
391391
def call_function(self, /, *args, **kwargs):
392392
# TODO(wangmingkai02): judge whether this is a break api
393393
if all(
394-
isinstance(arg, (ConstantVariable, NumpyNumberVariable))
394+
isinstance(arg, (ConstantVariable, NumPyNumberVariable))
395395
for arg in args
396396
):
397397
if any(
398398
self.value in ufuncs
399-
for ufuncs in NumpyApiVariable._get_numpy_ufuncs()
399+
for ufuncs in NumPyApiVariable._get_numpy_ufuncs()
400400
):
401401
vars = list(args)
402402
var_py_values = [var.get_py_value() for var in vars]
@@ -410,7 +410,7 @@ def call_function(self, /, *args, **kwargs):
410410
NUMPY_API_SUPPORTED_DICT[self.value], *args, **kwargs
411411
)
412412
raise BreakGraphError(
413-
UnsupportedNumpyAPIBreak(fn_name=self.value.__name__)
413+
UnsupportedNumPyAPIBreak(fn_name=self.value.__name__)
414414
)
415415

416416
@classmethod
@@ -429,11 +429,11 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
429429
value in NUMPY_API_SUPPORTED_DICT
430430
or any(
431431
value in ufuncs
432-
for ufuncs in NumpyApiVariable._get_numpy_ufuncs()
432+
for ufuncs in NumPyApiVariable._get_numpy_ufuncs()
433433
)
434434
)
435435
):
436-
return NumpyApiVariable(value, graph, tracker)
436+
return NumPyApiVariable(value, graph, tracker)
437437
return None
438438

439439
@property

0 commit comments

Comments
 (0)