diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index f3cb62afe32d53..738db7ff24d7fd 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -398,6 +398,31 @@ static PyObject* tensor_method_numpy(TensorObject* self, dense_tensor->place(), dense_tensor->Holder()->ptr(), dense_tensor->Holder()->size()); + } else if (self->tensor.is_dist_tensor()) { +#ifdef PADDLE_WITH_DISTRIBUTE + VLOG(6) << "Getting DistTensor's numpy value"; + auto* dist_tensor = + static_cast(self->tensor.impl().get()); + auto dense_tensor = ReshardXToReplicated(dist_tensor); + + cpu_tensor.set_meta(dense_tensor.meta()); + auto tmp_allocation_ptr = + memory::Alloc(cpu_place, dense_tensor.Holder()->size()); + cpu_tensor.ResetHolder(std::shared_ptr( + tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter())); + paddle::memory::Copy(place, + cpu_tensor.Holder()->ptr(), + dense_tensor.place(), + dense_tensor.Holder()->ptr(), + dense_tensor.Holder()->size()); +#else + PADDLE_THROW( + common::errors::Unavailable("The `numpy()` method of (Dist)Tensor " + "is not supported in the current " + "PaddlePaddle, please recompile and " + "installPaddlePaddle with the option " + "of `WITH_DISTRIBUTE=ON`.")); +#endif } else { VLOG(6) << "Getting DenseTensor's numpy value"; auto dense_tensor = diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index 42b1d53e8a1657..09b6a7014464cb 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -26,6 +26,7 @@ from paddle.framework import core from ....infer_meta import ( + DistInfo, MetaInfo, SymbolicBool, SymbolicFloat, @@ -482,7 +483,7 @@ def make_stringified_guard(self) -> list[StringifiedExpression]: # A quick check path for PIR, we don't need dtype conversion for AMP in PIR meta = self.origin_meta dtype_str, dtype_free_vars = stringify_pyobject(meta.dtype) - return [ + guards = [ # Check rank StringifiedExpression( f"len({{}}.shape) == {len(meta.shape)}", @@ -511,7 +512,68 @@ def make_stringified_guard(self) -> list[StringifiedExpression]: [frame_value_tracer], union_free_vars(frame_value_tracer.free_vars), ), + # Check whether this tensor is distributed + StringifiedExpression( + f"{{}}.is_dist() is {(self.meta.dist_info is not None)!r}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), ] + if self.meta.dist_info is not None: + tensor_dist_info = self.meta.dist_info + guards.extend( + [ + # check mesh shape + StringifiedExpression( + f"DistInfo.from_tensor({{}}).mesh.shape == {tensor_dist_info.mesh.shape}", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + { + "paddle": paddle, + "DistInfo": DistInfo, + }, + ), + ), + # check mesh process ids + StringifiedExpression( + f"DistInfo.from_tensor({{}}).mesh.process_ids == {tensor_dist_info.mesh.process_ids}", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + { + "paddle": paddle, + "DistInfo": DistInfo, + }, + ), + ), + # check dims mapping + StringifiedExpression( + f"DistInfo.from_tensor({{}}).dims_mapping == {tensor_dist_info.dims_mapping}", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + { + "paddle": paddle, + "DistInfo": DistInfo, + }, + ), + ), + # check local shape + StringifiedExpression( + f"DistInfo.from_tensor({{}}).local_shape == {tensor_dist_info.local_shape}", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + { + "paddle": paddle, + "DistInfo": DistInfo, + }, + ), + ), + ] + ) + return guards def get_iter(self): from .iter import SequenceIterVariable diff --git a/test/sot/test_sot_distribution.py b/test/sot/test_sot_distribution.py new file mode 100644 index 00000000000000..285dbfcb141036 --- /dev/null +++ b/test/sot/test_sot_distribution.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle +import paddle.distributed as dist +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def fn(x, y): + return x + y + + +@unittest.skipIf( + not paddle.is_compiled_with_distribute(), + reason='Not compiled with distribute.', +) +class TestGuardForDistInfo(TestCaseBase): + def test_fn(self): + x = paddle.ones([2, 2]) + x.stop_gradient = False + y = paddle.zeros([2, 2]) + y.stop_gradient = False + mesh1 = dist.ProcessMesh([0, 1], dim_names=['x']) + mesh2 = dist.ProcessMesh([0, 1], dim_names=['y']) + mesh3 = dist.ProcessMesh([0, 2], dim_names=['x']) + dist_x1 = dist.shard_tensor( + x, mesh1, [dist.Replicate()], stop_gradient=False + ) + dist_y1 = dist.shard_tensor( + y, mesh1, [dist.Replicate()], stop_gradient=False + ) + dist_x2 = dist.shard_tensor(x, mesh2, [dist.Replicate()]) + dist_y2 = dist.shard_tensor(y, mesh2, [dist.Replicate()]) + dist_x3 = dist.shard_tensor(x, mesh3, [dist.Replicate()]) + dist_y3 = dist.shard_tensor(y, mesh3, [dist.Replicate()]) + with test_instruction_translator_cache_context() as ctx: + self.assertEqual(ctx.translate_count, 0) + self.assert_results(fn, dist_x1, dist_y1) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(fn, dist_x2, dist_y2) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(fn, dist_x3, dist_y3) + self.assertEqual(ctx.translate_count, 2) + + +if __name__ == "__main__": + unittest.main()