Skip to content

Commit 9e1285c

Browse files
committed
🐛 fix: treat as guard miss when strict_guard is off and guard_tree misses cache_index
1 parent 494b078 commit 9e1285c

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

python/paddle/jit/sot/opcode_translator/executor/executor_cache.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,20 @@ def lookup(
152152
if enable_strict_guard or enable_guard_tree:
153153
cache_index = guard_tree.lookup(frame)
154154

155-
if not enable_strict_guard and cache_index is not None:
156-
# TODO(zrr1999): add a mapping between custom_code and cache_index
157-
return guarded_fns[cache_index][0]
155+
if not enable_strict_guard:
156+
if cache_index is not None:
157+
# TODO(zrr1999): add a mapping between custom_code and cache_index
158+
return guarded_fns[cache_index][0]
159+
else:
160+
log(2, "[Cache]: all guards missed (guard tree mode)\n")
161+
new_custom_code, guard_fn, guard_chain = self.translate(
162+
frame, **kwargs
163+
)
164+
if guard_fn is not None:
165+
assert guard_chain is not None
166+
guarded_fns.append((new_custom_code, guard_fn))
167+
guard_tree.add_guard_chain(guard_chain)
168+
return new_custom_code
158169

159170
for index, (custom_code, guard_fn) in enumerate(guarded_fns):
160171
if enable_strict_guard:

test/sot/test_instruction_translator_cache.py

+2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def __init__(self, recompile):
9090
self.func = lambda frame: recompile
9191
self.mirror_guard = self.func
9292
self.expr = f"lambda frame: {recompile}"
93+
self.inlined_expr = f"lambda frame: {recompile}"
94+
self.__globals__ = {}
9395

9496
def __call__(self, *args, **kwargs):
9597
return self.func(*args, **kwargs)

0 commit comments

Comments
 (0)