Skip to content

Commit c254fa8

Browse files
committed
Set stop_gradient=False and Skip the test when not is_compiled_with_distribute
1 parent 9278ad5 commit c254fa8

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

test/sot/test_sot_distribution.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,25 @@ def fn(x, y):
2929
return x + y
3030

3131

32+
@unittest.skipIf(
33+
not paddle.is_compiled_with_distribute(),
34+
reason='Not compiled with distribute.',
35+
)
3236
class TestGuardForDistInfo(TestCaseBase):
3337
def test_fn(self):
3438
x = paddle.ones([2, 2])
39+
x.stop_gradient = False
3540
y = paddle.zeros([2, 2])
41+
y.stop_gradient = False
3642
mesh1 = dist.ProcessMesh([0, 1], dim_names=['x'])
3743
mesh2 = dist.ProcessMesh([0, 1], dim_names=['y'])
3844
mesh3 = dist.ProcessMesh([0, 2], dim_names=['x'])
39-
dist_x1 = dist.shard_tensor(x, mesh1, [dist.Replicate()])
40-
dist_y1 = dist.shard_tensor(y, mesh1, [dist.Replicate()])
45+
dist_x1 = dist.shard_tensor(
46+
x, mesh1, [dist.Replicate()], stop_gradient=False
47+
)
48+
dist_y1 = dist.shard_tensor(
49+
y, mesh1, [dist.Replicate()], stop_gradient=False
50+
)
4151
dist_x2 = dist.shard_tensor(x, mesh2, [dist.Replicate()])
4252
dist_y2 = dist.shard_tensor(y, mesh2, [dist.Replicate()])
4353
dist_x3 = dist.shard_tensor(x, mesh3, [dist.Replicate()])

0 commit comments

Comments
 (0)