Skip to content

Commit fc2efd4

Browse files
committed
impl TensorDtypeVariable.make_faster_guard
1 parent 7c9ba5f commit fc2efd4

File tree

1 file changed

+13
-1
lines changed
  • python/paddle/jit/sot/opcode_translator/executor/variables

1 file changed

+13
-1
lines changed

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,19 @@ def __init__(self, value, graph, tracker):
335335

336336
@check_faster_guard
337337
def make_faster_guard(self) -> list[paddle.framework.core.GuardNodeBase]:
338-
raise NotImplementedError
338+
if isinstance(self.tracker, GetAttrTracker) and isinstance(
339+
self.tracker.obj, TensorVariable
340+
):
341+
expr_node = self.tracker.obj.tracker.guard_tree_expr_node()
342+
assert paddle.framework.use_pir_api(), "Only support PIR"
343+
return [
344+
paddle.framework.core.GuardNode(
345+
paddle.framework.core.DtypeMatchGuard(self.value),
346+
[expr_node],
347+
)
348+
]
349+
else:
350+
return object_equal_faster_guard(self)
339351

340352
@check_guard
341353
def make_stringified_guard(self) -> list[StringifiedExpression]:

0 commit comments

Comments
 (0)