Skip to content

Commit 9278ad5

Browse files
committed
Update test case info
1 parent 7ac4339 commit 9278ad5

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

test/sot/test_sot_distribution.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,13 @@
2424
from paddle.jit.sot.psdb import check_no_breakgraph
2525

2626

27-
def apply_fn(fn, x, y):
28-
return fn(x, y)
29-
30-
3127
@check_no_breakgraph
32-
def fn1(x, y):
28+
def fn(x, y):
3329
return x + y
3430

3531

36-
class TestApplyDifferentFunctions(TestCaseBase):
37-
def test_apply_fn(self):
32+
class TestGuardForDistInfo(TestCaseBase):
33+
def test_fn(self):
3834
x = paddle.ones([2, 2])
3935
y = paddle.zeros([2, 2])
4036
mesh1 = dist.ProcessMesh([0, 1], dim_names=['x'])
@@ -48,11 +44,11 @@ def test_apply_fn(self):
4844
dist_y3 = dist.shard_tensor(y, mesh3, [dist.Replicate()])
4945
with test_instruction_translator_cache_context() as ctx:
5046
self.assertEqual(ctx.translate_count, 0)
51-
self.assert_results(apply_fn, fn1, dist_x1, dist_y1)
47+
self.assert_results(fn, dist_x1, dist_y1)
5248
self.assertEqual(ctx.translate_count, 1)
53-
self.assert_results(apply_fn, fn1, dist_x2, dist_y2)
49+
self.assert_results(fn, dist_x2, dist_y2)
5450
self.assertEqual(ctx.translate_count, 1)
55-
self.assert_results(apply_fn, fn1, dist_x3, dist_y3)
51+
self.assert_results(fn, dist_x3, dist_y3)
5652
self.assertEqual(ctx.translate_count, 2)
5753

5854

0 commit comments

Comments
 (0)