@@ -547,25 +547,95 @@ std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(
547
547
}
548
548
549
549
phi::distributed::DistTensor* SetKernelDistOutput (
550
- Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
550
+ Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
551
+ PADDLE_ENFORCE_EQ (
552
+ paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
553
+ true ,
554
+ phi::errors::PreconditionNotMet (" Arg must be a single TensorDistAttr" ));
551
555
if (out) {
552
556
if (out->impl () == nullptr ) {
553
- auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim (),
554
- dist_attr);
557
+ auto dist_t = std::make_shared<phi::distributed::DistTensor>(
558
+ phi::DDim (), paddle::get< 0 >( dist_attr) );
555
559
out->set_impl (dist_t );
556
560
}
557
561
return static_cast <phi::distributed::DistTensor*>(out->impl ().get ());
558
562
}
559
563
return nullptr ;
560
564
}
561
565
562
- phi::distributed::DistTensor* SetKernelDistOutput (
563
- Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) {
566
+ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput (
567
+ size_t out_size, std::vector<Tensor>* out) {
568
+ std::vector<phi::distributed::DistTensor*> results (out_size);
569
+ if (out->size () != out_size) {
570
+ // Empty out vector
571
+ out->reserve (out_size);
572
+ }
573
+ for (size_t i = 0 ; i < out_size; ++i) {
574
+ if (out->size () != out_size) {
575
+ auto dist_t = std::make_shared<phi::distributed::DistTensor>();
576
+ out->emplace_back ();
577
+ out->back ().set_impl (dist_t );
578
+ }
579
+ results[i] =
580
+ static_cast <phi::distributed::DistTensor*>(out->at (i).impl ().get ());
581
+ }
582
+ return results;
583
+ }
584
+
585
+ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput (
586
+ const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
564
587
PADDLE_ENFORCE_EQ (
565
- paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr),
588
+ paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
589
+ dist_attr),
566
590
true ,
567
- phi::errors::PreconditionNotMet (" Arg must be a single TensorDistAttr" ));
568
- return SetKernelDistOutput (out, paddle::get<0 >(dist_attr));
591
+ phi::errors::PreconditionNotMet (
592
+ " Arg must be a vector of TensorDistAttr" ));
593
+ const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
594
+ PADDLE_GET_CONST (std::vector<phi::distributed::TensorDistAttr>,
595
+ dist_attr);
596
+ auto out_size = dist_attrs.size ();
597
+ std::vector<phi::distributed::DistTensor*> results (out_size);
598
+ // TODO(GhostScreaming): Inplace outputs are initialized, just set their
599
+ // dist_attr.
600
+ if (out->size () == out_size) {
601
+ VLOG (3 ) << " Outputs are inplace vector Tensors, just set their dist_attrs "
602
+ << " according to InferSPMD output result." ;
603
+ for (size_t i = 0 ; i < out_size; ++i) {
604
+ results[i] =
605
+ static_cast <phi::distributed::DistTensor*>(out->at (i).impl ().get ());
606
+ results[i]->unsafe_set_dist_attr (dist_attrs[i]);
607
+ }
608
+ } else {
609
+ out->reserve (out_size);
610
+ for (size_t i = 0 ; i < out_size; ++i) {
611
+ auto dist_t = std::make_shared<phi::distributed::DistTensor>(
612
+ phi::DDim (), dist_attrs[i]);
613
+ results[i] = dist_t .get ();
614
+ out->emplace_back ();
615
+ out->back ().set_impl (dist_t );
616
+ }
617
+ }
618
+ return results;
619
+ }
620
+
621
+ // For backward
622
+ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput (
623
+ std::vector<Tensor*> out) {
624
+ std::vector<phi::distributed::DistTensor*> result;
625
+ for (auto tmp : out) {
626
+ if (tmp) {
627
+ // TODO(GhostScreaming): now all dist case are nullptr
628
+ if (tmp->impl () == nullptr ) {
629
+ auto dist_t = std::make_shared<phi::distributed::DistTensor>();
630
+ tmp->set_impl (dist_t );
631
+ }
632
+ result.emplace_back (
633
+ static_cast <phi::distributed::DistTensor*>(tmp->impl ().get ()));
634
+ } else {
635
+ result.emplace_back (nullptr );
636
+ }
637
+ }
638
+ return result;
569
639
}
570
640
571
641
std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput (
@@ -609,84 +679,6 @@ std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
609
679
return nullptr ;
610
680
}
611
681
612
- std::vector<phi::distributed::DistTensor*> SetKernelDistOutput (
613
- std::vector<Tensor*> out) {
614
- std::vector<phi::distributed::DistTensor*> result;
615
- for (auto tmp : out) {
616
- if (tmp) {
617
- // TODO(GhostScreaming): now all dist case are nullptr
618
- if (tmp->impl () == nullptr ) {
619
- auto dist_t = std::make_shared<phi::distributed::DistTensor>();
620
- tmp->set_impl (dist_t );
621
- }
622
- result.emplace_back (
623
- static_cast <phi::distributed::DistTensor*>(tmp->impl ().get ()));
624
- } else {
625
- result.emplace_back (nullptr );
626
- }
627
- }
628
- return result;
629
- }
630
-
631
- std::vector<phi::distributed::DistTensor*> SetKernelDistOutput (
632
- const phi::distributed::ArgDistAttr& dist_attr, std::vector<Tensor>* out) {
633
- PADDLE_ENFORCE_EQ (
634
- paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
635
- dist_attr),
636
- true ,
637
- phi::errors::PreconditionNotMet (
638
- " Arg must be a vector of TensorDistAttr" ));
639
- const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
640
- PADDLE_GET_CONST (std::vector<phi::distributed::TensorDistAttr>,
641
- dist_attr);
642
- auto out_size = dist_attrs.size ();
643
- out->reserve (out_size);
644
- std::vector<phi::distributed::DistTensor*> results (out_size);
645
- for (size_t i = 0 ; i < out_size; ++i) {
646
- auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim (),
647
- dist_attrs[i]);
648
- results[i] = dist_t .get ();
649
- out->emplace_back ();
650
- out->back ().set_impl (dist_t );
651
- }
652
- return results;
653
- }
654
-
655
- std::vector<phi::distributed::DistTensor*> SetKernelDistOutput (
656
- size_t out_size, std::vector<Tensor>* out) {
657
- out->reserve (out_size);
658
- std::vector<phi::distributed::DistTensor*> results (out_size);
659
- for (size_t i = 0 ; i < out_size; ++i) {
660
- auto dist_t = std::make_shared<phi::distributed::DistTensor>();
661
- results[i] = dist_t .get ();
662
- out->emplace_back ();
663
- out->back ().set_impl (dist_t );
664
- }
665
- return results;
666
- }
667
-
668
- std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput (
669
- size_t out_size, std::vector<Tensor>* out) {
670
- std::vector<phi::distributed::DistTensor*> results (out->size (), nullptr );
671
- for (size_t i = 0 ; i < out->size (); ++i) {
672
- results[i] =
673
- static_cast <phi::distributed::DistTensor*>(out->at (i).impl ().get ());
674
- }
675
- return results;
676
- }
677
-
678
- std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput (
679
- size_t out_size, paddle::optional<std::vector<Tensor>> out) {
680
- std::vector<phi::distributed::DistTensor*> results;
681
- if (out) {
682
- results = std::vector<phi::distributed::DistTensor*>(out->size (), nullptr );
683
- for (size_t i = 0 ; i < out->size (); ++i) {
684
- results[i] =
685
- static_cast <phi::distributed::DistTensor*>(out->at (i).impl ().get ());
686
- }
687
- }
688
- return results;
689
- }
690
682
void SetReplicatedDistAttrForOutput (
691
683
phi::distributed::DistTensor* out,
692
684
const phi::distributed::ProcessMesh& process_mesh) {
0 commit comments