Skip to content

Commit 4274259

Browse files
committed
add value limit
1 parent 46fd4b1 commit 4274259

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ def get_py_value(self, allow_tensor: bool = False) -> bool | int | float:
837837
self.need_guard_value = True
838838
log(
839839
3,
840-
f"get_py_value from SymbolicVariable {self} caused value need guard",
840+
f"get_py_value from SymbolicVariable {self} caused value need guard\n",
841841
)
842842
if isinstance(self.value, SymbolicValue):
843843
assert isinstance(
@@ -941,6 +941,8 @@ def should_create_symbolic_variable(
941941
tracker: Tracker,
942942
symbolic_inputs: dict[str, dict[int, int] | None],
943943
):
944+
if value < 2:
945+
return False
944946
tracker_expr = tracker.trace_value_from_frame().inlined_expr
945947
symbolic_inputs.setdefault(tracker_expr, {})
946948
if tracker_expr in symbolic_inputs:

test/sot/test_sot_dynamic_shape.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_dynamic_int_input_cache_hit_case1(self):
100100
self.assert_results(
101101
dynamic_int_input_func1, paddle.randn([3, 4, 5]), i
102102
)
103-
self.assertEqual(ctx.translate_count, 2)
103+
self.assertEqual(ctx.translate_count, 2 if i == 2 else 3)
104104

105105
def test_dynamic_int_input_cache_hit_case2(self):
106106
with allow_dynamic_shape_guard(
@@ -114,7 +114,7 @@ def test_dynamic_int_input_cache_hit_case2(self):
114114
self.assert_results(
115115
dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: i}
116116
)
117-
self.assertEqual(ctx.translate_count, 2)
117+
self.assertEqual(ctx.translate_count, 2 if i == 2 else 3)
118118

119119
def test_dynamic_int_input_cache_hit_case3(self):
120120
with allow_dynamic_shape_guard(
@@ -138,7 +138,7 @@ def test_dynamic_shape_input_cache_hit_case1(self):
138138
self.assert_results(
139139
dynamic_shape_input_func1, paddle.randn([i, 4, 5])
140140
)
141-
self.assertEqual(ctx.translate_count, 2)
141+
self.assertEqual(ctx.translate_count, 2 if i == 2 else 3)
142142

143143
def test_dynamic_shape_input_cache_hit_case2(self):
144144
with allow_dynamic_shape_guard(
@@ -153,7 +153,7 @@ def test_dynamic_shape_input_cache_hit_case2(self):
153153
dynamic_shape_access_inner_var_shape,
154154
paddle.randn([i, 4, 5]),
155155
)
156-
self.assertEqual(ctx.translate_count, 2)
156+
self.assertEqual(ctx.translate_count, 2 if i == 2 else 3)
157157

158158
def test_dynamic_shape_cast(self):
159159
with allow_dynamic_shape_guard(
@@ -182,7 +182,7 @@ def test_dynamic_shape_in_list(self):
182182
paddle.randn([i, 4, 5]),
183183
[i * 4, 5],
184184
)
185-
self.assertEqual(ctx.translate_count, 2)
185+
self.assertEqual(ctx.translate_count, 2 if i == 2 else 3)
186186

187187
def test_conv_dynamic_shape_stride_fallback(self):
188188
with allow_dynamic_shape_guard(

test/sot/test_trace_list_arg.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ def test_bar_dynamic_shape(self):
7070
self.assertEqual(cache.translate_count, 1)
7171
self.assert_results(bar, a, 2, 0) # Cache miss
7272
self.assertEqual(cache.translate_count, 2)
73-
self.assert_results(bar, b, 2, 0) # Cache miss
74-
self.assertEqual(cache.translate_count, 3)
7573
self.assert_results(bar, b, 2, 0) # Cache hit
76-
self.assertEqual(cache.translate_count, 3)
74+
self.assertEqual(cache.translate_count, 2)
75+
self.assert_results(bar, b, 2, 0) # Cache hit
76+
self.assertEqual(cache.translate_count, 2)
7777
self.assert_results(bar, b, 1, 1) # Cache hit
78-
self.assertEqual(cache.translate_count, 3)
78+
self.assertEqual(cache.translate_count, 2)
7979
self.assert_results(bar, b, 0, 2) # Cache miss
80-
self.assertEqual(cache.translate_count, 4)
80+
self.assertEqual(cache.translate_count, 3)
8181

8282

8383
if __name__ == "__main__":

0 commit comments

Comments
 (0)