Skip to content

Commit 1d015f1

Browse files
Add enable_partial_send_recv switch in pipeline_configs (#46992) (#47083)
* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result is wrong. * Support allow_partial switch, which can be configure in pipeline_configs. If sent tensor are not the same from different hosts, they shouldn't been sent partially and then concated as a whole tensor. * Change name allow_partial to enable_partial_send_recv. * Add global variable _enable_partial_send_recv
1 parent 69515e9 commit 1d015f1

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

paddle/fluid/framework/distributed_strategy.proto

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ message PipelineConfig {
177177
optional int32 accumulate_steps = 2 [ default = 1 ];
178178
optional string schedule_mode = 3 [ default = '1F1B' ];
179179
optional bool p2p_cache_shape = 4 [ default = true ];
180+
optional bool enable_partial_send_recv = 5 [ default = true ];
180181
}
181182

182183
message TensorParallelConfig {

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def __init__(self, layers, hcg, strategy):
4646
'micro_batch_size']
4747
self.accumulate_steps = self._strategy.pipeline_configs[
4848
'accumulate_steps']
49-
49+
# If sent tensor are not the same from different hosts,
50+
# they shouldn't been sent partially and then concated as a whole tensor.
51+
self._enable_partial_send_recv = self._strategy.pipeline_configs[
52+
'enable_partial_send_recv']
5053
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']
5154

5255
self.num_stages = self._hcg.get_pipe_parallel_world_size()
@@ -58,7 +61,8 @@ def __init__(self, layers, hcg, strategy):
5861
self._real_pp_world_size = self.num_stages
5962
self._real_pp_rank = self.stage_id
6063

61-
p2p.initialize_p2p_groups(hcg, self._using_cache)
64+
p2p.initialize_p2p_groups(hcg, self._using_cache,
65+
self._enable_partial_send_recv)
6266

6367
self.global_rank = self._hcg.get_global_rank()
6468
self.micro_batch_id = 0

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222

2323
_hcg = None
2424
_use_cache = False
25+
_enable_partial_send_recv = True
2526

2627

27-
def initialize_p2p_groups(hcg, use_cache=True):
28-
global _hcg, _use_cache
28+
def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
29+
global _hcg, _use_cache, _enable_partial_send_recv
2930
_hcg = hcg
3031
_use_cache = use_cache
32+
_enable_partial_send_recv = enable_partial_send_recv
3133
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups(
3234
)
3335

@@ -157,7 +159,8 @@ def set_send_message(self, tensor):
157159

158160

159161
def _is_valid_send_recv_partial(tensor, mp_degree):
160-
162+
if not _enable_partial_send_recv:
163+
return False
161164
tensor_numel = np.prod(tensor.shape)
162165
assert tensor_numel != 0, "can't send/recv zero element"
163166
return mp_degree > 1 and tensor_numel % mp_degree == 0

0 commit comments

Comments
 (0)