@@ -29,15 +29,25 @@ def fn(x, y):
29
29
return x + y
30
30
31
31
32
+ @unittest .skipIf (
33
+ not paddle .is_compiled_with_distribute (),
34
+ reason = 'Not compiled with distribute.' ,
35
+ )
32
36
class TestGuardForDistInfo (TestCaseBase ):
33
37
def test_fn (self ):
34
38
x = paddle .ones ([2 , 2 ])
39
+ x .stop_gradient = False
35
40
y = paddle .zeros ([2 , 2 ])
41
+ y .stop_gradient = False
36
42
mesh1 = dist .ProcessMesh ([0 , 1 ], dim_names = ['x' ])
37
43
mesh2 = dist .ProcessMesh ([0 , 1 ], dim_names = ['y' ])
38
44
mesh3 = dist .ProcessMesh ([0 , 2 ], dim_names = ['x' ])
39
- dist_x1 = dist .shard_tensor (x , mesh1 , [dist .Replicate ()])
40
- dist_y1 = dist .shard_tensor (y , mesh1 , [dist .Replicate ()])
45
+ dist_x1 = dist .shard_tensor (
46
+ x , mesh1 , [dist .Replicate ()], stop_gradient = False
47
+ )
48
+ dist_y1 = dist .shard_tensor (
49
+ y , mesh1 , [dist .Replicate ()], stop_gradient = False
50
+ )
41
51
dist_x2 = dist .shard_tensor (x , mesh2 , [dist .Replicate ()])
42
52
dist_y2 = dist .shard_tensor (y , mesh2 , [dist .Replicate ()])
43
53
dist_x3 = dist .shard_tensor (x , mesh3 , [dist .Replicate ()])
0 commit comments