Skip to content

Commit b68d21b

Browse files
committed
Fix
1 parent 8bd6a82 commit b68d21b

38 files changed

+143
-76
lines changed

paddle/fluid/framework/new_executor/interpreter/execution_config.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
// FLAGS_force_sync_ops is used to finer control the op-sync in executor.
2828
// The format is: "micro_batch_id, job_name, op_id, op_name | micro_batch_id,
2929
// job_name, op_id, op_name | ...". Keep spaces to syncs all name/id. Example:
30-
// 1. sync the recv_v2 op in the second backward-job of 1F1B scheduling:
31-
// FLAGS_force_sync_ops="1, backward, , recv_v2"
30+
// 1. sync the p_recv op in the second backward-job of 1F1B scheduling:
31+
// FLAGS_force_sync_ops="1, backward, , p_recv"
3232
// 2. sync the full op with op_id=5: FLAGS_force_sync_ops=" , , 5, full"
3333
// 3. sync all ops in the first default-job: FLAGS_force_sync_ops="0,default,,
3434
// 4. sync all ops in the forward-job and backward-job: FLAGS_force_sync_ops=" ,

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

+2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ bool IsCommunicationOp(const OperatorBase* op) {
142142
"recv",
143143
"send_v2",
144144
"recv_v2",
145+
"p_recv",
145146
};
146147
const std::string communication_op_prefix = "c_";
147148
if (op_name.find(communication_op_prefix) != std::string::npos ||
@@ -170,6 +171,7 @@ bool IsCommunicationOp(const ::pir::Operation* op) {
170171
const std::set<std::string> special_comm_op_set = {
171172
paddle::dialect::SendV2Op::name(),
172173
paddle::dialect::RecvV2Op::name(),
174+
paddle::dialect::PRecvOp::name(),
173175
};
174176
const std::string communication_op_prefix = "c_";
175177
if (op_name.find(communication_op_prefix) != std::string::npos ||

paddle/fluid/framework/new_executor/pir_interpreter.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ void PirInterpreter::UpdateNcclOpNum() {
551551
"pd_op.partial_recv",
552552
"pd_op.partial_allgather",
553553
"pd_op.recv_v2",
554+
"pd_op.p_recv",
554555
"pd_op.send_v2",
555556
"pd_op.mp_allreduce_sum",
556557
"pd_op.barrier",
@@ -586,7 +587,7 @@ void PirInterpreter::UpdateNcclOpNum() {
586587
"pd_op.partial_send_grad",
587588
"pd_op.partial_recv_grad",
588589
"pd_op.partial_allgather_grad",
589-
"pd_op.recv_v2_grad",
590+
"pd_op.p_recv_grad",
590591
"pd_op.send_v2_grad",
591592
"pd_op.mp_allreduce_sum_grad",
592593
"pd_op.barrier_grad",
@@ -625,7 +626,7 @@ void PirInterpreter::UpdateNcclOpNum() {
625626
"pd_op.partial_send_",
626627
"pd_op.partial_recv_",
627628
"pd_op.partial_allgather_",
628-
"pd_op.recv_v2_",
629+
"pd_op.p_recv_",
629630
"pd_op.send_v2_",
630631
"pd_op.mp_allreduce_sum_",
631632
"pd_op.barrier_",
@@ -661,7 +662,7 @@ void PirInterpreter::UpdateNcclOpNum() {
661662
"pd_op.partial_send_grad_",
662663
"pd_op.partial_recv_grad_",
663664
"pd_op.partial_allgather_grad_",
664-
"pd_op.recv_v2_grad_",
665+
"pd_op.p_recv_grad_",
665666
"pd_op.send_v2_grad_",
666667
"pd_op.mp_allreduce_sum_grad_",
667668
"pd_op.barrier_grad_",

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
'coalesce_tensor_',
141141
'send_v2',
142142
'recv_v2',
143+
'p_recv',
143144
'sequence_expand',
144145
'sequence_softmax',
145146
'qkv_unpack_mha',

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc

+52
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,58 @@ bool RecvV2OpInferSymbolicShape(pir::Operation *op,
565565
return true;
566566
}
567567

568+
bool PRecvOpInferSymbolicShape(pir::Operation *op,
569+
pir::InferSymbolicShapeContext *infer_context) {
570+
const int ring_id = op->attribute<pir::Int32Attribute>("ring_id").data();
571+
const bool dynamic_shape =
572+
op->attribute<pir::BoolAttribute>("dynamic_shape").data();
573+
const int peer = op->attribute<pir::Int32Attribute>("peer").data();
574+
575+
PADDLE_ENFORCE_GE(
576+
peer,
577+
0,
578+
common::errors::InvalidArgument(
579+
"The peer (%d) for p_recv op must be non-negative.", peer));
580+
581+
PADDLE_ENFORCE_GE(
582+
ring_id,
583+
0,
584+
common::errors::InvalidArgument(
585+
"The ring_id (%d) for p_recv op must be non-negative.", ring_id));
586+
587+
const std::vector<int> out_shape =
588+
paddle::dialect::details::GetVectorAttr<int>(op, "out_shape");
589+
if (!dynamic_shape) {
590+
PADDLE_ENFORCE_GE(out_shape.size(),
591+
1,
592+
common::errors::InvalidArgument(
593+
"The size of the output shape must be greater than 0 "
594+
"but the value given is %d.",
595+
out_shape.size()));
596+
597+
std::vector<symbol::DimExpr> output_shape;
598+
for (size_t i = 0; i < out_shape.size(); ++i) {
599+
PADDLE_ENFORCE_GE(out_shape[i],
600+
1,
601+
common::errors::InvalidArgument(
602+
"The shape attribute for p_recv must be set "
603+
"explicitly, but the %dth element is %d which "
604+
"is less than 1. Or dynamic_shape should be set to "
605+
"True for both send_v2 and p_recv.",
606+
i,
607+
out_shape[i]));
608+
output_shape.push_back(symbol::DimExpr(out_shape[i]));
609+
}
610+
611+
infer_context->SetShapeOrDataForValue(
612+
op->result(0),
613+
symbol::ShapeOrDataDimExprs{
614+
symbol::TensorShapeOrDataDimExprs(output_shape)});
615+
}
616+
617+
return true;
618+
}
619+
568620
bool SeedOpInferSymbolicShape(pir::Operation *op,
569621
pir::InferSymbolicShapeContext *infer_context) {
570622
std::vector<symbol::DimExpr> dims = {symbol::DimExpr(1)};

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randperm)
3434
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReadFile)
3535
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Seed)
3636
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RecvV2)
37+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PRecv)
3738
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices)
3839
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices)
3940
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TruncatedGaussianRandom)

paddle/fluid/pybind/eager_generator.cc

+1
Original file line numberDiff line numberDiff line change
@@ -3395,6 +3395,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
33953395
{"accuracy", {"Correct", "Total"}},
33963396
{"fill_constant", {"Out"}},
33973397
{"recv_v2", {"Out"}},
3398+
{"p_recv", {"Out"}},
33983399
{"partial_recv", {"Out"}},
33993400
{"matmul", {"Out"}},
34003401
{"c_broadcast", {"Out"}},

paddle/phi/infermeta/nullary.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ void PRecvArrayInferMeta(int peer,
274274
errors::InvalidArgument("The shape attribute for recv must be set "
275275
"explicitly, but the %dth element is %d which "
276276
"is less than 1. Or dynamic_shape should be "
277-
"set to True for both send_v2 and recv_v2.",
277+
"set to True for both send_v2 and p_recv.",
278278
i,
279279
out_shape[i]));
280280
}
@@ -291,13 +291,13 @@ void RecvV2InferMeta(const int ring_id,
291291
peer,
292292
0,
293293
errors::InvalidArgument(
294-
"The peer (%d) for recv_v2 op must be non-negative.", peer));
294+
"The peer (%d) for p_recv op must be non-negative.", peer));
295295

296296
PADDLE_ENFORCE_GE(
297297
ring_id,
298298
0,
299299
errors::InvalidArgument(
300-
"The ring_id (%d) for recv_v2 op must be non-negative.", ring_id));
300+
"The ring_id (%d) for p_recv op must be non-negative.", ring_id));
301301

302302
if (!dynamic_shape) {
303303
PADDLE_ENFORCE_GE(out_shape.size(),
@@ -310,10 +310,10 @@ void RecvV2InferMeta(const int ring_id,
310310
PADDLE_ENFORCE_GE(out_shape[i],
311311
1,
312312
errors::InvalidArgument(
313-
"The shape attribute for recv_v2 must be set "
313+
"The shape attribute for p_recv must be set "
314314
"explicitly, but the %dth element is %d which "
315315
"is less than 1. Or dynamic_shape should be "
316-
"set to True for both send_v2 and recv_v2.",
316+
"set to True for both send_v2 and p_recv.",
317317
i,
318318
out_shape[i]));
319319
}

paddle/phi/ops/yaml/inconsistent/static_ops.yaml

+13
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,19 @@
597597
interfaces : paddle::dialect::InferSymbolicShapeInterface
598598
traits : paddle::dialect::ForwardOnlyTrait
599599

600+
- op : p_recv
601+
args : (int ring_id = 0, int peer = 0, DataType dtype = DataType::FLOAT32, int[] out_shape = {}, bool dynamic_shape = false)
602+
output : Tensor(out)
603+
infer_meta :
604+
func : PRecvInferMeta
605+
param : [peer, dtype, out_shape, dynamic_shape]
606+
kernel :
607+
func : p_recv
608+
param : [peer, dtype, out_shape, dynamic_shape]
609+
data_type : dtype
610+
interfaces : paddle::dialect::InferSymbolicShapeInterface
611+
traits : paddle::dialect::ForwardOnlyTrait
612+
600613
- op : partial_recv
601614
args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, int num = 1, int id = 0)
602615
output : Tensor(out)

paddle/phi/ops/yaml/op_compat.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -4469,6 +4469,10 @@
44694469
outputs :
44704470
out : Out
44714471

4472+
- op: p_recv
4473+
outputs :
4474+
out : Out
4475+
44724476
- op: partial_send
44734477
inputs :
44744478
x : X

python/paddle/distributed/auto_parallel/static/cost/base_cost.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
COMM_OP_TYPE = [
2828
"send_v2",
2929
"recv_v2",
30+
"p_recv",
3031
"broadcast",
3132
"all_gather",
3233
"c_allreduce_sum",

python/paddle/distributed/auto_parallel/static/cost/comm_op_cost.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def calc_time(self):
181181

182182
@register_op_cost
183183
class RecvOpCost(CommOpCost):
184-
OP_TYPE = "recv_v2"
184+
OP_TYPE = "p_recv"
185185

186186
def __init__(self, op=None, op_desc=None, comm_context=None):
187187
super().__init__(op=op, op_desc=op_desc, comm_context=comm_context)

python/paddle/distributed/auto_parallel/static/dist_saver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs):
174174
for idx, op in enumerate(ops):
175175
if op.attr(op_role_key) != op_role_forward:
176176
continue
177-
if op.type == "read" or op.type == "feed" or op.type == 'recv_v2':
177+
if op.type == "read" or op.type == "feed" or op.type == 'p_recv':
178178
feed_vars_names += op.output("Out")
179179
if op.type == "send_v2":
180180
fetch_vars_names += op.input("X")

python/paddle/distributed/auto_parallel/static/mapper.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def is_collective_comm_op(op):
5050

5151

5252
def is_p2p_comm_op(op):
53-
comm_list = ["send_v2", "recv_v2"]
53+
comm_list = ["send_v2", "p_recv"]
5454
if op.type in comm_list:
5555
return True
5656
else:
@@ -87,7 +87,7 @@ def get_comm_volume(comm_op, src_rank, tgt_rank):
8787
if src_rank == tgt_rank:
8888
return comm_volume
8989
comm_op_type = comm_op.type
90-
if comm_op_type != "recv_v2":
90+
if comm_op_type != "p_recv":
9191
tensor_name = comm_op.input_arg_names[0]
9292
else:
9393
tensor_name = comm_op.output_arg_names[0]
@@ -128,7 +128,7 @@ def get_comm_volume(comm_op, src_rank, tgt_rank):
128128
comm_volume = tensor_bytes
129129
else:
130130
comm_volume = None
131-
elif "recv_v2" in comm_op_type:
131+
elif "p_recv" in comm_op_type:
132132
comm_volume = None
133133
else:
134134
raise ValueError("Unrecognized communication operator.")

python/paddle/distributed/auto_parallel/static/parallelizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def parallelize(
490490
if self._dist_strategy.auto_search:
491491
is_pipeline = False
492492
for op in dist_main_prog.global_block().ops:
493-
if op.type == "send_v2" or op.type == "recv_v2":
493+
if op.type == "send_v2" or op.type == "p_recv":
494494
is_pipeline = True
495495
break
496496
if is_pipeline:

python/paddle/distributed/auto_parallel/static/reshard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def insert_send_op(block, idx, tensor, src, dst, op_role, sync=True):
387387
@staticmethod
388388
def insert_recv_op(block, idx, tensor, src, dst, op_role, sync=True):
389389
"""Insert recv op into block at the given index."""
390-
op_type = 'recv_v2'
390+
op_type = 'p_recv'
391391
insert_operation = (
392392
block._insert_op if sync else block._insert_op_without_sync
393393
)

python/paddle/distributed/communication/stream/recv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _recv_in_dygraph(
4848
def _recv_in_static_mode(
4949
tensor, src_rank_in_group, group, sync_op, use_calc_stream
5050
):
51-
op_type = 'recv_v2'
51+
op_type = 'p_recv'
5252
data_feeder.check_variable_and_dtype(
5353
tensor,
5454
'tensor',

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,7 @@ def _prune_main_program(self, block, shard, rings):
10251025
"c_gen_xccl_id",
10261026
"c_comm_init",
10271027
'send_v2',
1028-
'recv_v2',
1028+
'p_recv',
10291029
]:
10301030
pass
10311031
elif op.type == "conditional_block":

python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def _insert_send_recv(cur_id, prev_id):
637637

638638
block._insert_op_without_sync(
639639
index=index + extra_index_info['index'],
640-
type='recv_v2',
640+
type='p_recv',
641641
outputs={'Out': [var]},
642642
attrs={
643643
'out_shape': var_shape,
@@ -737,7 +737,7 @@ def _insert_sendrecv_ops_in_while_block(
737737
)
738738
block._insert_op_without_sync(
739739
index=index,
740-
type='recv_v2',
740+
type='p_recv',
741741
outputs={'Out': [var]},
742742
attrs={
743743
'out_shape': var_shape,

python/paddle/distributed/passes/auto_parallel_grad_clip.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _is_pure_data_parallel(self):
265265
and not is_data_parallel_reduce_op(op)
266266
):
267267
return False
268-
if op.type in ["send_v2", "recv_v2"]:
268+
if op.type in ["send_v2", "p_recv"]:
269269
return False
270270

271271
return True

python/paddle/distributed/passes/pass_utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def _pir_overlap_send_recv(program):
447447
The finally target of this function is as follows:
448448
1. no need to insert the 'c_sync_calc' and 'c_sync_calc' operators
449449
2. 'send_v2' operator uses 'dist_attr.execution_stream' to set stream of its own.
450-
3. 'recv_v2' operator uses 'dist_attr.execution_stream' to set stream of its own.
450+
3. 'p_recv' operator uses 'dist_attr.execution_stream' to set stream of its own.
451451
"""
452452
for block in program.blocks:
453453
for op in block.ops:
@@ -457,7 +457,7 @@ def _pir_overlap_send_recv(program):
457457
ring_id = op.attrs()["ring_id"]
458458
op.set_execution_stream(f"send_stream_{ring_id}")
459459
op.set_scheduling_priority(0)
460-
elif op.name() == "pd_op.recv_v2":
460+
elif op.name() == "pd_op.p_recv":
461461
op.set_bool_attr("dynamic_shape", False)
462462
op.set_bool_attr("use_calc_stream", True)
463463
op.set_execution_stream("recv_stream")
@@ -468,7 +468,7 @@ def _insert_sync_for_fthenb_1f1b(program, dist_context=None):
468468
"""
469469
This implementation refers to lots of Paddle/python/paddle/base/optimizer.py.
470470
The difference between this function with 'PipelineOptimizer' is that
471-
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
471+
'send_v2' op and 'p_recv' op have been inserted in program by 'reshard'.
472472
"""
473473

474474
for block in program.blocks:
@@ -482,7 +482,7 @@ def _insert_sync_for_fthenb_1f1b(program, dist_context=None):
482482
# insert sync ops
483483
for index, op in enumerate(list(block.ops)):
484484
# NOTE: pipeline might hang when dynamic_shape is True
485-
if op.type in ['send_v2', 'recv_v2']:
485+
if op.type in ['send_v2', 'p_recv']:
486486
op._set_attr("dynamic_shape", False)
487487
# set send op on comm stream
488488
if op.type == 'send_v2':
@@ -556,7 +556,7 @@ def _insert_sync_for_fthenb_1f1b(program, dist_context=None):
556556
offset = 0
557557
backward_recv_index = None
558558
for index, op in enumerate(block.ops):
559-
if op.type == "recv_v2" and is_backward_op(op):
559+
if op.type == "p_recv" and is_backward_op(op):
560560
backward_recv_index = index
561561
break
562562
if backward_recv_index is None:
@@ -917,7 +917,7 @@ def _add_event_dependency(recorder_op, waiter_op):
917917
'''
918918
Add the extra event dependency of the two operators.
919919
This function mainly aims for the cross-programs in pipeline parallelism,
920-
especial for the 'send_v2' 'recv_v2' etc.
920+
especial for the 'send_v2' 'p_recv' etc.
921921
'''
922922
if not recorder_op.dist_attr.force_record_event:
923923
recorder_op.dist_attr.force_record_event = True

0 commit comments

Comments
 (0)