Skip to content

Commit 496947d

Browse files
authored
add call to spmd in expand infermeta when use pir (#71214)
1 parent 399da2a commit 496947d

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

+39
Original file line numberDiff line numberDiff line change
@@ -3695,6 +3695,45 @@ std::vector<pir::Type> ExpandOp::InferMeta(
36953695
dense_out.layout(),
36963696
dense_out.lod(),
36973697
dense_out.offset());
3698+
3699+
// Auto Parallel condition
3700+
#ifdef PADDLE_WITH_DISTRIBUTE
3701+
ProcessMeshAttribute op_mesh;
3702+
if (HasDistInput(input_values, &op_mesh)) {
3703+
CvtAllInputsToDist(input_values, op_mesh);
3704+
auto ctx = pir::IrContext::Instance();
3705+
std::vector<pir::Attribute> dist_operand_attrs, dist_result_attrs;
3706+
auto dist_meta_x =
3707+
CvtToDistMetaTensor(x_.type().dyn_cast<DistDenseTensorType>());
3708+
// Todo(jeff41404): When expand adds spmd rules, synchronous modifications
3709+
// are required here.
3710+
auto spmd_info =
3711+
phi::distributed::VariadicReplicatedInferSpmdDynamic(dist_meta_x);
3712+
PADDLE_ENFORCE_EQ(
3713+
spmd_info.first.size(),
3714+
1u,
3715+
common::errors::Unavailable(
3716+
"Size of spmd_info.first for op[ExpandOp]is unexpected."));
3717+
for (auto &arg_dist : spmd_info.first) {
3718+
dist_operand_attrs.push_back(CvtToPirAttr(arg_dist));
3719+
}
3720+
3721+
for (int i = 1; i < 2; ++i) {
3722+
dist_operand_attrs.push_back(GetTensorDistAttr(input_values[i].type()));
3723+
}
3724+
3725+
auto dist_attr_out =
3726+
CreateReplicatedDistAttr(out_dense_tensor_type, op_mesh);
3727+
3728+
dist_result_attrs.push_back(dist_attr_out);
3729+
argument_outputs.push_back(
3730+
CvtToPirDistType(out_dense_tensor_type, dist_attr_out));
3731+
3732+
(*p_attributes)[kAttrOpDistAttr] = OperationDistAttribute::get(
3733+
ctx, op_mesh, dist_operand_attrs, dist_result_attrs);
3734+
return argument_outputs;
3735+
}
3736+
#endif
36983737
argument_outputs.push_back(out_dense_tensor_type);
36993738
return argument_outputs;
37003739
}

0 commit comments

Comments
 (0)