Skip to content

[spmd rule]reshard partial input to replicate #60215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions paddle/phi/infermeta/spmd_rules/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x,
int x_ndim = x_shape.size();
auto out_shape = common::vectorize(out.dims());
int out_ndim = out_shape.size();
TensorDistAttr out_dist_attr = out.dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr.dims_mapping();
TensorDistAttr out_dist_attr_src = out.dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
out_ndim,
out_dims_mapping.size(),
Expand Down Expand Up @@ -165,8 +165,10 @@ SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x,
// step2.2: Infer input dims mapping from merged input dims mapping
std::vector<int64_t> x_dims_mapping =
GetDimsMappingForAxes(x_axes, axis_to_dim_map);
TensorDistAttr x_dist_attr(x.dist_attr());
auto x_dist_attr = CopyTensorDistAttrForOutput(out_dist_attr_src);
x_dist_attr.set_dims_mapping(x_dims_mapping);
auto out_dist_attr_dst = CopyTensorDistAttrForOutput(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(x_dims_mapping);

// Step3: Handle partial
// Handle output tensor partial (TODO)
Expand All @@ -175,7 +177,7 @@ SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x,
<< "dims_mapping: [" << str_join(out_dims_mapping) << "] ";
VLOG(4) << "Input0 dims_mapping: [" + str_join(x_dims_mapping) + "]\n\n";

return {{x_dist_attr}, {out_dist_attr}};
return {{x_dist_attr}, {out_dist_attr_dst}};
}

SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x,
Expand Down Expand Up @@ -287,15 +289,18 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x,
ShardingMergeForTensors({{out_axes, out_dims_mapping}});

// Step2.2: Infer input dims mappings from merged output dims mapping
TensorDistAttr x_dist_attr_dst = x.dist_attr();
TensorDistAttr y_dist_attr_dst = y.dist_attr();
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x.dist_attr());
TensorDistAttr y_dist_attr_dst = CopyTensorDistAttrForOutput(y.dist_attr());
std::vector<int64_t> x_dims_mapping =
GetDimsMappingForAxes(x_axes, axis_to_dim_map);
std::vector<int64_t> y_dims_mapping =
GetDimsMappingForAxes(y_axes, axis_to_dim_map);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
y_dist_attr_dst.set_dims_mapping(y_dims_mapping);

auto out_dist_attr_dst = CopyTensorDistAttrForOutput(out_dist_attr);
out_dist_attr_dst.set_dims_mapping(out_dims_mapping);

// Step3: Handle partial
// Handle input tensor partial (TODO)
VLOG(4) << "ElementwiseSPMDRule InferReverse:";
Expand All @@ -306,7 +311,7 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x,
VLOG(4) << "Input1 shape: [" << str_join(y_shape) << "] "
<< "dims_mapping: [" << str_join(y_dims_mapping) << "]\n\n";

return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}};
return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr_dst}};
}

SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/spmd_rules/embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ SpmdInfo EmbeddingGradInferSpmd(const DistMetaTensor& x,
// update input dims mapping with merged shardings.
t0_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(t0_axes, axis_to_dim_map));
auto out_grad_dst_dist_attr = out_grad_dst.dist_attr();
auto out_grad_dst_dist_attr =
CopyTensorDistAttrForOutput(out_grad_dst.dist_attr());
out_grad_dst_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(out_grad_dst_axes, axis_to_dim_map));

Expand Down
9 changes: 5 additions & 4 deletions paddle/phi/infermeta/spmd_rules/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ SpmdInfo FlattenInferSpmd(const DistMetaTensor& x,

// Step3: Update the dist attributes of input
// and output with the inferred dims mapping.
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr out_dist_attr(x_dist_attr_src);
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "FlattenInferSpmd: X shape: [" << str_join(src_shape) << "]";
Expand Down Expand Up @@ -178,9 +178,10 @@ SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x,

// Step3: Update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
TensorDistAttr out_dist_attr_dst =
CopyTensorDistAttrForOutput(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr x_dist_attr(x.dist_attr());
TensorDistAttr x_dist_attr = CopyTensorDistAttrForOutput(x.dist_attr());
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "FlattenInferSpmdReverse: Out shape: [" << str_join(out_shape)
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/infermeta/spmd_rules/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ SpmdInfo ReductionInferSpmdBase(const DistMetaTensor& x,

// initialize output dist_attr's process_mesh, batch_dim and dynamic dims with
// input dist_attr.
auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);

TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(out_dims_mapping);

Expand All @@ -126,7 +129,7 @@ SpmdInfo ReductionInferSpmdBase(const DistMetaTensor& x,
<< "partial_on_dims: [" + str_join(partial_on_dims)
<< " with reduce_type " << reduce_type << "]\n\n";

return {{x_dist_attr_src}, {out_dist_attr}};
return {{x_dist_attr_dst}, {out_dist_attr}};
}

SpmdInfo ReductionInferSpmd(const DistMetaTensor& x,
Expand Down Expand Up @@ -203,7 +206,7 @@ SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x,

// initialize input dist_attr's process_mesh, batch_dim and dynamic dims with
// input dist_attr.
TensorDistAttr x_dist_attr_dst(x.dist_attr());
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x.dist_attr());
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);

// Step3: handle partial (TODO)
Expand Down
14 changes: 9 additions & 5 deletions paddle/phi/infermeta/spmd_rules/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input,

// Step2.3 get new dist attribute for input. the sliced
// cannot be sharded, if it is sharded, set it to replicated.
TensorDistAttr input_dist_attr_dst(input_dist_attr_src);
TensorDistAttr input_dist_attr_dst =
CopyTensorDistAttrForOutput(input_dist_attr_src);
for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_dims_mapping[axis] = -1;
Expand Down Expand Up @@ -181,7 +182,9 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input,
// Step2.2: infer input dims mapping from output dims mapping. the sliced
// cannot be sharded, if it is sharded, set it to replicated.
input_dims_mapping = GetDimsMappingForAxes(input_axes, axis_to_dim_map, true);
input_dist_attr.set_dims_mapping(input_dims_mapping);

auto input_dist_attr_dst = CopyTensorDistAttrForOutput(input_dist_attr);
input_dist_attr_dst.set_dims_mapping(input_dims_mapping);

// step2.3 get new dist attribute for output. the sliced
// cannot be sharded, if it is sharded, set it to replicated.
Expand All @@ -190,7 +193,8 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input,
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
out_dims_mapping[axis] = -1;
}
out_dist_attr.set_dims_mapping(out_dims_mapping);
auto out_dist_attr_dst = CopyTensorDistAttrForOutput(out_dist_attr);
out_dist_attr_dst.set_dims_mapping(out_dims_mapping);

VLOG(4) << "SliceInferSpmdReverse:";
VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes;
Expand All @@ -199,12 +203,12 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input,
<< "axes: [" << str_join(axes) << "] "
<< "src_dims_mapping: ["
<< str_join(output.dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping())
<< "dst_dims_mapping: [" << str_join(out_dist_attr_dst.dims_mapping())
<< "]";
VLOG(4) << "Input shape: [" << str_join(input_shape) << "] "
<< "dims_mapping: [" << str_join(input_dims_mapping) << "]\n\n";

return {{input_dist_attr}, {out_dist_attr}};
return {{input_dist_attr_dst}, {out_dist_attr_dst}};
}

SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/infermeta/spmd_rules/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,12 @@ SpmdInfo SoftmaxInferSpmdReverse(const DistMetaTensor& x,
// infer input's dims mapping.
std::vector<int64_t> x_dims_mapping =
GetDimsMappingForAxes(x_axes, axis_to_dim_map);
TensorDistAttr x_dist_attr(x.dist_attr());
TensorDistAttr x_dist_attr = CopyTensorDistAttrForOutput(x.dist_attr());
x_dist_attr.set_dims_mapping(x_dims_mapping);

// update output's dims mapping.
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
TensorDistAttr out_dist_attr_dst =
CopyTensorDistAttrForOutput(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(out_dims_mapping);

VLOG(4) << "SoftmaxInferSpmdReverse:\n"
Expand Down
11 changes: 7 additions & 4 deletions paddle/phi/infermeta/spmd_rules/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis) {

// Step2.3 get new dist attribute for input. the splitted
// cannot be sharded, if it is sharded, set it to replicated.
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dims_mapping[axis] = -1;
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);

Expand Down Expand Up @@ -168,13 +168,16 @@ SpmdInfo SplitWithNumInferSpmdReverse(
// the split axis in input is set to -1.
x_dims_mapping = GetDimsMappingForAxes(x_axes, axis_to_dim_map, true);
x_dims_mapping[axis] = -1;
x_dist_attr.set_dims_mapping(x_dims_mapping);

auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);

// step2.3 get new dist attribute for output. the splitted
// cannot be sharded, if it is sharded, set it to replicated.
std::vector<TensorDistAttr> out_dist_attrs;
for (int i = 0; i < nouts; i++) {
out_dist_attrs.emplace_back(outs[i]->dist_attr());
out_dist_attrs.emplace_back(
CopyTensorDistAttrForOutput(outs[i]->dist_attr()));
std::vector<int64_t> out_dims_mapping =
GetDimsMappingForAxes(out_axes, axis_to_dim_map, true);
out_dims_mapping[axis] = -1;
Expand All @@ -197,7 +200,7 @@ SpmdInfo SplitWithNumInferSpmdReverse(
<< "dims_mapping: [" << str_join(x_dims_mapping) << "]\n\n";
// TODO(liuzhenhai): remedy this
// return {{x_dist_attr}, {out_dist_attrs}};
return {{x_dist_attr}, ToArgDistAttr(out_dist_attrs)};
return {{x_dist_attr_dst}, ToArgDistAttr(out_dist_attrs)};
}

SpmdInfo SplitInferSpmd(const DistMetaTensor& x,
Expand Down
9 changes: 5 additions & 4 deletions paddle/phi/infermeta/spmd_rules/squeeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,

// Step3: Update the dist attributes of input
// and output with the inferred dims mapping.
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
if (x_dist_attr_dst.dynamic_dims().size() !=
x_dist_attr_dst.dims_mapping().size()) {
VLOG(4) << "SqueezeInferSPMD change x dist attr dynamic dims";
x_dist_attr_dst.set_default_dynamic_dims(x_dist_attr_dst.dims_mapping());
}
TensorDistAttr out_dist_attr(x_dist_attr_src);
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
if (out_dist_attr.dynamic_dims().size() !=
out_dist_attr.dims_mapping().size()) {
Expand Down Expand Up @@ -219,15 +219,16 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,

// Step3: Update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
TensorDistAttr out_dist_attr_dst =
CopyTensorDistAttrForOutput(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
if (out_dist_attr_dst.dynamic_dims().size() !=
out_dist_attr_dst.dims_mapping().size()) {
VLOG(4) << "SqueezeInferSPMD change output dist attr dynamic dims";
out_dist_attr_dst.set_default_dynamic_dims(
out_dist_attr_dst.dims_mapping());
}
TensorDistAttr x_dist_attr(x.dist_attr());
TensorDistAttr x_dist_attr = CopyTensorDistAttrForOutput(x.dist_attr());
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
if (x_dist_attr.dynamic_dims().size() != x_dist_attr.dims_mapping().size()) {
VLOG(4) << "SqueezeInferSPMD change x dist attr dynamic dims";
Expand Down
17 changes: 13 additions & 4 deletions paddle/phi/infermeta/spmd_rules/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ SpmdInfo TransposeInferSpmd(const DistMetaTensor& x,
std::vector<int64_t> out_dims_mapping =
GetDimsMappingForAxes(out_axes, axis_to_dim_map);

auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);

// initialize output dist_attr's process_mesh, batch_dim and dynamic dims with
// input dist_attr.
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
Expand All @@ -99,7 +102,7 @@ SpmdInfo TransposeInferSpmd(const DistMetaTensor& x,
VLOG(4) << "Perm: [" << str_join(perm) << "]";
VLOG(4) << "Output dims_mapping: [" + str_join(out_dims_mapping) + "]\n\n";

return {{x_dist_attr_src}, {out_dist_attr}};
return {{x_dist_attr_dst}, {out_dist_attr}};
}

SpmdInfo TransposeInferSpmdReverse(const DistMetaTensor& x,
Expand Down Expand Up @@ -156,6 +159,9 @@ SpmdInfo TransposeInferSpmdReverse(const DistMetaTensor& x,
TensorDistAttr x_dist_attr = CopyTensorDistAttrForOutput(x.dist_attr());
x_dist_attr.set_dims_mapping(x_dims_mapping);

auto out_dist_attr_dst = CopyTensorDistAttrForOutput(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(out_dims_mapping);

// Step3 Handle partial (TODO)

VLOG(4) << "TransposeInferSpmdReverse:";
Expand All @@ -165,7 +171,7 @@ SpmdInfo TransposeInferSpmdReverse(const DistMetaTensor& x,
VLOG(4) << "Input shape: [" << str_join(x_shape) << "] "
<< "dims_mapping: [" << str_join(x_dims_mapping) << "]\n\n";

return {{x_dist_attr}, {out_dist_attr_src}};
return {{x_dist_attr}, {out_dist_attr_dst}};
}

SpmdInfo TransposeGradInferSpmd(const DistMetaTensor& out_grad,
Expand Down Expand Up @@ -196,9 +202,12 @@ SpmdInfo TransposeGradInferSpmd(const DistMetaTensor& out_grad,
int origin_index = perm[i] >= 0 ? perm[i] : out_grad_ndim + perm[i];
x_dims_mapping[origin_index] = out_grad_dims_mapping[i];
}
TensorDistAttr x_grad_dist_attr = out_grad.dist_attr();

auto out_grad_dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr());
out_grad_dist_attr.set_dims_mapping(out_grad_dims_mapping);
auto x_grad_dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr());
x_grad_dist_attr.set_dims_mapping(x_dims_mapping);
return {{out_grad.dist_attr()}, {x_grad_dist_attr}};
return {{out_grad_dist_attr}, {x_grad_dist_attr}};
}

} // namespace distributed
Expand Down
9 changes: 5 additions & 4 deletions paddle/phi/infermeta/spmd_rules/unsqueeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x,

// Step3: Update the dist attributes of input
// and output with the inferred dims mapping.
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
if (x_dist_attr_dst.dynamic_dims().size() !=
x_dist_attr_dst.dims_mapping().size()) {
VLOG(4) << "UnSqueezeInferSPMD change output dist attr dynamic dims";
x_dist_attr_dst.set_default_dynamic_dims(x_dist_attr_dst.dims_mapping());
}
TensorDistAttr out_dist_attr(x_dist_attr_src);
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
if (out_dist_attr.dynamic_dims().size() !=
out_dist_attr.dims_mapping().size()) {
Expand Down Expand Up @@ -199,15 +199,16 @@ SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x,

// Step3: Update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
TensorDistAttr out_dist_attr_dst =
CopyTensorDistAttrForOutput(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
if (out_dist_attr_dst.dynamic_dims().size() !=
out_dist_attr_dst.dims_mapping().size()) {
VLOG(4) << "UnSqueezeInferSPMDReverse change output dist attr dynamic dims";
out_dist_attr_dst.set_default_dynamic_dims(
out_dist_attr_dst.dims_mapping());
}
TensorDistAttr x_dist_attr(x.dist_attr());
TensorDistAttr x_dist_attr = CopyTensorDistAttrForOutput(x.dist_attr());
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
if (x_dist_attr.dynamic_dims().size() != x_dist_attr.dims_mapping().size()) {
VLOG(4) << "UnSqueezeInferSPMDReverse change x dist attr dynamic dims";
Expand Down