@@ -565,6 +565,58 @@ bool RecvV2OpInferSymbolicShape(pir::Operation *op,
565
565
return true ;
566
566
}
567
567
568
+ bool PRecvOpInferSymbolicShape (pir::Operation *op,
569
+ pir::InferSymbolicShapeContext *infer_context) {
570
+ const int ring_id = op->attribute <pir::Int32Attribute>(" ring_id" ).data ();
571
+ const bool dynamic_shape =
572
+ op->attribute <pir::BoolAttribute>(" dynamic_shape" ).data ();
573
+ const int peer = op->attribute <pir::Int32Attribute>(" peer" ).data ();
574
+
575
+ PADDLE_ENFORCE_GE (
576
+ peer,
577
+ 0 ,
578
+ common::errors::InvalidArgument (
579
+ " The peer (%d) for p_recv op must be non-negative." , peer));
580
+
581
+ PADDLE_ENFORCE_GE (
582
+ ring_id,
583
+ 0 ,
584
+ common::errors::InvalidArgument (
585
+ " The ring_id (%d) for p_recv op must be non-negative." , ring_id));
586
+
587
+ const std::vector<int > out_shape =
588
+ paddle::dialect::details::GetVectorAttr<int >(op, " out_shape" );
589
+ if (!dynamic_shape) {
590
+ PADDLE_ENFORCE_GE (out_shape.size (),
591
+ 1 ,
592
+ common::errors::InvalidArgument (
593
+ " The size of the output shape must be greater than 0 "
594
+ " but the value given is %d." ,
595
+ out_shape.size ()));
596
+
597
+ std::vector<symbol::DimExpr> output_shape;
598
+ for (size_t i = 0 ; i < out_shape.size (); ++i) {
599
+ PADDLE_ENFORCE_GE (out_shape[i],
600
+ 1 ,
601
+ common::errors::InvalidArgument (
602
+ " The shape attribute for p_recv must be set "
603
+ " explicitly, but the %dth element is %d which "
604
+ " is less than 1. Or dynamic_shape should be set to "
605
+ " True for both send_v2 and p_recv." ,
606
+ i,
607
+ out_shape[i]));
608
+ output_shape.push_back (symbol::DimExpr (out_shape[i]));
609
+ }
610
+
611
+ infer_context->SetShapeOrDataForValue (
612
+ op->result (0 ),
613
+ symbol::ShapeOrDataDimExprs{
614
+ symbol::TensorShapeOrDataDimExprs (output_shape)});
615
+ }
616
+
617
+ return true ;
618
+ }
619
+
568
620
bool SeedOpInferSymbolicShape (pir::Operation *op,
569
621
pir::InferSymbolicShapeContext *infer_context) {
570
622
std::vector<symbol::DimExpr> dims = {symbol::DimExpr (1 )};
0 commit comments