Skip to content

Commit 2b7bf33

Browse files
authored
[SOT][Dist] Add guard for dist tensor (#71666)
1 parent 05f0731 commit 2b7bf33

File tree

3 files changed

+154
-1
lines changed

3 files changed

+154
-1
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,31 @@ static PyObject* tensor_method_numpy(TensorObject* self,
398398
dense_tensor->place(),
399399
dense_tensor->Holder()->ptr(),
400400
dense_tensor->Holder()->size());
401+
} else if (self->tensor.is_dist_tensor()) {
402+
#ifdef PADDLE_WITH_DISTRIBUTE
403+
VLOG(6) << "Getting DistTensor's numpy value";
404+
auto* dist_tensor =
405+
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
406+
auto dense_tensor = ReshardXToReplicated(dist_tensor);
407+
408+
cpu_tensor.set_meta(dense_tensor.meta());
409+
auto tmp_allocation_ptr =
410+
memory::Alloc(cpu_place, dense_tensor.Holder()->size());
411+
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
412+
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
413+
paddle::memory::Copy(place,
414+
cpu_tensor.Holder()->ptr(),
415+
dense_tensor.place(),
416+
dense_tensor.Holder()->ptr(),
417+
dense_tensor.Holder()->size());
418+
#else
419+
PADDLE_THROW(
420+
common::errors::Unavailable("The `numpy()` method of (Dist)Tensor "
421+
"is not supported in the current "
422+
"PaddlePaddle, please recompile and "
423+
"installPaddlePaddle with the option "
424+
"of `WITH_DISTRIBUTE=ON`."));
425+
#endif
401426
} else {
402427
VLOG(6) << "Getting DenseTensor's numpy value";
403428
auto dense_tensor =

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from paddle.framework import core
2727

2828
from ....infer_meta import (
29+
DistInfo,
2930
MetaInfo,
3031
SymbolicBool,
3132
SymbolicFloat,
@@ -482,7 +483,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
482483
# A quick check path for PIR, we don't need dtype conversion for AMP in PIR
483484
meta = self.origin_meta
484485
dtype_str, dtype_free_vars = stringify_pyobject(meta.dtype)
485-
return [
486+
guards = [
486487
# Check rank
487488
StringifiedExpression(
488489
f"len({{}}.shape) == {len(meta.shape)}",
@@ -511,7 +512,68 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
511512
[frame_value_tracer],
512513
union_free_vars(frame_value_tracer.free_vars),
513514
),
515+
# Check whether this tensor is distributed
516+
StringifiedExpression(
517+
f"{{}}.is_dist() is {(self.meta.dist_info is not None)!r}",
518+
[frame_value_tracer],
519+
union_free_vars(frame_value_tracer.free_vars),
520+
),
514521
]
522+
if self.meta.dist_info is not None:
523+
tensor_dist_info = self.meta.dist_info
524+
guards.extend(
525+
[
526+
# check mesh shape
527+
StringifiedExpression(
528+
f"DistInfo.from_tensor({{}}).mesh.shape == {tensor_dist_info.mesh.shape}",
529+
[frame_value_tracer],
530+
union_free_vars(
531+
frame_value_tracer.free_vars,
532+
{
533+
"paddle": paddle,
534+
"DistInfo": DistInfo,
535+
},
536+
),
537+
),
538+
# check mesh process ids
539+
StringifiedExpression(
540+
f"DistInfo.from_tensor({{}}).mesh.process_ids == {tensor_dist_info.mesh.process_ids}",
541+
[frame_value_tracer],
542+
union_free_vars(
543+
frame_value_tracer.free_vars,
544+
{
545+
"paddle": paddle,
546+
"DistInfo": DistInfo,
547+
},
548+
),
549+
),
550+
# check dims mapping
551+
StringifiedExpression(
552+
f"DistInfo.from_tensor({{}}).dims_mapping == {tensor_dist_info.dims_mapping}",
553+
[frame_value_tracer],
554+
union_free_vars(
555+
frame_value_tracer.free_vars,
556+
{
557+
"paddle": paddle,
558+
"DistInfo": DistInfo,
559+
},
560+
),
561+
),
562+
# check local shape
563+
StringifiedExpression(
564+
f"DistInfo.from_tensor({{}}).local_shape == {tensor_dist_info.local_shape}",
565+
[frame_value_tracer],
566+
union_free_vars(
567+
frame_value_tracer.free_vars,
568+
{
569+
"paddle": paddle,
570+
"DistInfo": DistInfo,
571+
},
572+
),
573+
),
574+
]
575+
)
576+
return guards
515577

516578
def get_iter(self):
517579
from .iter import SequenceIterVariable

test/sot/test_sot_distribution.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from test_case_base import (
18+
TestCaseBase,
19+
test_instruction_translator_cache_context,
20+
)
21+
22+
import paddle
23+
import paddle.distributed as dist
24+
from paddle.jit.sot.psdb import check_no_breakgraph
25+
26+
27+
@check_no_breakgraph
28+
def fn(x, y):
29+
return x + y
30+
31+
32+
@unittest.skipIf(
33+
not paddle.is_compiled_with_distribute(),
34+
reason='Not compiled with distribute.',
35+
)
36+
class TestGuardForDistInfo(TestCaseBase):
37+
def test_fn(self):
38+
x = paddle.ones([2, 2])
39+
x.stop_gradient = False
40+
y = paddle.zeros([2, 2])
41+
y.stop_gradient = False
42+
mesh1 = dist.ProcessMesh([0, 1], dim_names=['x'])
43+
mesh2 = dist.ProcessMesh([0, 1], dim_names=['y'])
44+
mesh3 = dist.ProcessMesh([0, 2], dim_names=['x'])
45+
dist_x1 = dist.shard_tensor(
46+
x, mesh1, [dist.Replicate()], stop_gradient=False
47+
)
48+
dist_y1 = dist.shard_tensor(
49+
y, mesh1, [dist.Replicate()], stop_gradient=False
50+
)
51+
dist_x2 = dist.shard_tensor(x, mesh2, [dist.Replicate()])
52+
dist_y2 = dist.shard_tensor(y, mesh2, [dist.Replicate()])
53+
dist_x3 = dist.shard_tensor(x, mesh3, [dist.Replicate()])
54+
dist_y3 = dist.shard_tensor(y, mesh3, [dist.Replicate()])
55+
with test_instruction_translator_cache_context() as ctx:
56+
self.assertEqual(ctx.translate_count, 0)
57+
self.assert_results(fn, dist_x1, dist_y1)
58+
self.assertEqual(ctx.translate_count, 1)
59+
self.assert_results(fn, dist_x2, dist_y2)
60+
self.assertEqual(ctx.translate_count, 1)
61+
self.assert_results(fn, dist_x3, dist_y3)
62+
self.assertEqual(ctx.translate_count, 2)
63+
64+
65+
if __name__ == "__main__":
66+
unittest.main()

0 commit comments

Comments
 (0)