Skip to content

Commit ed2bf1e

Browse files
committed
set dynamic_shape = true
1 parent 4a02850 commit ed2bf1e

File tree

5 files changed

+26
-39
lines changed

5 files changed

+26
-39
lines changed

python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
6262
src_value,
6363
comm_group.id,
6464
comm_group.ranks.index(dst),
65-
False,
65+
True,
6666
)
6767
point = paddle.base.libpaddle.pir.get_current_insertion_point()
6868
point.prev()
@@ -91,7 +91,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
9191
comm_group.id,
9292
comm_group.ranks.index(src),
9393
dst_type.dtype,
94-
False,
94+
True,
9595
)
9696
new_op = recv_value.get_defining_op()
9797
new_op.dist_attr = (

python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -613,18 +613,17 @@ def _insert_send_recv(cur_id, prev_id):
613613
else:
614614
ring_id = self._pp_ring_map[pair_key]
615615

616-
send_op = block._insert_op_without_sync(
616+
block._insert_op_without_sync(
617617
index=index + extra_index_info['index'],
618618
type='p_send',
619619
inputs={'x': var},
620620
attrs={
621621
'peer': 1,
622622
'ring_id': ring_id,
623623
self._op_role_key: op_role,
624+
'dynamic_shape': True,
624625
},
625626
)
626-
print("---------hybrid_parallel_inference----------")
627-
print(send_op)
628627
extra_index_info['index'] += 1
629628
var_shape = list(var.shape)
630629
if var_shape[0] < 0:
@@ -635,7 +634,7 @@ def _insert_send_recv(cur_id, prev_id):
635634
else:
636635
var_shape[0] = self.micro_batch_size
637636

638-
recv_op = block._insert_op_without_sync(
637+
block._insert_op_without_sync(
639638
index=index + extra_index_info['index'],
640639
type='p_recv',
641640
outputs={'out': [var]},
@@ -644,10 +643,9 @@ def _insert_send_recv(cur_id, prev_id):
644643
'peer': 0,
645644
'ring_id': ring_id,
646645
self._op_role_key: op_role,
646+
'dynamic_shape': True,
647647
},
648648
)
649-
print("---------hybrid_parallel_inference----------")
650-
print(recv_op)
651649
extra_index_info['index'] += 1
652650

653651
_insert_send_recv(
@@ -711,18 +709,17 @@ def _insert_sendrecv_ops_in_while_block(
711709
for var_name in var_names:
712710
var = block._var_recursive(var_name)
713711
if stage == cur_id:
714-
send_op = block._insert_op_without_sync(
712+
block._insert_op_without_sync(
715713
index=index,
716714
type='p_send',
717715
inputs={'x': var},
718716
attrs={
719717
'peer': 0,
720718
'ring_id': ring_id,
721719
self._op_role_key: int(self._op_role.Forward),
720+
'dynamic_shape': True,
722721
},
723722
)
724-
print("---------hybrid_parallel_inference----------")
725-
print(send_op)
726723

727724
else:
728725
var_shape = list(var.shape)
@@ -733,7 +730,7 @@ def _insert_sendrecv_ops_in_while_block(
733730
if var_shape[0] < 0
734731
else var_shape[0]
735732
)
736-
recv_op = block._insert_op_without_sync(
733+
block._insert_op_without_sync(
737734
index=index,
738735
type='p_recv',
739736
outputs={'out': [var]},
@@ -742,10 +739,9 @@ def _insert_sendrecv_ops_in_while_block(
742739
'peer': 1,
743740
'ring_id': ring_id,
744741
self._op_role_key: int(self._op_role.Forward),
742+
'dynamic_shape': True,
745743
},
746744
)
747-
print("---------hybrid_parallel_inference----------")
748-
print(recv_op)
749745
index += 1
750746
block._sync_with_cpp()
751747

python/paddle/distributed/passes/auto_parallel_pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _insert_sync_ops_for_stream(self):
9696
for index, op in enumerate(list(block.ops)):
9797
# NOTE: pipeline might hang when dynamic_shape is True
9898
if op.type in ['p_send', 'p_recv']:
99-
op._set_attr("dynamic_shape", False)
99+
op._set_attr("dynamic_shape", True)
100100
# set send op on comm stream
101101
if op.type == 'p_send':
102102
op_role = op.attr('op_role')

python/paddle/distributed/passes/pass_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -477,12 +477,12 @@ def _pir_overlap_send_recv(program):
477477
for block in program.blocks:
478478
for op in block.ops:
479479
if op.name() == "pd_op.p_send":
480-
op.set_bool_attr("dynamic_shape", False)
480+
op.set_bool_attr("dynamic_shape", True)
481481
ring_id = op.attrs()["ring_id"]
482482
op.set_execution_stream(f"send_stream_{ring_id}")
483483
op.set_scheduling_priority(0)
484484
elif op.name() == "pd_op.p_recv":
485-
op.set_bool_attr("dynamic_shape", False)
485+
op.set_bool_attr("dynamic_shape", True)
486486
op.set_execution_stream("recv_stream")
487487
op.set_scheduling_priority(0)
488488

@@ -506,7 +506,7 @@ def _insert_sync_for_fthenb_1f1b(program, dist_context=None):
506506
for index, op in enumerate(list(block.ops)):
507507
# NOTE: pipeline might hang when dynamic_shape is True
508508
if op.type in ['p_send', 'p_recv']:
509-
op._set_attr("dynamic_shape", False)
509+
op._set_attr("dynamic_shape", True)
510510
# set send op on comm stream
511511
if op.type == 'p_send':
512512
op_role = op.attr('op_role')
@@ -616,12 +616,12 @@ def _overlap_send_recv(program):
616616
for block in program.blocks:
617617
for op in block.ops:
618618
if op.type == 'p_send':
619-
op._set_attr("dynamic_shape", False)
619+
op._set_attr("dynamic_shape", True)
620620
ring_id = op.attr("ring_id")
621621
op.dist_attr.execution_stream = "send_stream_" + str(ring_id)
622622
op.dist_attr.stream_priority = 0
623623
elif op.type == 'p_recv':
624-
op._set_attr("dynamic_shape", False)
624+
op._set_attr("dynamic_shape", True)
625625
op.dist_attr.execution_stream = "recv_stream"
626626
op.dist_attr.stream_priority = 0
627627
else:

python/paddle/incubate/optimizer/pipeline.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -798,27 +798,25 @@ def _insert_send_recv(cur_id, prev_id):
798798
ring_id = self._pp_ring_map[pair_key]
799799

800800
if self.schedule_mode == 'F-then-B': # F-then-B
801-
send_op = block._insert_op_without_sync(
801+
block._insert_op_without_sync(
802802
index=index + extra_index_info['index'],
803803
type='p_send',
804804
inputs={'x': var},
805805
attrs={
806806
'peer': 1,
807807
self._op_role_key: op_role,
808808
'ring_id': ring_id,
809+
'dynamic_shape': True,
809810
},
810811
)
811-
print("---------pipeline----------")
812-
print(send_op)
813-
# send_op.dist_attr.execution_stream = "default"
814812
extra_index_info['index'] += 1
815813
var_shape = list(var.shape)
816814
var_shape[0] = (
817815
self.micro_batch_size
818816
if var_shape[0] < 0
819817
else var_shape[0]
820818
)
821-
recv_op = block._insert_op_without_sync(
819+
block._insert_op_without_sync(
822820
index=index + extra_index_info['index'],
823821
type='p_recv',
824822
outputs={'out': [var]},
@@ -827,11 +825,9 @@ def _insert_send_recv(cur_id, prev_id):
827825
'peer': 0,
828826
'ring_id': ring_id,
829827
self._op_role_key: op_role,
828+
'dynamic_shape': True,
830829
},
831830
)
832-
print("---------pipeline----------")
833-
print(recv_op)
834-
# recv_op.dist_attr.execution_stream = "default"
835831
extra_index_info['index'] += 1
836832
elif self.schedule_mode == '1F1B': # 1F1B
837833
var_shape = list(var.shape)
@@ -891,19 +887,17 @@ def _insert_send_recv(cur_id, prev_id):
891887
True if isinstance(prefix_var, Parameter) else False
892888
)
893889
if not use_mp or is_param:
894-
send_op = block._insert_op_without_sync(
890+
block._insert_op_without_sync(
895891
index=index + extra_index_info['index'],
896892
type='p_send',
897893
inputs={'x': var},
898894
attrs={
899895
'ring_id': ring_id,
900896
'peer': 1,
901897
self._op_role_key: op_role,
898+
'dynamic_shape': True,
902899
},
903900
)
904-
print("---------pipeline----------")
905-
print(send_op)
906-
# send_op.dist_attr.execution_stream = "default"
907901
else:
908902
block._insert_op_without_sync(
909903
index=index + extra_index_info['index'],
@@ -943,7 +937,7 @@ def _insert_send_recv(cur_id, prev_id):
943937
sync_comm_op._set_attr('pipeline_flag', '')
944938
extra_index_info['index'] += 1
945939
if not use_mp or is_param:
946-
recv_op = block._insert_op_without_sync(
940+
block._insert_op_without_sync(
947941
index=index + extra_index_info['index'],
948942
type='p_recv',
949943
outputs={'out': [var]},
@@ -952,11 +946,9 @@ def _insert_send_recv(cur_id, prev_id):
952946
'peer': 0,
953947
'ring_id': ring_id,
954948
self._op_role_key: op_role,
949+
'dynamic_shape': True,
955950
},
956951
)
957-
print("---------pipeline----------")
958-
print(recv_op)
959-
# recv_op.dist_attr.execution_stream = "default"
960952
else:
961953
block._insert_op_without_sync(
962954
index=index + extra_index_info['index'],
@@ -984,7 +976,6 @@ def _insert_send_recv(cur_id, prev_id):
984976
self._op_role_key: op_role,
985977
'use_calc_stream': True,
986978
'ring_id': 0,
987-
# if p_recv, num&id attr is not in op_attrs, will not insert
988979
'nranks': self.mp_degree,
989980
'rank': self.mp_rank,
990981
},
@@ -1637,10 +1628,9 @@ def _process_persistable_vars_in_multi_sections(
16371628
'peer': read_dev_index,
16381629
'ring_id': ring_id,
16391630
self._op_role_key: self._op_role.LRSched,
1631+
'dynamic_shape': True,
16401632
},
16411633
)
1642-
print("---------pipeline----------")
1643-
# print(recv_op)
16441634
read_block._insert_op(
16451635
index=0,
16461636
type='p_recv',
@@ -1652,6 +1642,7 @@ def _process_persistable_vars_in_multi_sections(
16521642
'peer': write_dev_index,
16531643
'ring_id': ring_id,
16541644
self._op_role_key: self._op_role.LRSched,
1645+
'dynamic_shape': True,
16551646
},
16561647
)
16571648
read_block._insert_op(

0 commit comments

Comments
 (0)