Skip to content

Commit a1dad92

Browse files
committed
fix DummyTracker && symbolic_translate -> self.assert_results
1 parent 7cc9c5c commit a1dad92

File tree

2 files changed

+54
-45
lines changed

2 files changed

+54
-45
lines changed

python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def inner(*args, **kwargs):
245245
cls=cls,
246246
obj=obj,
247247
graph=Dispatcher.graph,
248-
tracker=DummyTracker([]),
248+
tracker=DummyTracker([cls, obj]),
249249
),
250250
)
251251

test/sot/test_23_super.py

+53-44
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121

2222
import paddle
23-
from paddle.jit.sot import symbolic_translate
2423
from paddle.jit.sot.psdb import check_no_breakgraph
2524

2625

@@ -78,27 +77,35 @@ def test_super_self_name(self):
7877
def test_guard_run(self): # test guard
7978
with test_instruction_translator_cache_context() as ctx:
8079
self.assertEqual(ctx.translate_count, 0)
81-
symbolic_translate(B().test_super_no_args_add2)(paddle.to_tensor(1))
82-
symbolic_translate(B().test_super_no_args_add2)(paddle.to_tensor(2))
80+
self.assert_results(
81+
B().test_super_no_args_add2, paddle.to_tensor(1)
82+
)
83+
self.assert_results(
84+
B().test_super_no_args_add2, paddle.to_tensor(2)
85+
)
8386
self.assertEqual(ctx.translate_count, 1)
8487

8588
with test_instruction_translator_cache_context() as ctx:
8689
self.assertEqual(ctx.translate_count, 0)
87-
symbolic_translate(B().test_super_no_args_add2)(paddle.to_tensor(1))
88-
symbolic_translate(B().test_super_no_args_add2)(paddle.to_tensor(2))
90+
self.assert_results(
91+
B().test_super_no_args_add2, paddle.to_tensor(1)
92+
)
93+
self.assert_results(
94+
B().test_super_no_args_add2, paddle.to_tensor(2)
95+
)
8996
self.assertEqual(ctx.translate_count, 1)
90-
symbolic_translate(B().test_super_with_args_add3)(
91-
paddle.to_tensor(3)
97+
self.assert_results(
98+
B().test_super_with_args_add3, paddle.to_tensor(3)
9299
)
93-
symbolic_translate(B().test_super_with_args_add3)(
94-
paddle.to_tensor(4)
100+
self.assert_results(
101+
B().test_super_with_args_add3, paddle.to_tensor(4)
95102
)
96103
self.assertEqual(ctx.translate_count, 2)
97104

98105
with test_instruction_translator_cache_context() as ctx:
99106
self.assertEqual(ctx.translate_count, 0)
100-
symbolic_translate(B().test_super_both_add5)(paddle.to_tensor(5))
101-
symbolic_translate(B().test_super_both_add5)(paddle.to_tensor(6))
107+
self.assert_results(B().test_super_both_add5, paddle.to_tensor(5))
108+
self.assert_results(B().test_super_both_add5, paddle.to_tensor(6))
102109
self.assertEqual(ctx.translate_count, 1)
103110

104111

@@ -159,27 +166,27 @@ def test_with_args(self):
159166
def test_guard_run(self): # test guard
160167
with test_instruction_translator_cache_context() as ctx:
161168
self.assertEqual(ctx.translate_count, 0)
162-
symbolic_translate(Q().addx)(paddle.to_tensor(1))
163-
symbolic_translate(Q().addx)(paddle.to_tensor(2))
164-
symbolic_translate(Q().addx)(paddle.to_tensor(3))
169+
self.assert_results(Q().addx, paddle.to_tensor(1))
170+
self.assert_results(Q().addx, paddle.to_tensor(2))
171+
self.assert_results(Q().addx, paddle.to_tensor(3))
165172
self.assertEqual(ctx.translate_count, 1)
166173

167174
with test_instruction_translator_cache_context() as ctx:
168175
self.assertEqual(ctx.translate_count, 0)
169-
symbolic_translate(Q().addxP)(paddle.to_tensor(4))
170-
symbolic_translate(Q().addxP)(paddle.to_tensor(5))
171-
symbolic_translate(Q().addxP)(paddle.to_tensor(6))
176+
self.assert_results(Q().addxP, paddle.to_tensor(4))
177+
self.assert_results(Q().addxP, paddle.to_tensor(5))
178+
self.assert_results(Q().addxP, paddle.to_tensor(6))
172179
self.assertEqual(ctx.translate_count, 1)
173180

174181
with test_instruction_translator_cache_context() as ctx:
175182
self.assertEqual(ctx.translate_count, 0)
176-
symbolic_translate(Q().addxZ)(paddle.to_tensor(7))
177-
symbolic_translate(Q().addxZ)(paddle.to_tensor(8))
178-
symbolic_translate(Q().addxZ)(paddle.to_tensor(9))
183+
self.assert_results(Q().addxZ, paddle.to_tensor(7))
184+
self.assert_results(Q().addxZ, paddle.to_tensor(8))
185+
self.assert_results(Q().addxZ, paddle.to_tensor(9))
179186
self.assertEqual(ctx.translate_count, 1)
180-
symbolic_translate(Q().addxP)(paddle.to_tensor(4))
181-
symbolic_translate(Q().addxP)(paddle.to_tensor(5))
182-
symbolic_translate(Q().addxP)(paddle.to_tensor(6))
187+
self.assert_results(Q().addxP, paddle.to_tensor(4))
188+
self.assert_results(Q().addxP, paddle.to_tensor(5))
189+
self.assert_results(Q().addxP, paddle.to_tensor(6))
183190
self.assertEqual(ctx.translate_count, 2)
184191

185192

@@ -202,9 +209,9 @@ def test_super_as_input(self):
202209
def test_guard_run(self): # test guard
203210
with test_instruction_translator_cache_context() as ctx:
204211
self.assertEqual(ctx.translate_count, 0)
205-
symbolic_translate(super_as_input)(super())
206-
symbolic_translate(super_as_input)(super())
207-
symbolic_translate(super_as_input)(super())
212+
self.assert_results(super_as_input, super())
213+
self.assert_results(super_as_input, super())
214+
self.assert_results(super_as_input, super())
208215
self.assertEqual(ctx.translate_count, 1)
209216

210217

@@ -244,19 +251,19 @@ def test_guard_run(self): # test guard
244251
with test_instruction_translator_cache_context() as ctx:
245252
self.assertEqual(ctx.translate_count, 0)
246253
x = paddle.to_tensor(3)
247-
symbolic_translate(ClassSuperAsInputC().test_super_as_input)(
248-
x, ClassSuperAsInputC
254+
self.assert_results(
255+
ClassSuperAsInputC().test_super_as_input, x, ClassSuperAsInputC
249256
)
250-
symbolic_translate(ClassSuperAsInputC().test_super_as_input)(
251-
x, ClassSuperAsInputC
257+
self.assert_results(
258+
ClassSuperAsInputC().test_super_as_input, x, ClassSuperAsInputC
252259
)
253260
self.assertEqual(ctx.translate_count, 1)
254261
x = paddle.to_tensor(4)
255-
symbolic_translate(ClassSuperAsInputC().test_super_as_input)(
256-
x, ClassSuperAsInputB
262+
self.assert_results(
263+
ClassSuperAsInputC().test_super_as_input, x, ClassSuperAsInputB
257264
)
258-
symbolic_translate(ClassSuperAsInputC().test_super_as_input)(
259-
x, ClassSuperAsInputB
265+
self.assert_results(
266+
ClassSuperAsInputC().test_super_as_input, x, ClassSuperAsInputB
260267
)
261268
self.assertEqual(ctx.translate_count, 2)
262269

@@ -289,9 +296,9 @@ def test_guard_run(self): # test guard
289296
x = paddle.to_tensor([4.0])
290297
with test_instruction_translator_cache_context() as ctx:
291298
self.assertEqual(ctx.translate_count, 0)
292-
symbolic_translate(ClassWithAttributionC().foo)(x)
293-
symbolic_translate(ClassWithAttributionC().foo)(x)
294-
symbolic_translate(ClassWithAttributionC().foo)(x)
299+
self.assert_results(ClassWithAttributionC().foo, x)
300+
self.assert_results(ClassWithAttributionC().foo, x)
301+
self.assert_results(ClassWithAttributionC().foo, x)
295302
self.assertEqual(ctx.translate_count, 1)
296303

297304

@@ -318,21 +325,23 @@ def super_function_as_input(self, fn, x):
318325

319326

320327
# We create a fake `super` and inject it to `__globals__` of the function
321-
new_globals = FakeSuperClass.fake_super.__globals__.copy()
328+
new_globals = FakeSuperClass.fake_super_function.__globals__.copy()
322329
new_globals["super"] = lambda x, y: Toy()
323330

324-
FakeSuperClass.fake_super = types.FunctionType(
325-
FakeSuperClass.fake_super.__code__,
331+
FakeSuperClass.fake_super_function = types.FunctionType(
332+
FakeSuperClass.fake_super_function.__code__,
326333
new_globals,
327-
name=FakeSuperClass.fake_super.__name__,
328-
argdefs=FakeSuperClass.fake_super.__defaults__,
329-
closure=FakeSuperClass.fake_super.__closure__,
334+
name=FakeSuperClass.fake_super_function.__name__,
335+
argdefs=FakeSuperClass.fake_super_function.__defaults__,
336+
closure=FakeSuperClass.fake_super_function.__closure__,
330337
)
331338

332339

333340
class TestCustomSuper(TestCaseBase):
334341
def test_fake_super(self):
335-
self.assert_results(FakeSuperClass().fake_super, paddle.to_tensor(3.0))
342+
self.assert_results(
343+
FakeSuperClass().fake_super_function, paddle.to_tensor(3.0)
344+
)
336345

337346
def test_super_function_as_input(self):
338347
self.assert_exceptions(

0 commit comments

Comments
 (0)