Skip to content

Commit 512aa63

Browse files
committed
set dynamic_shape = true
1 parent b0c40a1 commit 512aa63

File tree

5 files changed

+26
-39
lines changed

5 files changed

+26
-39
lines changed

python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
101101
src_value,
102102
comm_group.id,
103103
comm_group.ranks.index(dst),
104-
False,
104+
True,
105105
)
106106
point = paddle.base.libpaddle.pir.get_current_insertion_point()
107107
point.prev()
@@ -130,7 +130,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
130130
comm_group.id,
131131
comm_group.ranks.index(src),
132132
dst_type.dtype,
133-
False,
133+
True,
134134
)
135135
new_op = recv_value.get_defining_op()
136136
new_op.dist_attr = (

python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -613,18 +613,17 @@ def _insert_send_recv(cur_id, prev_id):
613613
else:
614614
ring_id = self._pp_ring_map[pair_key]
615615

616-
send_op = block._insert_op_without_sync(
616+
block._insert_op_without_sync(
617617
index=index + extra_index_info['index'],
618618
type='p_send',
619619
inputs={'x': var},
620620
attrs={
621621
'peer': 1,
622622
'ring_id': ring_id,
623623
self._op_role_key: op_role,
624+
'dynamic_shape': True,
624625
},
625626
)
626-
print("---------hybrid_parallel_inference----------")
627-
print(send_op)
628627
extra_index_info['index'] += 1
629628
var_shape = list(var.shape)
630629
if var_shape[0] < 0:
@@ -635,7 +634,7 @@ def _insert_send_recv(cur_id, prev_id):
635634
else:
636635
var_shape[0] = self.micro_batch_size
637636

638-
recv_op = block._insert_op_without_sync(
637+
block._insert_op_without_sync(
639638
index=index + extra_index_info['index'],
640639
type='p_recv',
641640
outputs={'out': [var]},
@@ -644,10 +643,9 @@ def _insert_send_recv(cur_id, prev_id):
644643
'peer': 0,
645644
'ring_id': ring_id,
646645
self._op_role_key: op_role,
646+
'dynamic_shape': True,
647647
},
648648
)
649-
print("---------hybrid_parallel_inference----------")
650-
print(recv_op)
651649
extra_index_info['index'] += 1
652650

653651
_insert_send_recv(
@@ -711,18 +709,17 @@ def _insert_sendrecv_ops_in_while_block(
711709
for var_name in var_names:
712710
var = block._var_recursive(var_name)
713711
if stage == cur_id:
714-
send_op = block._insert_op_without_sync(
712+
block._insert_op_without_sync(
715713
index=index,
716714
type='p_send',
717715
inputs={'x': var},
718716
attrs={
719717
'peer': 0,
720718
'ring_id': ring_id,
721719
self._op_role_key: int(self._op_role.Forward),
720+
'dynamic_shape': True,
722721
},
723722
)
724-
print("---------hybrid_parallel_inference----------")
725-
print(send_op)
726723

727724
else:
728725
var_shape = list(var.shape)
@@ -733,7 +730,7 @@ def _insert_sendrecv_ops_in_while_block(
733730
if var_shape[0] < 0
734731
else var_shape[0]
735732
)
736-
recv_op = block._insert_op_without_sync(
733+
block._insert_op_without_sync(
737734
index=index,
738735
type='p_recv',
739736
outputs={'out': [var]},
@@ -742,10 +739,9 @@ def _insert_sendrecv_ops_in_while_block(
742739
'peer': 1,
743740
'ring_id': ring_id,
744741
self._op_role_key: int(self._op_role.Forward),
742+
'dynamic_shape': True,
745743
},
746744
)
747-
print("---------hybrid_parallel_inference----------")
748-
print(recv_op)
749745
index += 1
750746
block._sync_with_cpp()
751747

python/paddle/distributed/passes/auto_parallel_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _insert_sync_ops_for_stream(self):
9696
for index, op in enumerate(list(block.ops)):
9797
# NOTE: pipeline might hang when dynamic_shape is True
9898
if op.type in ['p_send', 'p_recv']:
99-
op._set_attr("dynamic_shape", False)
99+
op._set_attr("dynamic_shape", True)
100100
# set send op on comm stream
101101
if op.type == 'p_send':
102102
op_role = op.attr('op_role')

python/paddle/distributed/passes/pass_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,12 @@ def _pir_overlap_send_recv(program):
476476
for block in program.blocks:
477477
for op in block.ops:
478478
if op.name() == "pd_op.p_send":
479-
op.set_bool_attr("dynamic_shape", False)
479+
op.set_bool_attr("dynamic_shape", True)
480480
ring_id = op.attrs()["ring_id"]
481481
op.set_execution_stream(f"send_stream_{ring_id}")
482482
op.set_scheduling_priority(0)
483483
elif op.name() == "pd_op.p_recv":
484-
op.set_bool_attr("dynamic_shape", False)
484+
op.set_bool_attr("dynamic_shape", True)
485485
op.set_execution_stream("recv_stream")
486486
op.set_scheduling_priority(0)
487487

@@ -505,7 +505,7 @@ def _insert_sync_for_fthenb_1f1b(program, dist_context=None):
505505
for index, op in enumerate(list(block.ops)):
506506
# NOTE: pipeline might hang when dynamic_shape is True
507507
if op.type in ['p_send', 'p_recv']:
508-
op._set_attr("dynamic_shape", False)
508+
op._set_attr("dynamic_shape", True)
509509
# set send op on comm stream
510510
if op.type == 'p_send':
511511
op_role = op.attr('op_role')
@@ -606,12 +606,12 @@ def _overlap_send_recv(program):
606606
for block in program.blocks:
607607
for op in block.ops:
608608
if op.type == 'p_send':
609-
op._set_attr("dynamic_shape", False)
609+
op._set_attr("dynamic_shape", True)
610610
ring_id = op.attr("ring_id")
611611
op.dist_attr.execution_stream = "send_stream_" + str(ring_id)
612612
op.dist_attr.stream_priority = 0
613613
elif op.type == 'p_recv':
614-
op._set_attr("dynamic_shape", False)
614+
op._set_attr("dynamic_shape", True)
615615
op.dist_attr.execution_stream = "recv_stream"
616616
op.dist_attr.stream_priority = 0
617617
else:

python/paddle/incubate/optimizer/pipeline.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -798,27 +798,25 @@ def _insert_send_recv(cur_id, prev_id):
798798
ring_id = self._pp_ring_map[pair_key]
799799

800800
if self.schedule_mode == 'F-then-B': # F-then-B
801-
send_op = block._insert_op_without_sync(
801+
block._insert_op_without_sync(
802802
index=index + extra_index_info['index'],
803803
type='p_send',
804804
inputs={'x': var},
805805
attrs={
806806
'peer': 1,
807807
self._op_role_key: op_role,
808808
'ring_id': ring_id,
809+
'dynamic_shape': True,
809810
},
810811
)
811-
print("---------pipeline----------")
812-
print(send_op)
813-
# send_op.dist_attr.execution_stream = "default"
814812
extra_index_info['index'] += 1
815813
var_shape = list(var.shape)
816814
var_shape[0] = (
817815
self.micro_batch_size
818816
if var_shape[0] < 0
819817
else var_shape[0]
820818
)
821-
recv_op = block._insert_op_without_sync(
819+
block._insert_op_without_sync(
822820
index=index + extra_index_info['index'],
823821
type='p_recv',
824822
outputs={'out': [var]},
@@ -827,11 +825,9 @@ def _insert_send_recv(cur_id, prev_id):
827825
'peer': 0,
828826
'ring_id': ring_id,
829827
self._op_role_key: op_role,
828+
'dynamic_shape': True,
830829
},
831830
)
832-
print("---------pipeline----------")
833-
print(recv_op)
834-
# recv_op.dist_attr.execution_stream = "default"
835831
extra_index_info['index'] += 1
836832
elif self.schedule_mode == '1F1B': # 1F1B
837833
var_shape = list(var.shape)
@@ -891,19 +887,17 @@ def _insert_send_recv(cur_id, prev_id):
891887
True if isinstance(prefix_var, Parameter) else False
892888
)
893889
if not use_mp or is_param:
894-
send_op = block._insert_op_without_sync(
890+
block._insert_op_without_sync(
895891
index=index + extra_index_info['index'],
896892
type='p_send',
897893
inputs={'x': var},
898894
attrs={
899895
'ring_id': ring_id,
900896
'peer': 1,
901897
self._op_role_key: op_role,
898+
'dynamic_shape': True,
902899
},
903900
)
904-
print("---------pipeline----------")
905-
print(send_op)
906-
# send_op.dist_attr.execution_stream = "default"
907901
else:
908902
block._insert_op_without_sync(
909903
index=index + extra_index_info['index'],
@@ -943,7 +937,7 @@ def _insert_send_recv(cur_id, prev_id):
943937
sync_comm_op._set_attr('pipeline_flag', '')
944938
extra_index_info['index'] += 1
945939
if not use_mp or is_param:
946-
recv_op = block._insert_op_without_sync(
940+
block._insert_op_without_sync(
947941
index=index + extra_index_info['index'],
948942
type='p_recv',
949943
outputs={'out': [var]},
@@ -952,11 +946,9 @@ def _insert_send_recv(cur_id, prev_id):
952946
'peer': 0,
953947
'ring_id': ring_id,
954948
self._op_role_key: op_role,
949+
'dynamic_shape': True,
955950
},
956951
)
957-
print("---------pipeline----------")
958-
print(recv_op)
959-
# recv_op.dist_attr.execution_stream = "default"
960952
else:
961953
block._insert_op_without_sync(
962954
index=index + extra_index_info['index'],
@@ -984,7 +976,6 @@ def _insert_send_recv(cur_id, prev_id):
984976
self._op_role_key: op_role,
985977
'use_calc_stream': True,
986978
'ring_id': 0,
987-
# if p_recv, num&id attr is not in op_attrs, will not insert
988979
'nranks': self.mp_degree,
989980
'rank': self.mp_rank,
990981
},
@@ -1637,10 +1628,9 @@ def _process_persistable_vars_in_multi_sections(
16371628
'peer': read_dev_index,
16381629
'ring_id': ring_id,
16391630
self._op_role_key: self._op_role.LRSched,
1631+
'dynamic_shape': True,
16401632
},
16411633
)
1642-
print("---------pipeline----------")
1643-
# print(recv_op)
16441634
read_block._insert_op(
16451635
index=0,
16461636
type='p_recv',
@@ -1652,6 +1642,7 @@ def _process_persistable_vars_in_multi_sections(
16521642
'peer': write_dev_index,
16531643
'ring_id': ring_id,
16541644
self._op_role_key: self._op_role.LRSched,
1645+
'dynamic_shape': True,
16551646
},
16561647
)
16571648
read_block._insert_op(

0 commit comments

Comments
 (0)