Skip to content

Commit 399da2a

Browse files
authored
[SOT] Split NumpyVariable for separate array/number handling (#71216)
1 parent f0deb95 commit 399da2a

File tree

3 files changed

+57
-33
lines changed

3 files changed

+57
-33
lines changed

python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
10641064
binary_fn,
10651065
(
10661066
"TensorVariable",
1067-
"TensorVariable | SymbolicVariable | ConstantVariable | NumpyVariable",
1067+
"TensorVariable | SymbolicVariable | ConstantVariable | NumpyNumberVariable",
10681068
),
10691069
partial(
10701070
lambda magic_name, var, other: var.graph.call_tensor_method(
@@ -1092,7 +1092,7 @@ def tensor_mod_dispatcher(
10921092
Dispatcher.register(
10931093
binary_fn,
10941094
(
1095-
"SymbolicVariable | ConstantVariable | NumpyVariable",
1095+
"SymbolicVariable | ConstantVariable | NumpyNumberVariable",
10961096
"TensorVariable",
10971097
),
10981098
partial(

python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
GlobalVariable,
2828
ModuleVariable,
2929
NullVariable,
30+
NumpyArrayVariable,
31+
NumpyNumberVariable,
3032
NumpyVariable,
3133
ObjectVariable,
3234
ParameterVariable,

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

+53-31
Original file line numberDiff line numberDiff line change
@@ -1202,49 +1202,71 @@ def main_info(self) -> dict[str, Any]:
12021202
def get_py_value(self, allow_tensor=False) -> Any:
12031203
return self.value
12041204

1205+
@staticmethod
1206+
def format_dtype(dtype: np.dtype):
1207+
return f"np.{dtype}"
1208+
1209+
@staticmethod
1210+
def format_number(number: np.number):
1211+
return f"{NumpyVariable.format_dtype(number.dtype)}({number.item()})"
1212+
1213+
def make_stringified_guard(self) -> None:
1214+
raise NotImplementedError
1215+
1216+
@VariableFactory.register_from_value()
1217+
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
1218+
if isinstance(value, (np.number)):
1219+
return NumpyNumberVariable(value, graph, tracker)
1220+
if isinstance(value, (np.ndarray)):
1221+
return NumpyArrayVariable(value, graph, tracker)
1222+
return None
1223+
1224+
1225+
class NumpyNumberVariable(NumpyVariable):
12051226
@check_guard
12061227
def make_stringified_guard(self) -> list[StringifiedExpression]:
12071228
frame_value_tracer = self.tracker.trace_value_from_frame()
12081229
obj_free_var_name = f"__{self.id}"
12091230

1210-
def format_dtype(dtype: np.dtype):
1211-
return f"np.{dtype}"
1231+
dtype_guard = StringifiedExpression(
1232+
f"{{}}.dtype == {NumpyVariable.format_dtype(self.get_py_value().dtype)}",
1233+
[frame_value_tracer],
1234+
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
1235+
)
1236+
1237+
return [
1238+
dtype_guard,
1239+
StringifiedExpression(
1240+
f"{{}} == {NumpyVariable.format_number(self.get_py_value())}",
1241+
[frame_value_tracer],
1242+
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
1243+
),
1244+
]
1245+
12121246

1213-
def format_number(number: np.number):
1214-
return f"{format_dtype(number.dtype)}({number.item()})"
1247+
class NumpyArrayVariable(NumpyVariable):
1248+
@check_guard
1249+
def make_stringified_guard(self) -> list[StringifiedExpression]:
1250+
frame_value_tracer = self.tracker.trace_value_from_frame()
1251+
obj_free_var_name = f"__{self.id}"
12151252

12161253
dtype_guard = StringifiedExpression(
1217-
f"{{}}.dtype == {format_dtype(self.get_py_value().dtype)}",
1254+
f"{{}}.dtype == {NumpyVariable.format_dtype(self.get_py_value().dtype)}",
12181255
[frame_value_tracer],
12191256
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
12201257
)
1221-
if isinstance(self.get_py_value(), np.number):
1222-
return [
1223-
dtype_guard,
1224-
StringifiedExpression(
1225-
f"{{}} == {format_number(self.get_py_value())}",
1226-
[frame_value_tracer],
1227-
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
1228-
),
1229-
]
1230-
else:
1231-
return [
1232-
dtype_guard,
1233-
StringifiedExpression(
1234-
f"({{}} == {obj_free_var_name}).all()",
1235-
[frame_value_tracer],
1236-
union_free_vars(
1237-
frame_value_tracer.free_vars,
1238-
{obj_free_var_name: self.get_py_value()},
1239-
),
1240-
),
1241-
]
12421258

1243-
@VariableFactory.register_from_value()
1244-
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
1245-
if isinstance(value, (np.ndarray, np.number)):
1246-
return NumpyVariable(value, graph, tracker)
1247-
return None
1259+
return [
1260+
dtype_guard,
1261+
StringifiedExpression(
1262+
f"({{}} == {obj_free_var_name}).all()",
1263+
[frame_value_tracer],
1264+
union_free_vars(
1265+
frame_value_tracer.free_vars,
1266+
{obj_free_var_name: self.get_py_value()},
1267+
),
1268+
),
1269+
]
12481270

12491271

12501272
class NullVariable(VariableBase):

0 commit comments

Comments
 (0)