20
20
)
21
21
22
22
import paddle
23
- from paddle .jit .sot import symbolic_translate
24
23
from paddle .jit .sot .psdb import check_no_breakgraph
25
24
26
25
@@ -78,27 +77,35 @@ def test_super_self_name(self):
78
77
def test_guard_run (self ): # test guard
79
78
with test_instruction_translator_cache_context () as ctx :
80
79
self .assertEqual (ctx .translate_count , 0 )
81
- symbolic_translate (B ().test_super_no_args_add2 )(paddle .to_tensor (1 ))
82
- symbolic_translate (B ().test_super_no_args_add2 )(paddle .to_tensor (2 ))
80
+ self .assert_results (
81
+ B ().test_super_no_args_add2 , paddle .to_tensor (1 )
82
+ )
83
+ self .assert_results (
84
+ B ().test_super_no_args_add2 , paddle .to_tensor (2 )
85
+ )
83
86
self .assertEqual (ctx .translate_count , 1 )
84
87
85
88
with test_instruction_translator_cache_context () as ctx :
86
89
self .assertEqual (ctx .translate_count , 0 )
87
- symbolic_translate (B ().test_super_no_args_add2 )(paddle .to_tensor (1 ))
88
- symbolic_translate (B ().test_super_no_args_add2 )(paddle .to_tensor (2 ))
90
+ self .assert_results (
91
+ B ().test_super_no_args_add2 , paddle .to_tensor (1 )
92
+ )
93
+ self .assert_results (
94
+ B ().test_super_no_args_add2 , paddle .to_tensor (2 )
95
+ )
89
96
self .assertEqual (ctx .translate_count , 1 )
90
- symbolic_translate ( B (). test_super_with_args_add3 ) (
91
- paddle .to_tensor (3 )
97
+ self . assert_results (
98
+ B (). test_super_with_args_add3 , paddle .to_tensor (3 )
92
99
)
93
- symbolic_translate ( B (). test_super_with_args_add3 ) (
94
- paddle .to_tensor (4 )
100
+ self . assert_results (
101
+ B (). test_super_with_args_add3 , paddle .to_tensor (4 )
95
102
)
96
103
self .assertEqual (ctx .translate_count , 2 )
97
104
98
105
with test_instruction_translator_cache_context () as ctx :
99
106
self .assertEqual (ctx .translate_count , 0 )
100
- symbolic_translate (B ().test_super_both_add5 )( paddle .to_tensor (5 ))
101
- symbolic_translate (B ().test_super_both_add5 )( paddle .to_tensor (6 ))
107
+ self . assert_results (B ().test_super_both_add5 , paddle .to_tensor (5 ))
108
+ self . assert_results (B ().test_super_both_add5 , paddle .to_tensor (6 ))
102
109
self .assertEqual (ctx .translate_count , 1 )
103
110
104
111
@@ -159,27 +166,27 @@ def test_with_args(self):
159
166
def test_guard_run (self ): # test guard
160
167
with test_instruction_translator_cache_context () as ctx :
161
168
self .assertEqual (ctx .translate_count , 0 )
162
- symbolic_translate (Q ().addx )( paddle .to_tensor (1 ))
163
- symbolic_translate (Q ().addx )( paddle .to_tensor (2 ))
164
- symbolic_translate (Q ().addx )( paddle .to_tensor (3 ))
169
+ self . assert_results (Q ().addx , paddle .to_tensor (1 ))
170
+ self . assert_results (Q ().addx , paddle .to_tensor (2 ))
171
+ self . assert_results (Q ().addx , paddle .to_tensor (3 ))
165
172
self .assertEqual (ctx .translate_count , 1 )
166
173
167
174
with test_instruction_translator_cache_context () as ctx :
168
175
self .assertEqual (ctx .translate_count , 0 )
169
- symbolic_translate (Q ().addxP )( paddle .to_tensor (4 ))
170
- symbolic_translate (Q ().addxP )( paddle .to_tensor (5 ))
171
- symbolic_translate (Q ().addxP )( paddle .to_tensor (6 ))
176
+ self . assert_results (Q ().addxP , paddle .to_tensor (4 ))
177
+ self . assert_results (Q ().addxP , paddle .to_tensor (5 ))
178
+ self . assert_results (Q ().addxP , paddle .to_tensor (6 ))
172
179
self .assertEqual (ctx .translate_count , 1 )
173
180
174
181
with test_instruction_translator_cache_context () as ctx :
175
182
self .assertEqual (ctx .translate_count , 0 )
176
- symbolic_translate (Q ().addxZ )( paddle .to_tensor (7 ))
177
- symbolic_translate (Q ().addxZ )( paddle .to_tensor (8 ))
178
- symbolic_translate (Q ().addxZ )( paddle .to_tensor (9 ))
183
+ self . assert_results (Q ().addxZ , paddle .to_tensor (7 ))
184
+ self . assert_results (Q ().addxZ , paddle .to_tensor (8 ))
185
+ self . assert_results (Q ().addxZ , paddle .to_tensor (9 ))
179
186
self .assertEqual (ctx .translate_count , 1 )
180
- symbolic_translate (Q ().addxP )( paddle .to_tensor (4 ))
181
- symbolic_translate (Q ().addxP )( paddle .to_tensor (5 ))
182
- symbolic_translate (Q ().addxP )( paddle .to_tensor (6 ))
187
+ self . assert_results (Q ().addxP , paddle .to_tensor (4 ))
188
+ self . assert_results (Q ().addxP , paddle .to_tensor (5 ))
189
+ self . assert_results (Q ().addxP , paddle .to_tensor (6 ))
183
190
self .assertEqual (ctx .translate_count , 2 )
184
191
185
192
@@ -202,9 +209,9 @@ def test_super_as_input(self):
202
209
def test_guard_run (self ): # test guard
203
210
with test_instruction_translator_cache_context () as ctx :
204
211
self .assertEqual (ctx .translate_count , 0 )
205
- symbolic_translate (super_as_input )( super ())
206
- symbolic_translate (super_as_input )( super ())
207
- symbolic_translate (super_as_input )( super ())
212
+ self . assert_results (super_as_input , super ())
213
+ self . assert_results (super_as_input , super ())
214
+ self . assert_results (super_as_input , super ())
208
215
self .assertEqual (ctx .translate_count , 1 )
209
216
210
217
@@ -244,19 +251,19 @@ def test_guard_run(self): # test guard
244
251
with test_instruction_translator_cache_context () as ctx :
245
252
self .assertEqual (ctx .translate_count , 0 )
246
253
x = paddle .to_tensor (3 )
247
- symbolic_translate ( ClassSuperAsInputC (). test_super_as_input ) (
248
- x , ClassSuperAsInputC
254
+ self . assert_results (
255
+ ClassSuperAsInputC (). test_super_as_input , x , ClassSuperAsInputC
249
256
)
250
- symbolic_translate ( ClassSuperAsInputC (). test_super_as_input ) (
251
- x , ClassSuperAsInputC
257
+ self . assert_results (
258
+ ClassSuperAsInputC (). test_super_as_input , x , ClassSuperAsInputC
252
259
)
253
260
self .assertEqual (ctx .translate_count , 1 )
254
261
x = paddle .to_tensor (4 )
255
- symbolic_translate ( ClassSuperAsInputC (). test_super_as_input ) (
256
- x , ClassSuperAsInputB
262
+ self . assert_results (
263
+ ClassSuperAsInputC (). test_super_as_input , x , ClassSuperAsInputB
257
264
)
258
- symbolic_translate ( ClassSuperAsInputC (). test_super_as_input ) (
259
- x , ClassSuperAsInputB
265
+ self . assert_results (
266
+ ClassSuperAsInputC (). test_super_as_input , x , ClassSuperAsInputB
260
267
)
261
268
self .assertEqual (ctx .translate_count , 2 )
262
269
@@ -289,9 +296,9 @@ def test_guard_run(self): # test guard
289
296
x = paddle .to_tensor ([4.0 ])
290
297
with test_instruction_translator_cache_context () as ctx :
291
298
self .assertEqual (ctx .translate_count , 0 )
292
- symbolic_translate (ClassWithAttributionC ().foo )( x )
293
- symbolic_translate (ClassWithAttributionC ().foo )( x )
294
- symbolic_translate (ClassWithAttributionC ().foo )( x )
299
+ self . assert_results (ClassWithAttributionC ().foo , x )
300
+ self . assert_results (ClassWithAttributionC ().foo , x )
301
+ self . assert_results (ClassWithAttributionC ().foo , x )
295
302
self .assertEqual (ctx .translate_count , 1 )
296
303
297
304
@@ -318,21 +325,23 @@ def super_function_as_input(self, fn, x):
318
325
319
326
320
327
# We create a fake `super` and inject it to `__globals__` of the function
321
- new_globals = FakeSuperClass .fake_super .__globals__ .copy ()
328
+ new_globals = FakeSuperClass .fake_super_function .__globals__ .copy ()
322
329
new_globals ["super" ] = lambda x , y : Toy ()
323
330
324
- FakeSuperClass .fake_super = types .FunctionType (
325
- FakeSuperClass .fake_super .__code__ ,
331
+ FakeSuperClass .fake_super_function = types .FunctionType (
332
+ FakeSuperClass .fake_super_function .__code__ ,
326
333
new_globals ,
327
- name = FakeSuperClass .fake_super .__name__ ,
328
- argdefs = FakeSuperClass .fake_super .__defaults__ ,
329
- closure = FakeSuperClass .fake_super .__closure__ ,
334
+ name = FakeSuperClass .fake_super_function .__name__ ,
335
+ argdefs = FakeSuperClass .fake_super_function .__defaults__ ,
336
+ closure = FakeSuperClass .fake_super_function .__closure__ ,
330
337
)
331
338
332
339
333
340
class TestCustomSuper (TestCaseBase ):
334
341
def test_fake_super (self ):
335
- self .assert_results (FakeSuperClass ().fake_super , paddle .to_tensor (3.0 ))
342
+ self .assert_results (
343
+ FakeSuperClass ().fake_super_function , paddle .to_tensor (3.0 )
344
+ )
336
345
337
346
def test_super_function_as_input (self ):
338
347
self .assert_exceptions (
0 commit comments