Skip to content

Commit f24d463

Browse files
[AutoParallel] Fix PHI API inplace output code generation. (#59133)
1 parent 893235e commit f24d463

File tree

7 files changed

+393
-168
lines changed

7 files changed

+393
-168
lines changed

paddle/phi/api/lib/api_gen_utils.cc

+78-86
Original file line numberDiff line numberDiff line change
@@ -547,25 +547,95 @@ std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(
547547
}
548548

549549
phi::distributed::DistTensor* SetKernelDistOutput(
550-
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
550+
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
551+
PADDLE_ENFORCE_EQ(
552+
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
553+
true,
554+
phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr"));
551555
if (out) {
552556
if (out->impl() == nullptr) {
553-
auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
554-
dist_attr);
557+
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
558+
phi::DDim(), paddle::get<0>(dist_attr));
555559
out->set_impl(dist_t);
556560
}
557561
return static_cast<phi::distributed::DistTensor*>(out->impl().get());
558562
}
559563
return nullptr;
560564
}
561565

562-
phi::distributed::DistTensor* SetKernelDistOutput(
563-
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
566+
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
567+
size_t out_size, std::vector<Tensor>* out) {
568+
std::vector<phi::distributed::DistTensor*> results(out_size);
569+
if (out->size() != out_size) {
570+
// Empty out vector
571+
out->reserve(out_size);
572+
}
573+
for (size_t i = 0; i < out_size; ++i) {
574+
if (out->size() != out_size) {
575+
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
576+
out->emplace_back();
577+
out->back().set_impl(dist_t);
578+
}
579+
results[i] =
580+
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
581+
}
582+
return results;
583+
}
584+
585+
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
586+
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
564587
PADDLE_ENFORCE_EQ(
565-
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
588+
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
589+
dist_attr),
566590
true,
567-
phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr"));
568-
return SetKernelDistOutput(out, paddle::get<0>(dist_attr));
591+
phi::errors::PreconditionNotMet(
592+
"Arg must be a vector of TensorDistAttr"));
593+
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
594+
PADDLE_GET_CONST(std::vector<phi::distributed::TensorDistAttr>,
595+
dist_attr);
596+
auto out_size = dist_attrs.size();
597+
std::vector<phi::distributed::DistTensor*> results(out_size);
598+
// TODO(GhostScreaming): Inplace outputs are initialized, just set their
599+
// dist_attr.
600+
if (out->size() == out_size) {
601+
VLOG(3) << "Outputs are inplace vector Tensors, just set their dist_attrs "
602+
<< "according to InferSPMD output result.";
603+
for (size_t i = 0; i < out_size; ++i) {
604+
results[i] =
605+
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
606+
results[i]->unsafe_set_dist_attr(dist_attrs[i]);
607+
}
608+
} else {
609+
out->reserve(out_size);
610+
for (size_t i = 0; i < out_size; ++i) {
611+
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
612+
phi::DDim(), dist_attrs[i]);
613+
results[i] = dist_t.get();
614+
out->emplace_back();
615+
out->back().set_impl(dist_t);
616+
}
617+
}
618+
return results;
619+
}
620+
621+
// For backward
622+
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
623+
std::vector<Tensor*> out) {
624+
std::vector<phi::distributed::DistTensor*> result;
625+
for (auto tmp : out) {
626+
if (tmp) {
627+
// TODO(GhostScreaming): now all dist case are nullptr
628+
if (tmp->impl() == nullptr) {
629+
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
630+
tmp->set_impl(dist_t);
631+
}
632+
result.emplace_back(
633+
static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
634+
} else {
635+
result.emplace_back(nullptr);
636+
}
637+
}
638+
return result;
569639
}
570640

571641
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
@@ -609,84 +679,6 @@ std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
609679
return nullptr;
610680
}
611681

612-
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
613-
std::vector<Tensor*> out) {
614-
std::vector<phi::distributed::DistTensor*> result;
615-
for (auto tmp : out) {
616-
if (tmp) {
617-
// TODO(GhostScreaming): now all dist case are nullptr
618-
if (tmp->impl() == nullptr) {
619-
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
620-
tmp->set_impl(dist_t);
621-
}
622-
result.emplace_back(
623-
static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
624-
} else {
625-
result.emplace_back(nullptr);
626-
}
627-
}
628-
return result;
629-
}
630-
631-
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
632-
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
633-
PADDLE_ENFORCE_EQ(
634-
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
635-
dist_attr),
636-
true,
637-
phi::errors::PreconditionNotMet(
638-
"Arg must be a vector of TensorDistAttr"));
639-
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
640-
PADDLE_GET_CONST(std::vector<phi::distributed::TensorDistAttr>,
641-
dist_attr);
642-
auto out_size = dist_attrs.size();
643-
out->reserve(out_size);
644-
std::vector<phi::distributed::DistTensor*> results(out_size);
645-
for (size_t i = 0; i < out_size; ++i) {
646-
auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
647-
dist_attrs[i]);
648-
results[i] = dist_t.get();
649-
out->emplace_back();
650-
out->back().set_impl(dist_t);
651-
}
652-
return results;
653-
}
654-
655-
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
656-
size_t out_size, std::vector<Tensor>* out) {
657-
out->reserve(out_size);
658-
std::vector<phi::distributed::DistTensor*> results(out_size);
659-
for (size_t i = 0; i < out_size; ++i) {
660-
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
661-
results[i] = dist_t.get();
662-
out->emplace_back();
663-
out->back().set_impl(dist_t);
664-
}
665-
return results;
666-
}
667-
668-
std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
669-
size_t out_size, std::vector<Tensor>* out) {
670-
std::vector<phi::distributed::DistTensor*> results(out->size(), nullptr);
671-
for (size_t i = 0; i < out->size(); ++i) {
672-
results[i] =
673-
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
674-
}
675-
return results;
676-
}
677-
678-
std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
679-
size_t out_size, paddle::optional<std::vector<Tensor>> out) {
680-
std::vector<phi::distributed::DistTensor*> results;
681-
if (out) {
682-
results = std::vector<phi::distributed::DistTensor*>(out->size(), nullptr);
683-
for (size_t i = 0; i < out->size(); ++i) {
684-
results[i] =
685-
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
686-
}
687-
}
688-
return results;
689-
}
690682
void SetReplicatedDistAttrForOutput(
691683
phi::distributed::DistTensor* out,
692684
const phi::distributed::ProcessMesh& process_mesh) {

paddle/phi/api/lib/api_gen_utils.h

+8-16
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,10 @@ std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(
145145

146146
phi::distributed::DistTensor* SetKernelDistOutput(
147147
Tensor* out,
148-
const phi::distributed::TensorDistAttr& dist_attr =
149-
phi::distributed::TensorDistAttr());
150-
151-
phi::distributed::DistTensor* SetKernelDistOutput(
152-
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);
153-
154-
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
155-
Tensor* out,
156-
bool set_dist_output_as_tensor_impl,
157148
const phi::distributed::ArgDistAttr& dist_attr =
158149
phi::distributed::TensorDistAttr());
159150

160-
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
161-
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);
162-
151+
// For backward
163152
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
164153
std::vector<Tensor*> out);
165154

@@ -169,11 +158,14 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
169158
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
170159
const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out);
171160

172-
std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
173-
size_t out_size, std::vector<Tensor>* out);
161+
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
162+
Tensor* out,
163+
bool set_dist_output_as_tensor_impl,
164+
const phi::distributed::ArgDistAttr& dist_attr =
165+
phi::distributed::TensorDistAttr());
174166

175-
std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
176-
size_t out_size, paddle::optional<std::vector<Tensor>> out);
167+
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
168+
Tensor* out, const phi::distributed::ArgDistAttr& dist_attr);
177169

178170
// DistTensor need to set initial dist attr after the dims setted, it is
179171
// constructed based dims and current process mesh, beforce calling this

paddle/phi/api/lib/data_transform.cc

+102
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,108 @@ ReshardApiInputToKernelInput(
722722
return paddle::none;
723723
}
724724

725+
void SetInplaceOutputCorrectDistAttr(
726+
phi::DeviceContext* dev_ctx,
727+
Tensor& tensor, // NOLINT
728+
const phi::distributed::TensorDistAttr& dist_attr,
729+
bool use_general_spmd_rule) {
730+
auto tensor_in = tensor.impl();
731+
if (tensor_in) {
732+
phi::distributed::DistTensor* dist_tensor =
733+
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
734+
if (dist_tensor->initialized()) {
735+
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
736+
if (use_general_spmd_rule) {
737+
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
738+
<< " to origin dist_attr "
739+
<< ReshardDebugInfo(*dist_tensor, dist_attr);
740+
auto* func = phi::distributed::ChooseProperReshardFunction(
741+
*dist_tensor, dist_attr);
742+
func->Eval(dev_ctx, *dist_tensor, dist_attr, dist_tensor);
743+
} else {
744+
// just set correct SPMD dist_attrs
745+
VLOG(6) << "SetInplaceOutputCorrectDistAttr input " << tensor.name()
746+
<< " set its dist_attr from " << dist_tensor->dist_attr()
747+
<< " to " << dist_attr;
748+
dist_tensor->unsafe_set_dist_attr(dist_attr);
749+
}
750+
}
751+
} else {
752+
VLOG(6) << "SetInplaceOutputCorrectDistAttr has"
753+
<< " uninitialized DistTensor input " << tensor.name()
754+
<< ", just set its dist_attr from " << dist_tensor->dist_attr()
755+
<< " to " << dist_attr;
756+
dist_tensor->unsafe_set_dist_attr(dist_attr);
757+
}
758+
}
759+
}
760+
761+
void SetInplaceOutputCorrectDistAttr(
762+
phi::DeviceContext* dev_ctx,
763+
Tensor& tensor, // NOLINT
764+
const phi::distributed::ArgDistAttr& dist_attr,
765+
bool use_general_spmd_rule) {
766+
PADDLE_ENFORCE_EQ(
767+
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
768+
true,
769+
phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr"));
770+
SetInplaceOutputCorrectDistAttr(
771+
dev_ctx, tensor, paddle::get<0>(dist_attr), use_general_spmd_rule);
772+
}
773+
774+
void SetInplaceOutputCorrectDistAttr(
775+
phi::DeviceContext* dev_ctx,
776+
std::vector<Tensor>& tensors, // NOLINT
777+
const std::vector<phi::distributed::TensorDistAttr>& dist_attr,
778+
bool use_general_spmd_rule) {
779+
for (size_t i = 0; i < tensors.size(); i++) {
780+
auto tensor_in = tensors[i].impl();
781+
if (tensor_in) {
782+
phi::distributed::DistTensor* dist_tensor =
783+
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
784+
if (dist_tensor->initialized()) {
785+
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr[i])) {
786+
if (use_general_spmd_rule) {
787+
VLOG(6) << "SetInplaceOutputCorrectDistAttr Reshard inplace output"
788+
<< " to origin dist_attr "
789+
<< ReshardDebugInfo(*dist_tensor, dist_attr[i]);
790+
auto* func = phi::distributed::ChooseProperReshardFunction(
791+
*dist_tensor, dist_attr[i]);
792+
func->Eval(dev_ctx, *dist_tensor, dist_attr[i], dist_tensor);
793+
} else {
794+
// just set correct SPMD dist_attrs
795+
VLOG(6) << "SetInplaceOutputCorrectDistAttr input "
796+
<< tensors[i].name() << " set its dist_attr from "
797+
<< dist_tensor->dist_attr() << " to " << dist_attr[i];
798+
dist_tensor->unsafe_set_dist_attr(dist_attr[i]);
799+
}
800+
}
801+
} else {
802+
VLOG(6) << "SetInplaceOutputCorrectDistAttr has"
803+
<< " uninitialized DistTensor input " << tensors[i].name()
804+
<< ", just set its dist_attr from " << dist_tensor->dist_attr()
805+
<< " to " << dist_attr[i];
806+
dist_tensor->unsafe_set_dist_attr(dist_attr[i]);
807+
}
808+
}
809+
}
810+
}
811+
812+
void SetInplaceOutputCorrectDistAttr(
813+
phi::DeviceContext* dev_ctx,
814+
std::vector<Tensor>& tensors, // NOLINT
815+
const phi::distributed::ArgDistAttr& dist_attr,
816+
bool use_general_spmd_rule) {
817+
PADDLE_ENFORCE_EQ(
818+
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
819+
dist_attr),
820+
true,
821+
phi::errors::PreconditionNotMet(
822+
"Arg must be a vector of TensorDistAttr"));
823+
SetInplaceOutputCorrectDistAttr(
824+
dev_ctx, tensors, paddle::get<1>(dist_attr), use_general_spmd_rule);
825+
}
826+
725827
void ReshardOutputPartialAxisToReplicated(
726828
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) {
727829
if (out_tensor->dist_attr().is_partial()) {

paddle/phi/api/lib/data_transform.h

+24
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,30 @@ ReshardApiInputToKernelInput(
197197
const paddle::optional<std::vector<Tensor>>& tensors,
198198
const phi::distributed::ArgDistAttr& dist_attr);
199199

200+
void SetInplaceOutputCorrectDistAttr(
201+
phi::DeviceContext* dev_ctx,
202+
Tensor& tensor, // NOLINT
203+
const phi::distributed::TensorDistAttr& dist_attr,
204+
bool use_general_spmd_rule = true);
205+
206+
void SetInplaceOutputCorrectDistAttr(
207+
phi::DeviceContext* dev_ctx,
208+
Tensor& tensor, // NOLINT
209+
const phi::distributed::ArgDistAttr& dist_attr,
210+
bool use_general_spmd_rule = true);
211+
212+
void SetInplaceOutputCorrectDistAttr(
213+
phi::DeviceContext* dev_ctx,
214+
std::vector<Tensor>& tensors, // NOLINT
215+
const std::vector<phi::distributed::TensorDistAttr>& dist_attr,
216+
bool use_general_spmd_rule = true);
217+
218+
void SetInplaceOutputCorrectDistAttr(
219+
phi::DeviceContext* dev_ctx,
220+
std::vector<Tensor>& tensors, // NOLINT
221+
const phi::distributed::ArgDistAttr& dist_attr,
222+
bool use_general_spmd_rule = true);
223+
200224
void ReshardOutputPartialAxisToReplicated(
201225
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor);
202226

0 commit comments

Comments
 (0)