Skip to content

fix #72107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open

fix #72107

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,150 +88,172 @@ bool SameNdMeshReshardFunction::IsSuitable(
}

void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const DistTensor& src_tensor,
const TensorDistAttr& dst_dist_attr,
DistTensor* dst_tensor) {
VLOG(3) << "Call " << Name();
const auto& in_dist_attr = in.dist_attr();
const auto& process_mesh = out_dist_attr.process_mesh();
const auto& src_dist_attr = src_tensor.dist_attr();
int64_t first_diff_axis =
FindFirstDiffShardAxis(src_dist_attr, dst_dist_attr);

int64_t first_diff_axis = FindFirstDiffShardAxis(in_dist_attr, out_dist_attr);
VLOG(3) << "\nsrc_dist_attr: [" << src_dist_attr.to_string() << "] "
<< "\ndst_dist_attr: [" << dst_dist_attr.to_string() << "] "
<< "\nfirst_diff_axis: " << first_diff_axis;

// Backup out_dist_attr to to avoid overwriting the out's dist attr
auto out_dist_attr_orig = out_dist_attr;
auto dst_dist_attr_orig = dst_dist_attr;

SetValue(out, in.value());
SetDistProps(out, in.dims(), in_dist_attr);
SetValue(dst_tensor, src_tensor.value());
SetDistProps(dst_tensor, src_tensor.dims(), src_dist_attr);
const auto& process_mesh = dst_dist_attr.process_mesh();

// 1. change all the partial status to replicated status if needed
if (in_dist_attr.is_partial()) {
// Copy in_dist_attr.partial_status to avoid overwriting the value of
// input when the output and input are the same value
const auto in_partial_status = in_dist_attr.partial_status();
const auto& out_partial_status = out_dist_attr_orig.partial_status();
for (const auto& kv : in_partial_status) {
if (out_partial_status.count(kv.first) != 0 ||
out_dist_attr_orig.is_shard(kv.first)) {
continue;
}
VLOG(3) << "Step1: partial axis " << kv.first;
// 1. change shard status to replicated status
for (int64_t i = first_diff_axis; i >= 0; --i) {
int64_t in_mesh_axis = src_dist_attr.dims_mapping()[i];
int64_t out_mesh_axis = dst_dist_attr_orig.dims_mapping()[i];
if (in_mesh_axis == -1 || in_mesh_axis == out_mesh_axis) {
continue;
} else {
VLOG(3) << "Step1: shard to replicated on axis " << i;
// 1.1 Calculate the dist_attr after this transform
TensorDistAttr real_out_dist_attr(out->dist_attr());
real_out_dist_attr.clean_partial_dims({kv.first});
TensorDistAttr real_out_dist_attr(dst_tensor->dist_attr());
std::vector<int64_t> real_dims_mapping =
real_out_dist_attr.dims_mapping();
real_dims_mapping[i] = -1;
real_out_dist_attr.set_dims_mapping(real_dims_mapping);

// 1.2 Calculate the process_mesh on specific axis
ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, kv.first);
ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, in_mesh_axis);

// 1.3 Calculate the input one dim dist attr
TensorDistAttr in_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr in_one_dim_dist_attr(common::vectorize(src_tensor.dims()));
in_one_dim_dist_attr.set_process_mesh(sub_mesh);
in_one_dim_dist_attr.set_partial_status(std::vector<int64_t>{0},
kv.second);
std::vector<int64_t> in_one_dims_mapping =
in_one_dim_dist_attr.dims_mapping();
in_one_dims_mapping[i] = 0;
in_one_dim_dist_attr.set_dims_mapping(in_one_dims_mapping);

// 1.4 Calculate the output one dim dist attr
TensorDistAttr out_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr out_one_dim_dist_attr(
common::vectorize(src_tensor.dims()));
out_one_dim_dist_attr.set_process_mesh(sub_mesh);

// 1.5 Change from partial to replicated
SetDistProps(out, in_one_dim_dist_attr);

// 1.5 Change from shard to replicated
SetDistProps(dst_tensor, in_one_dim_dist_attr);
DistTensor tmp_result;
PToRReshardFunction func;
func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result);
SToRReshardFunction func;
func.Eval(dev_ctx, *dst_tensor, out_one_dim_dist_attr, &tmp_result);

// 1.6 Reset to the right dist attr
SetValue(out, tmp_result.value());
SetDistProps(out, real_out_dist_attr);
SetValue(dst_tensor, tmp_result.value());
SetDistProps(dst_tensor, real_out_dist_attr);
}
}

// 2. change all the shard status to replicated status
for (int64_t i = first_diff_axis; i >= 0; --i) {
int64_t in_mesh_axis = out->dist_attr().dims_mapping()[i];
if (in_mesh_axis != -1) {
VLOG(3) << "Step2: in_mesh axis " << in_mesh_axis;
VLOG(6) << "After step1 shard to replicated: dist_attr of dst_tensor: "
<< dst_tensor->dist_attr().to_string();

// 2. change all the partial status to replicated status if needed
if (src_dist_attr.is_partial()) {
// Copy in_dist_attr.partial_status to avoid overwriting the value of
// input when the output and input are the same value
const auto in_partial_status = src_dist_attr.partial_status();
const auto& out_partial_status = dst_dist_attr_orig.partial_status();
for (const auto& kv : in_partial_status) {
auto partial_dim = kv.first;
auto partial_type = kv.second;
if (out_partial_status.count(partial_dim) != 0 ||
dst_dist_attr_orig.is_shard(partial_dim)) {
continue;
}
VLOG(3) << "Step2: partial to replicated on axis " << partial_dim;
// 2.1 Calculate the dist_attr after this transform
TensorDistAttr real_out_dist_attr(out->dist_attr());
std::vector<int64_t> real_dims_mapping =
real_out_dist_attr.dims_mapping();
real_dims_mapping[i] = -1;
real_out_dist_attr.set_dims_mapping(real_dims_mapping);
TensorDistAttr real_out_dist_attr(dst_tensor->dist_attr());
real_out_dist_attr.clean_partial_dims({partial_dim});

// 2.2 Calculate the process_mesh on specific axis
ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, in_mesh_axis);
ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, partial_dim);

// 2.3 Calculate the input one dim dist attr
TensorDistAttr in_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr in_one_dim_dist_attr(common::vectorize(src_tensor.dims()));
in_one_dim_dist_attr.set_process_mesh(sub_mesh);
std::vector<int64_t> in_one_dims_mapping =
in_one_dim_dist_attr.dims_mapping();
in_one_dims_mapping[i] = 0;
in_one_dim_dist_attr.set_dims_mapping(in_one_dims_mapping);
in_one_dim_dist_attr.set_partial_status(std::vector<int64_t>{0},
partial_type);

// 2.4 Calculate the output one dim dist attr
TensorDistAttr out_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr out_one_dim_dist_attr(
common::vectorize(src_tensor.dims()));
out_one_dim_dist_attr.set_process_mesh(sub_mesh);

// 2.5 Change from shard to replicated
SetDistProps(out, in_one_dim_dist_attr);
// 2.5 Change from partial to replicated
SetDistProps(dst_tensor, in_one_dim_dist_attr);

DistTensor tmp_result;
SToRReshardFunction func;
func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result);
PToRReshardFunction func;
func.Eval(dev_ctx, *dst_tensor, out_one_dim_dist_attr, &tmp_result);

// 2.6 Reset to the right dist attr
SetValue(out, tmp_result.value());
SetDistProps(out, real_out_dist_attr);
SetValue(dst_tensor, tmp_result.value());
SetDistProps(dst_tensor, real_out_dist_attr);
}
}
VLOG(6) << "After step2 partial to replicated: dist_attr of dst_tensor: "
<< dst_tensor->dist_attr().to_string();

// 3. Change replicated to partial
if (out_dist_attr_orig.is_partial()) {
const auto& in_partial_status = out->dist_attr().partial_status();
const auto& out_partial_status = out_dist_attr_orig.partial_status();
if (dst_dist_attr_orig.is_partial()) {
const auto& in_partial_status = dst_tensor->dist_attr().partial_status();
const auto& out_partial_status = dst_dist_attr_orig.partial_status();
for (const auto& kv : out_partial_status) {
if (in_partial_status.count(kv.first) != 0) {
continue;
}
VLOG(3) << "Step3: Partial status mesh axis " << kv.first;
VLOG(3) << "Step3: replicated to partial on axis " << kv.first;
// 3.1 Calculate the dist_attr after this transform
TensorDistAttr real_out_dist_attr(out->dist_attr());
TensorDistAttr real_out_dist_attr(dst_tensor->dist_attr());
real_out_dist_attr.set_partial_status(std::vector<int64_t>{kv.first});

// 3.2 Calculate the process_mesh on specific axis
ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, kv.first);

// 3.3 Calculate the input one dim dist attr
TensorDistAttr in_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr in_one_dim_dist_attr(common::vectorize(src_tensor.dims()));
in_one_dim_dist_attr.set_process_mesh(sub_mesh);

// 3.4 Calculate the output one dim dist attr
TensorDistAttr out_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr out_one_dim_dist_attr(
common::vectorize(src_tensor.dims()));
out_one_dim_dist_attr.set_process_mesh(sub_mesh);
out_one_dim_dist_attr.set_partial_status(std::vector<int64_t>{0});

// 3.5 Change from partial to replicated
DistTensor tmp_result;
SetDistProps(out, in_one_dim_dist_attr);
SetDistProps(dst_tensor, in_one_dim_dist_attr);
RToPReshardFunction func;
func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result);
func.Eval(dev_ctx, *dst_tensor, out_one_dim_dist_attr, &tmp_result);

// 3.6 Reset to the right dist attr
SetValue(out, tmp_result.value());
SetDistProps(out, real_out_dist_attr);
SetValue(dst_tensor, tmp_result.value());
SetDistProps(dst_tensor, real_out_dist_attr);
}
}
VLOG(6) << "After step3 replicated to partial: dist_attr of dst_tensor: "
<< dst_tensor->dist_attr().to_string();

// 4. Change replicated/partial to shard
for (int64_t i = first_diff_axis; i >= 0; --i) {
int64_t out_mesh_axis = out_dist_attr_orig.dims_mapping()[i];
if (out_mesh_axis != -1) {
const auto& in_partial_status = out->dist_attr().partial_status();
int64_t in_mesh_axis = src_dist_attr.dims_mapping()[i];
int64_t out_mesh_axis = dst_dist_attr_orig.dims_mapping()[i];
if (out_mesh_axis == -1 || in_mesh_axis == out_mesh_axis) {
continue;
} else {
const auto& in_partial_status = dst_tensor->dist_attr().partial_status();
bool is_partial = in_partial_status.count(out_mesh_axis) != 0;

VLOG(3) << "Step4: out_mesh axis : " << out_mesh_axis
VLOG(3) << "Step4: replicated/partial to shard on axis " << out_mesh_axis
<< "; partial state :" << is_partial;

// 4.1 Calculate the dist_attr after this transform
TensorDistAttr real_out_dist_attr(out->dist_attr());
TensorDistAttr real_out_dist_attr(dst_tensor->dist_attr());
std::vector<int64_t> real_dims_mapping =
real_out_dist_attr.dims_mapping();
real_dims_mapping[i] = out_mesh_axis;
Expand All @@ -244,11 +266,12 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx,
ProcessMesh sub_mesh = GetSubProcessMesh(process_mesh, out_mesh_axis);

// 4.3 Calculate the input one dim dist attr
TensorDistAttr in_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr in_one_dim_dist_attr(common::vectorize(src_tensor.dims()));
in_one_dim_dist_attr.set_process_mesh(sub_mesh);

// 4.4 Calculate the output one dim dist attr
TensorDistAttr out_one_dim_dist_attr(common::vectorize(in.dims()));
TensorDistAttr out_one_dim_dist_attr(
common::vectorize(src_tensor.dims()));
out_one_dim_dist_attr.set_process_mesh(sub_mesh);
std::vector<int64_t> out_one_dims_mapping =
out_one_dim_dist_attr.dims_mapping();
Expand All @@ -257,19 +280,22 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx,

// 4.5 Change from replicated to shard
DistTensor tmp_result;
SetDistProps(out, in_one_dim_dist_attr);
SetDistProps(dst_tensor, in_one_dim_dist_attr);
if (is_partial) {
PToSReshardFunction func;
func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result);
func.Eval(dev_ctx, *dst_tensor, out_one_dim_dist_attr, &tmp_result);
} else {
RToSReshardFunction func;
func.Eval(dev_ctx, *out, out_one_dim_dist_attr, &tmp_result);
func.Eval(dev_ctx, *dst_tensor, out_one_dim_dist_attr, &tmp_result);
}
// 4.6 Reset to the right dist attr
SetValue(out, tmp_result.value());
SetDistProps(out, real_out_dist_attr);
SetValue(dst_tensor, tmp_result.value());
SetDistProps(dst_tensor, real_out_dist_attr);
}
}
VLOG(6)
<< "After step4 replicated/partial to shard: dist_attr of dst_tensor: "
<< dst_tensor->dist_attr().to_string();
}

bool CrossNdMeshReshardFunction::IsSuitable(
Expand Down
Loading