@@ -1202,49 +1202,71 @@ def main_info(self) -> dict[str, Any]:
1202
1202
def get_py_value (self , allow_tensor = False ) -> Any :
1203
1203
return self .value
1204
1204
1205
+ @staticmethod
1206
+ def format_dtype (dtype : np .dtype ):
1207
+ return f"np.{ dtype } "
1208
+
1209
+ @staticmethod
1210
+ def format_number (number : np .number ):
1211
+ return f"{ NumpyVariable .format_dtype (number .dtype )} ({ number .item ()} )"
1212
+
1213
+ def make_stringified_guard (self ) -> None :
1214
+ raise NotImplementedError
1215
+
1216
+ @VariableFactory .register_from_value ()
1217
+ def from_value (value : Any , graph : FunctionGraph , tracker : Tracker ):
1218
+ if isinstance (value , (np .number )):
1219
+ return NumpyNumberVariable (value , graph , tracker )
1220
+ if isinstance (value , (np .ndarray )):
1221
+ return NumpyArrayVariable (value , graph , tracker )
1222
+ return None
1223
+
1224
+
1225
+ class NumpyNumberVariable (NumpyVariable ):
1205
1226
@check_guard
1206
1227
def make_stringified_guard (self ) -> list [StringifiedExpression ]:
1207
1228
frame_value_tracer = self .tracker .trace_value_from_frame ()
1208
1229
obj_free_var_name = f"__{ self .id } "
1209
1230
1210
- def format_dtype (dtype : np .dtype ):
1211
- return f"np.{ dtype } "
1231
+ dtype_guard = StringifiedExpression (
1232
+ f"{{}}.dtype == { NumpyVariable .format_dtype (self .get_py_value ().dtype )} " ,
1233
+ [frame_value_tracer ],
1234
+ union_free_vars (frame_value_tracer .free_vars , {"np" : np }),
1235
+ )
1236
+
1237
+ return [
1238
+ dtype_guard ,
1239
+ StringifiedExpression (
1240
+ f"{{}} == { NumpyVariable .format_number (self .get_py_value ())} " ,
1241
+ [frame_value_tracer ],
1242
+ union_free_vars (frame_value_tracer .free_vars , {"np" : np }),
1243
+ ),
1244
+ ]
1245
+
1212
1246
1213
- def format_number (number : np .number ):
1214
- return f"{ format_dtype (number .dtype )} ({ number .item ()} )"
1247
+ class NumpyArrayVariable (NumpyVariable ):
1248
+ @check_guard
1249
+ def make_stringified_guard (self ) -> list [StringifiedExpression ]:
1250
+ frame_value_tracer = self .tracker .trace_value_from_frame ()
1251
+ obj_free_var_name = f"__{ self .id } "
1215
1252
1216
1253
dtype_guard = StringifiedExpression (
1217
- f"{{}}.dtype == { format_dtype (self .get_py_value ().dtype )} " ,
1254
+ f"{{}}.dtype == { NumpyVariable . format_dtype (self .get_py_value ().dtype )} " ,
1218
1255
[frame_value_tracer ],
1219
1256
union_free_vars (frame_value_tracer .free_vars , {"np" : np }),
1220
1257
)
1221
- if isinstance (self .get_py_value (), np .number ):
1222
- return [
1223
- dtype_guard ,
1224
- StringifiedExpression (
1225
- f"{{}} == { format_number (self .get_py_value ())} " ,
1226
- [frame_value_tracer ],
1227
- union_free_vars (frame_value_tracer .free_vars , {"np" : np }),
1228
- ),
1229
- ]
1230
- else :
1231
- return [
1232
- dtype_guard ,
1233
- StringifiedExpression (
1234
- f"({{}} == { obj_free_var_name } ).all()" ,
1235
- [frame_value_tracer ],
1236
- union_free_vars (
1237
- frame_value_tracer .free_vars ,
1238
- {obj_free_var_name : self .get_py_value ()},
1239
- ),
1240
- ),
1241
- ]
1242
1258
1243
- @VariableFactory .register_from_value ()
1244
- def from_value (value : Any , graph : FunctionGraph , tracker : Tracker ):
1245
- if isinstance (value , (np .ndarray , np .number )):
1246
- return NumpyVariable (value , graph , tracker )
1247
- return None
1259
+ return [
1260
+ dtype_guard ,
1261
+ StringifiedExpression (
1262
+ f"({{}} == { obj_free_var_name } ).all()" ,
1263
+ [frame_value_tracer ],
1264
+ union_free_vars (
1265
+ frame_value_tracer .free_vars ,
1266
+ {obj_free_var_name : self .get_py_value ()},
1267
+ ),
1268
+ ),
1269
+ ]
1248
1270
1249
1271
1250
1272
class NullVariable (VariableBase ):
0 commit comments