Skip to content

Commit cae678a

Browse files
committed
Fix
1 parent 80edc35 commit cae678a

File tree

10 files changed

+33
-34
lines changed

10 files changed

+33
-34
lines changed

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ bool IsCommunicationOp(const ::pir::Operation* op) {
169169
}
170170
const std::set<std::string> special_comm_op_set = {
171171
paddle::dialect::SendV2Op::name(),
172-
paddle::dialect::RecvV2Op::name(),
172+
paddle::dialect::PRecvOp::name(),
173173
};
174174
const std::string communication_op_prefix = "c_";
175175
if (op_name.find(communication_op_prefix) != std::string::npos ||

paddle/fluid/pir/dialect/operator/utils/utils.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace dialect {
3939
const std::unordered_set<std::string> LegacyOpList = {
4040
DistributedPushSparseOp::name(),
4141
SendV2Op::name(),
42-
RecvV2Op::name(),
42+
PRecv::name(),
4343
CAllreduceSumOp::name(),
4444
CAllreduceSum_Op::name(),
4545
};

paddle/phi/ops/yaml/inconsistent/static_ops.yaml

+12-12
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,18 @@
597597
interfaces : paddle::dialect::InferSymbolicShapeInterface
598598
traits : paddle::dialect::ForwardOnlyTrait
599599

600+
- op : p_recv
601+
args : (int[] out_shape = {}, DataType dtype = DataType::FLOAT32, int peer = 0, int ring_id = 0, bool use_calc_stream = false, bool dynamic_shape = false)
602+
output : Tensor(out)
603+
infer_meta:
604+
func: RecvV2InferMeta
605+
param: [ring_id, dynamic_shape, peer, out_shape, dtype]
606+
kernel :
607+
func : recv_v2
608+
param : [ring_id, dynamic_shape, peer, out_shape, dtype, use_calc_stream]
609+
data_type : dtype
610+
interfaces : paddle::dialect::InferSymbolicShapeInterface
611+
600612
- op : partial_recv
601613
args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, int num = 1, int id = 0)
602614
output : Tensor(out)
@@ -699,18 +711,6 @@
699711
inplace : (scale -> out_scale, in_accum -> out_accum, in_state -> out_state)
700712
interfaces : paddle::dialect::InferSymbolicShapeInterface
701713

702-
- op : recv_v2
703-
args : (int[] out_shape = {}, DataType dtype = DataType::FLOAT32, int peer = 0, int ring_id = 0, bool use_calc_stream = false, bool dynamic_shape = false)
704-
output : Tensor(out)
705-
infer_meta:
706-
func: RecvV2InferMeta
707-
param: [ring_id, dynamic_shape, peer, out_shape, dtype]
708-
kernel :
709-
func : recv_v2
710-
param : [ring_id, dynamic_shape, peer, out_shape, dtype, use_calc_stream]
711-
data_type : dtype
712-
interfaces : paddle::dialect::InferSymbolicShapeInterface
713-
714714
- op : remainder
715715
args : (Tensor x, Tensor y)
716716
output : Tensor (out)

paddle/phi/ops/yaml/op_compat.yaml

+8
Original file line numberDiff line numberDiff line change
@@ -4469,6 +4469,14 @@
44694469
outputs :
44704470
out : Out
44714471

4472+
- op: p_recv
4473+
outputs :
4474+
out : Out
4475+
4476+
- op: p_send
4477+
inputs :
4478+
x : X
4479+
44724480
- op: partial_send
44734481
inputs :
44744482
x : X

python/paddle/distributed/auto_parallel/static/reshard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def insert_send_op(block, idx, tensor, src, dst, op_role, sync=True):
387387
@staticmethod
388388
def insert_recv_op(block, idx, tensor, src, dst, op_role, sync=True):
389389
"""Insert recv op into block at the given index."""
390-
op_type = 'recv_v2'
390+
op_type = 'p_recv'
391391
insert_operation = (
392392
block._insert_op if sync else block._insert_op_without_sync
393393
)

test/collective/collective_sendrecv_api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def send_new(tensor, dst, group=None, sync_op=True):
4747
helper = framework.LayerHelper(op_type, **locals())
4848
helper.append_op(
4949
type=op_type,
50-
inputs={'x': [tensor]},
50+
inputs={'X': [tensor]},
5151
attrs={
5252
'ring_id': ring_id,
5353
'peer': dst,
@@ -78,7 +78,7 @@ def recv_new(tensor, src, group=None, sync_op=True, dtype='float32'):
7878
helper = framework.LayerHelper(op_type, **locals())
7979
helper.append_op(
8080
type=op_type,
81-
outputs={'out': [tensor]},
81+
outputs={'Out': [tensor]},
8282
attrs={
8383
'ring_id': ring_id,
8484
'peer': src,

test/collective/collective_sendrecv_op.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def get_model(self, main_prog, startup_program):
3838
tindata.desc.set_need_check_feed(False)
3939
if self.rank == 0:
4040
main_prog.global_block().append_op(
41-
type="send_v2",
41+
type="p_send",
4242
inputs={'X': tindata},
4343
attrs={
4444
'ring_id': ring_id,
@@ -48,7 +48,7 @@ def get_model(self, main_prog, startup_program):
4848
)
4949
else:
5050
main_prog.global_block().append_op(
51-
type="recv_v2",
51+
type="p_recv",
5252
outputs={'Out': tindata},
5353
attrs={
5454
'peer': 0,

test/deprecated/legacy_test/test_auto_parallel_reshard_deprecated.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
223223
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
224224
send_result = True
225225
if (
226-
op.type == "recv_v2"
226+
op.type == "p_recv"
227227
and "gelu_0.tmp_0@GRAD" in op.output_arg_names[0]
228228
):
229229
recv_result = True
@@ -234,10 +234,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
234234
and "gelu_0.tmp_0@GRAD" in op.input_arg_names
235235
):
236236
send_result = True
237-
if (
238-
op.type == "recv_v2"
239-
and "gelu_0.tmp_0" in op.output_arg_names[0]
240-
):
237+
if op.type == "p_recv" and "gelu_0.tmp_0" in op.output_arg_names[0]:
241238
recv_result = True
242239

243240
return send_result and recv_result

test/deprecated/legacy_test/test_auto_parallel_reshard_dpmppp_deprecated.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
163163
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
164164
send_result = True
165165
if (
166-
op.type == "recv_v2"
166+
op.type == "p_recv"
167167
and "gelu_0.tmp_0@GRAD" in op.output_arg_names[0]
168168
):
169169
recv_result = True
@@ -174,10 +174,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
174174
and "gelu_0.tmp_0@GRAD" in op.input_arg_names
175175
):
176176
send_result = True
177-
if (
178-
op.type == "recv_v2"
179-
and "gelu_0.tmp_0" in op.output_arg_names[0]
180-
):
177+
if op.type == "p_recv" and "gelu_0.tmp_0" in op.output_arg_names[0]:
181178
recv_result = True
182179

183180
return send_result and recv_result

test/deprecated/legacy_test/test_auto_parallel_reshard_mppp_deprecated.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
176176
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
177177
send_result = True
178178
if (
179-
op.type == "recv_v2"
179+
op.type == "p_recv"
180180
and "gelu_0.tmp_0@GRAD" in op.output_arg_names[0]
181181
):
182182
recv_result = True
@@ -187,10 +187,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
187187
and "gelu_0.tmp_0@GRAD" in op.input_arg_names[0]
188188
):
189189
send_result = True
190-
if (
191-
op.type == "recv_v2"
192-
and "gelu_0.tmp_0" in op.output_arg_names[0]
193-
):
190+
if op.type == "p_recv" and "gelu_0.tmp_0" in op.output_arg_names[0]:
194191
recv_result = True
195192

196193
return send_result and recv_result

0 commit comments

Comments
 (0)