@@ -3695,6 +3695,45 @@ std::vector<pir::Type> ExpandOp::InferMeta(
3695
3695
dense_out.layout (),
3696
3696
dense_out.lod (),
3697
3697
dense_out.offset ());
3698
+
3699
+ // Auto Parallel condition
3700
+ #ifdef PADDLE_WITH_DISTRIBUTE
3701
+ ProcessMeshAttribute op_mesh;
3702
+ if (HasDistInput (input_values, &op_mesh)) {
3703
+ CvtAllInputsToDist (input_values, op_mesh);
3704
+ auto ctx = pir::IrContext::Instance ();
3705
+ std::vector<pir::Attribute> dist_operand_attrs, dist_result_attrs;
3706
+ auto dist_meta_x =
3707
+ CvtToDistMetaTensor (x_.type ().dyn_cast <DistDenseTensorType>());
3708
+ // Todo(jeff41404): When expand adds spmd rules, synchronous modifications
3709
+ // are required here.
3710
+ auto spmd_info =
3711
+ phi::distributed::VariadicReplicatedInferSpmdDynamic (dist_meta_x);
3712
+ PADDLE_ENFORCE_EQ (
3713
+ spmd_info.first .size (),
3714
+ 1u ,
3715
+ common::errors::Unavailable (
3716
+ " Size of spmd_info.first for op[ExpandOp]is unexpected." ));
3717
+ for (auto &arg_dist : spmd_info.first ) {
3718
+ dist_operand_attrs.push_back (CvtToPirAttr (arg_dist));
3719
+ }
3720
+
3721
+ for (int i = 1 ; i < 2 ; ++i) {
3722
+ dist_operand_attrs.push_back (GetTensorDistAttr (input_values[i].type ()));
3723
+ }
3724
+
3725
+ auto dist_attr_out =
3726
+ CreateReplicatedDistAttr (out_dense_tensor_type, op_mesh);
3727
+
3728
+ dist_result_attrs.push_back (dist_attr_out);
3729
+ argument_outputs.push_back (
3730
+ CvtToPirDistType (out_dense_tensor_type, dist_attr_out));
3731
+
3732
+ (*p_attributes)[kAttrOpDistAttr ] = OperationDistAttribute::get (
3733
+ ctx, op_mesh, dist_operand_attrs, dist_result_attrs);
3734
+ return argument_outputs;
3735
+ }
3736
+ #endif
3698
3737
argument_outputs.push_back (out_dense_tensor_type);
3699
3738
return argument_outputs;
3700
3739
}
0 commit comments