Skip to content

Commit 138b066

Browse files
committed
Fix
1 parent f03a429 commit 138b066

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+259
-93
lines changed

paddle/fluid/framework/new_executor/interpreter/execution_config.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
// FLAGS_force_sync_ops is used to finer control the op-sync in executor.
2828
// The format is: "micro_batch_id, job_name, op_id, op_name | micro_batch_id,
2929
// job_name, op_id, op_name | ...". Keep spaces to syncs all name/id. Example:
30-
// 1. sync the recv_v2 op in the second backward-job of 1F1B scheduling:
31-
// FLAGS_force_sync_ops="1, backward, , recv_v2"
30+
// 1. sync the p_recv op in the second backward-job of 1F1B scheduling:
31+
// FLAGS_force_sync_ops="1, backward, , p_recv"
3232
// 2. sync the full op with op_id=5: FLAGS_force_sync_ops=" , , 5, full"
3333
// 3. sync all ops in the first default-job: FLAGS_force_sync_ops="0,default,,
3434
// 4. sync all ops in the forward-job and backward-job: FLAGS_force_sync_ops=" ,

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ bool IsCommunicationOp(const OperatorBase* op) {
142142
"recv",
143143
"send_v2",
144144
"recv_v2",
145+
"p_recv",
145146
};
146147
const std::string communication_op_prefix = "c_";
147148
if (op_name.find(communication_op_prefix) != std::string::npos ||
@@ -169,7 +170,7 @@ bool IsCommunicationOp(const ::pir::Operation* op) {
169170
}
170171
const std::set<std::string> special_comm_op_set = {
171172
paddle::dialect::SendV2Op::name(),
172-
paddle::dialect::RecvV2Op::name(),
173+
paddle::dialect::PRecvOp::name(),
173174
};
174175
const std::string communication_op_prefix = "c_";
175176
if (op_name.find(communication_op_prefix) != std::string::npos ||

paddle/fluid/framework/new_executor/pir_interpreter.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ void PirInterpreter::UpdateNcclOpNum() {
545545
"pd_op.partial_recv",
546546
"pd_op.partial_allgather",
547547
"pd_op.recv_v2",
548+
"pd_op.p_recv",
548549
"pd_op.send_v2",
549550
"pd_op.mp_allreduce_sum",
550551
"pd_op.barrier",
@@ -575,7 +576,7 @@ void PirInterpreter::UpdateNcclOpNum() {
575576
"pd_op.partial_send_grad",
576577
"pd_op.partial_recv_grad",
577578
"pd_op.partial_allgather_grad",
578-
"pd_op.recv_v2_grad",
579+
"pd_op.p_recv_grad",
579580
"pd_op.send_v2_grad",
580581
"pd_op.mp_allreduce_sum_grad",
581582
"pd_op.barrier_grad",
@@ -608,7 +609,7 @@ void PirInterpreter::UpdateNcclOpNum() {
608609
"pd_op.partial_send_",
609610
"pd_op.partial_recv_",
610611
"pd_op.partial_allgather_",
611-
"pd_op.recv_v2_",
612+
"pd_op.p_recv_",
612613
"pd_op.send_v2_",
613614
"pd_op.mp_allreduce_sum_",
614615
"pd_op.barrier_",
@@ -639,7 +640,7 @@ void PirInterpreter::UpdateNcclOpNum() {
639640
"pd_op.partial_send_grad_",
640641
"pd_op.partial_recv_grad_",
641642
"pd_op.partial_allgather_grad_",
642-
"pd_op.recv_v2_grad_",
643+
"pd_op.p_recv_grad_",
643644
"pd_op.send_v2_grad_",
644645
"pd_op.mp_allreduce_sum_grad_",
645646
"pd_op.barrier_grad_",

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
'coalesce_tensor_',
141141
'send_v2',
142142
'recv_v2',
143+
'p_recv',
143144
'sequence_expand',
144145
'sequence_softmax',
145146
'qkv_unpack_mha',

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc

+52
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,58 @@ bool RecvV2OpInferSymbolicShape(pir::Operation *op,
565565
return true;
566566
}
567567

568+
bool PRecvOpInferSymbolicShape(pir::Operation *op,
569+
pir::InferSymbolicShapeContext *infer_context) {
570+
const int ring_id = op->attribute<pir::Int32Attribute>("ring_id").data();
571+
const bool dynamic_shape =
572+
op->attribute<pir::BoolAttribute>("dynamic_shape").data();
573+
const int peer = op->attribute<pir::Int32Attribute>("peer").data();
574+
575+
PADDLE_ENFORCE_GE(
576+
peer,
577+
0,
578+
common::errors::InvalidArgument(
579+
"The peer (%d) for p_recv op must be non-negative.", peer));
580+
581+
PADDLE_ENFORCE_GE(
582+
ring_id,
583+
0,
584+
common::errors::InvalidArgument(
585+
"The ring_id (%d) for p_recv op must be non-negative.", ring_id));
586+
587+
const std::vector<int> out_shape =
588+
paddle::dialect::details::GetVectorAttr<int>(op, "out_shape");
589+
if (!dynamic_shape) {
590+
PADDLE_ENFORCE_GE(out_shape.size(),
591+
1,
592+
common::errors::InvalidArgument(
593+
"The size of the output shape must be greater than 0 "
594+
"but the value given is %d.",
595+
out_shape.size()));
596+
597+
std::vector<symbol::DimExpr> output_shape;
598+
for (size_t i = 0; i < out_shape.size(); ++i) {
599+
PADDLE_ENFORCE_GE(out_shape[i],
600+
1,
601+
common::errors::InvalidArgument(
602+
"The shape attribute for p_recv must be set "
603+
"explicitly, but the %dth element is %d which "
604+
"is less than 1. Or dynamic_shape should be set to "
605+
"True for both send_v2 and p_recv.",
606+
i,
607+
out_shape[i]));
608+
output_shape.push_back(symbol::DimExpr(out_shape[i]));
609+
}
610+
611+
infer_context->SetShapeOrDataForValue(
612+
op->result(0),
613+
symbol::ShapeOrDataDimExprs{
614+
symbol::TensorShapeOrDataDimExprs(output_shape)});
615+
}
616+
617+
return true;
618+
}
619+
568620
bool SeedOpInferSymbolicShape(pir::Operation *op,
569621
pir::InferSymbolicShapeContext *infer_context) {
570622
std::vector<symbol::DimExpr> dims = {symbol::DimExpr(1)};

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randperm)
3434
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReadFile)
3535
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Seed)
3636
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RecvV2)
37+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PRecv)
3738
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices)
3839
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices)
3940
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TruncatedGaussianRandom)

paddle/fluid/pir/dialect/operator/utils/utils.cc

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ namespace dialect {
3939
const std::unordered_set<std::string> LegacyOpList = {
4040
DistributedPushSparseOp::name(),
4141
SendV2Op::name(),
42-
RecvV2Op::name(),
4342
CAllreduceSumOp::name(),
4443
CAllreduceSum_Op::name(),
4544
};

paddle/fluid/pybind/eager_generator.cc

+1
Original file line numberDiff line numberDiff line change
@@ -3395,6 +3395,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
33953395
{"accuracy", {"Correct", "Total"}},
33963396
{"fill_constant", {"Out"}},
33973397
{"recv_v2", {"Out"}},
3398+
{"p_recv", {"Out"}},
33983399
{"partial_recv", {"Out"}},
33993400
{"matmul", {"Out"}},
34003401
{"c_broadcast", {"Out"}},

paddle/phi/infermeta/nullary.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ void PRecvArrayInferMeta(int peer,
274274
errors::InvalidArgument("The shape attribute for recv must be set "
275275
"explicitly, but the %dth element is %d which "
276276
"is less than 1. Or dynamic_shape should be "
277-
"set to True for both send_v2 and recv_v2.",
277+
"set to True for both send_v2 and p_recv.",
278278
i,
279279
out_shape[i]));
280280
}
@@ -291,13 +291,13 @@ void RecvV2InferMeta(const int ring_id,
291291
peer,
292292
0,
293293
errors::InvalidArgument(
294-
"The peer (%d) for recv_v2 op must be non-negative.", peer));
294+
"The peer (%d) for p_recv op must be non-negative.", peer));
295295

296296
PADDLE_ENFORCE_GE(
297297
ring_id,
298298
0,
299299
errors::InvalidArgument(
300-
"The ring_id (%d) for recv_v2 op must be non-negative.", ring_id));
300+
"The ring_id (%d) for p_recv op must be non-negative.", ring_id));
301301

302302
if (!dynamic_shape) {
303303
PADDLE_ENFORCE_GE(out_shape.size(),
@@ -310,10 +310,10 @@ void RecvV2InferMeta(const int ring_id,
310310
PADDLE_ENFORCE_GE(out_shape[i],
311311
1,
312312
errors::InvalidArgument(
313-
"The shape attribute for recv_v2 must be set "
313+
"The shape attribute for p_recv must be set "
314314
"explicitly, but the %dth element is %d which "
315315
"is less than 1. Or dynamic_shape should be "
316-
"set to True for both send_v2 and recv_v2.",
316+
"set to True for both send_v2 and p_recv.",
317317
i,
318318
out_shape[i]));
319319
}

paddle/phi/kernels/cpu/p_recv_kernel.cc

+26
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ void PRecvKernel(const Context& dev_ctx UNUSED,
3333
PADDLE_THROW(errors::Unavailable("Do not support recv for cpu kernel now."));
3434
}
3535

36+
template <typename T, typename Context>
37+
void PRecv2Kernel(const Context& dev_ctx UNUSED,
38+
int ring_id UNUSED,
39+
bool dynamic_shape UNUSED,
40+
int peer UNUSED,
41+
const std::vector<int>& out_shape UNUSED,
42+
DataType dtype UNUSED,
43+
bool use_calc_stream UNUSED,
44+
DenseTensor* out UNUSED) {
45+
PADDLE_THROW(errors::Unavailable("Do not support recv for cpu kernel now."));
46+
}
47+
3648
template <typename T, typename Context>
3749
void PRecvArrayKernel(const Context& dev_ctx UNUSED,
3850
int peer UNUSED,
@@ -59,6 +71,20 @@ PD_REGISTER_KERNEL(p_recv,
5971
int64_t,
6072
phi::dtype::float16) {}
6173

74+
PD_REGISTER_KERNEL(p_recv2,
75+
CPU,
76+
ALL_LAYOUT,
77+
phi::PRecv2Kernel,
78+
float,
79+
double,
80+
int,
81+
bool,
82+
int8_t,
83+
uint8_t,
84+
int16_t,
85+
int64_t,
86+
phi::dtype::float16) {}
87+
6288
PD_REGISTER_KERNEL(p_recv_array,
6389
CPU,
6490
ALL_LAYOUT,

paddle/phi/kernels/gpu/p_recv_kernel.cu

+39
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ void PRecvKernel(const Context& dev_ctx,
5858
#endif
5959
}
6060

61+
template <typename T, typename Context>
62+
void PRecv2Kernel(const Context& dev_ctx UNUSED,
63+
int ring_id UNUSED,
64+
bool dynamic_shape UNUSED,
65+
int peer UNUSED,
66+
const std::vector<int>& out_shape UNUSED,
67+
DataType dtype UNUSED,
68+
bool use_calc_stream UNUSED,
69+
DenseTensor* out UNUSED) {
70+
PRecvKernel<T, Context>(dev_ctx, peer, dtype, out_shape, dynamic_shape, out);
71+
}
72+
6173
template <typename T, typename Context>
6274
void PRecvArrayKernel(const Context& dev_ctx,
6375
int peer,
@@ -103,6 +115,20 @@ PD_REGISTER_KERNEL(p_recv,
103115
int64_t,
104116
phi::dtype::bfloat16,
105117
phi::dtype::float16) {}
118+
PD_REGISTER_KERNEL(p_recv2,
119+
GPU,
120+
ALL_LAYOUT,
121+
phi::PRecv2Kernel,
122+
float,
123+
double,
124+
int,
125+
bool,
126+
int8_t,
127+
uint8_t,
128+
int16_t,
129+
int64_t,
130+
phi::dtype::bfloat16,
131+
phi::dtype::float16) {}
106132

107133
PD_REGISTER_KERNEL(p_recv_array,
108134
GPU,
@@ -131,6 +157,19 @@ PD_REGISTER_KERNEL(p_recv,
131157
int16_t,
132158
int64_t,
133159
phi::dtype::float16) {}
160+
PD_REGISTER_KERNEL(p_recv2,
161+
GPU,
162+
ALL_LAYOUT,
163+
phi::PRecv2Kernel,
164+
float,
165+
double,
166+
int,
167+
bool,
168+
int8_t,
169+
uint8_t,
170+
int16_t,
171+
int64_t,
172+
phi::dtype::float16) {}
134173

135174
PD_REGISTER_KERNEL(p_recv_array,
136175
GPU,

paddle/phi/kernels/gpu/p_send_kernel.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void PSendKernel(const Context& dev_ctx,
3838
defined(PADDLE_WITH_RCCL) && NCCL_VERSION_CODE >= 2703
3939
auto comm_ctx =
4040
GetCommContext<Context, distributed::NCCLCommContext>(dev_ctx, peer);
41-
gpuStream_t stream = dev_ctx.stream();
41+
gpuStream_t stream = comm_ctx->GetStream();
4242
if (dynamic_shape) {
4343
send_shape_info<Context, distributed::NCCLCommContext, gpuStream_t>(
4444
dev_ctx, x, comm_ctx, peer, stream);

paddle/phi/kernels/xpu/p_recv_kernel.cc

+24
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ void PRecvKernel(const Context& dev_ctx,
5656
#endif
5757
}
5858

59+
template <typename T, typename Context>
60+
void PRecv2Kernel(const Context& dev_ctx UNUSED,
61+
int ring_id UNUSED,
62+
bool dynamic_shape UNUSED,
63+
int peer UNUSED,
64+
const std::vector<int>& out_shape UNUSED,
65+
DataType dtype UNUSED,
66+
bool use_calc_stream UNUSED,
67+
DenseTensor* out UNUSED) {
68+
PRecvKernel<T, Context>(dev_ctx, peer, dtype, out_shape, dynamic_shape, out);
69+
}
70+
5971
template <typename T, typename Context>
6072
void PRecvArrayKernel(const Context& dev_ctx,
6173
int peer,
@@ -96,6 +108,18 @@ PD_REGISTER_KERNEL(p_recv,
96108
phi::dtype::bfloat16,
97109
phi::dtype::float16) {}
98110

111+
PD_REGISTER_KERNEL(p_recv2,
112+
XPU,
113+
ALL_LAYOUT,
114+
phi::PRecv2Kernel,
115+
float,
116+
double,
117+
uint8_t,
118+
int,
119+
int64_t,
120+
phi::dtype::bfloat16,
121+
phi::dtype::float16) {}
122+
99123
PD_REGISTER_KERNEL(p_recv_array,
100124
XPU,
101125
ALL_LAYOUT,

paddle/phi/kernels/xpu/p_send_kernel.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void PSendKernel(const Context& dev_ctx,
3636
#if defined(PADDLE_WITH_XPU_BKCL)
3737
auto comm_ctx =
3838
GetCommContext<Context, distributed::BKCLCommContext>(dev_ctx, peer);
39-
XPUStream stream = dev_ctx.stream();
39+
XPUStream stream = comm_ctx->GetStream();
4040
if (dynamic_shape) {
4141
send_shape_info<Context, distributed::BKCLCommContext, XPUStream>(
4242
dev_ctx, x, comm_ctx, peer, stream);

0 commit comments

Comments
 (0)