24
24
from paddle .jit .sot .psdb import check_no_breakgraph
25
25
26
26
27
- def apply_fn (fn , x , y ):
28
- return fn (x , y )
29
-
30
-
31
27
@check_no_breakgraph
32
- def fn1 (x , y ):
28
+ def fn (x , y ):
33
29
return x + y
34
30
35
31
36
- class TestApplyDifferentFunctions (TestCaseBase ):
37
- def test_apply_fn (self ):
32
+ class TestGuardForDistInfo (TestCaseBase ):
33
+ def test_fn (self ):
38
34
x = paddle .ones ([2 , 2 ])
39
35
y = paddle .zeros ([2 , 2 ])
40
36
mesh1 = dist .ProcessMesh ([0 , 1 ], dim_names = ['x' ])
@@ -48,11 +44,11 @@ def test_apply_fn(self):
48
44
dist_y3 = dist .shard_tensor (y , mesh3 , [dist .Replicate ()])
49
45
with test_instruction_translator_cache_context () as ctx :
50
46
self .assertEqual (ctx .translate_count , 0 )
51
- self .assert_results (apply_fn , fn1 , dist_x1 , dist_y1 )
47
+ self .assert_results (fn , dist_x1 , dist_y1 )
52
48
self .assertEqual (ctx .translate_count , 1 )
53
- self .assert_results (apply_fn , fn1 , dist_x2 , dist_y2 )
49
+ self .assert_results (fn , dist_x2 , dist_y2 )
54
50
self .assertEqual (ctx .translate_count , 1 )
55
- self .assert_results (apply_fn , fn1 , dist_x3 , dist_y3 )
51
+ self .assert_results (fn , dist_x3 , dist_y3 )
56
52
self .assertEqual (ctx .translate_count , 2 )
57
53
58
54
0 commit comments