Skip to content

Commit 4a02850

Browse files
committed
tmp
1 parent 182c72b commit 4a02850

File tree

13 files changed

+68
-23
lines changed

13 files changed

+68
-23
lines changed

paddle/fluid/framework/new_executor/interpreter/execution_config.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
// FLAGS_force_sync_ops is used to finer control the op-sync in executor.
2828
// The format is: "micro_batch_id, job_name, op_id, op_name | micro_batch_id,
2929
// job_name, op_id, op_name | ...". Keep spaces to syncs all name/id. Example:
30-
// 1. sync the recv_v2 op in the second backward-job of 1F1B scheduling:
31-
// FLAGS_force_sync_ops="1, backward, , recv_v2"
30+
// 1. sync the p_recv op in the second backward-job of 1F1B scheduling:
31+
// FLAGS_force_sync_ops="1, backward, , p_recv"
3232
// 2. sync the full op with op_id=5: FLAGS_force_sync_ops=" , , 5, full"
3333
// 3. sync all ops in the first default-job: FLAGS_force_sync_ops="0,default,,
3434
// 4. sync all ops in the forward-job and backward-job: FLAGS_force_sync_ops=" ,

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ bool IsCommunicationOp(const ::pir::Operation* op) {
154154
op->attributes().at("op_name").dyn_cast<pir::StrAttribute>().AsString();
155155
}
156156
const std::set<std::string> special_comm_op_set = {
157-
paddle::dialect::SendV2Op::name(),
158-
paddle::dialect::RecvV2Op::name(),
159157
paddle::dialect::PSendOp::name(),
160158
paddle::dialect::PRecvOp::name(),
161159
};

paddle/fluid/framework/new_executor/interpreter/static_build.cc

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ std::set<std::string> OpsCanSkipedFakeAllocInStaticBuild = {
5454
"fetch_v2",
5555
"print",
5656
"send_v2",
57+
"p_send",
5758
"nop"};
5859

5960
std::set<std::string> StaticBuildBlackList = {

paddle/fluid/framework/new_executor/new_executor_defs.cc

+11-4
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,27 @@ void Instruction::UpdateRecordStreamForGcInfo() {
328328
need_record_stream_for_gc_ = true;
329329

330330
stream_ = reinterpret_cast<const phi::GPUContext&>(DeviceContext()).stream();
331-
// TODO(lizhiyu): Only analyse the 'send_v2' for GPT pp strategy right now.
332-
// To support all the operators for communicating in the future.
331+
// TODO(lizhiyu): Only analyse the 'p_send' for GPT pp strategy right now.
332+
// To support all the operators for communicating in the future.
333+
VLOG(0) << "enter new_executor_defs ";
333334
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
334335
auto operator_base_ptr = OpBase();
335-
if ((operator_base_ptr->Type() == "p_send") &&
336-
(operator_base_ptr->Attr<bool>("use_calc_stream") == false)) {
336+
if ((operator_base_ptr->Type() == "p_send")) {
337+
VLOG(0) << "enter new_executor_defs func";
337338
int ring_id = operator_base_ptr->Attr<int>("ring_id");
338339
if (FLAGS_dynamic_static_unified_comm) {
339340
const auto& comm_context_manager =
340341
phi::distributed::CommContextManager::GetInstance();
342+
VLOG(0) << "xxx: std::to_string(ring_id): " << std::to_string(ring_id);
343+
VLOG(0) << "xxx: distributed::comm_context_manager has: "
344+
<< comm_context_manager.Has(std::to_string(ring_id));
345+
341346
stream_ = static_cast<phi::distributed::NCCLCommContext*>(
342347
comm_context_manager.Get(std::to_string(ring_id)))
343348
->GetStream();
344349
} else {
350+
VLOG(0) << "xxx: std::to_string(ring_id): " << std::to_string(ring_id);
351+
VLOG(0) << "xxx: platform::NCCLCommContext has: ";
345352
stream_ = platform::NCCLCommContext::Instance()
346353
.Get(ring_id, DeviceContext().GetPlace())
347354
->stream();

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randperm)
3333
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReadFile)
3434
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Seed)
3535
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RecvV2)
36+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PRecv)
3637
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices)
3738
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices)
3839
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TruncatedGaussianRandom)

python/paddle/distributed/auto_parallel/static/cost/base_cost.py

+3
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ def comm_count(self):
781781
from ..reshard import get_var_with_recursion
782782

783783
if self._comm_count is None:
784+
print("---------self._comm_count is None-----------")
784785
dtype = None
785786
shape = None
786787
if self.op is not None:
@@ -802,8 +803,10 @@ def comm_count(self):
802803
dtype = var.dtype
803804
shape = var.shape
804805
elif self.op_desc is not None:
806+
print("---------self.op_desc is not None-----------")
805807
dtype = self.op_desc["inputs"]["X"][0][0]
806808
shape = self.op_desc["inputs"]["X"][0][1]
809+
print(dtype, shape)
807810

808811
factor = None
809812
if dtype == paddle.float32 or dtype == paddle.int32:

python/paddle/distributed/auto_parallel/static/reshard.py

+6
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,12 @@ def insert_send_op(block, idx, tensor, src, dst, op_role, sync=True):
377377
attrs={
378378
'ring_id': process_group.id,
379379
'peer': process_group.ranks.index(dst),
380+
'op_role': op_role,
380381
'dynamic_shape': True,
381382
},
382383
)
384+
print("------reshard------------")
385+
print(send_op)
383386
send_op._set_attr('op_namescope', "/auto_parallel/reshard")
384387

385388
@staticmethod
@@ -400,10 +403,13 @@ def insert_recv_op(block, idx, tensor, src, dst, op_role, sync=True):
400403
'ring_id': process_group.id,
401404
'peer': process_group.ranks.index(src),
402405
'dtype': tensor.dtype,
406+
'op_role': op_role,
403407
'dynamic_shape': True,
404408
},
405409
)
406410
recv_op._set_attr('op_namescope', "/auto_parallel/reshard")
411+
print("------reshard------------")
412+
print(recv_op)
407413

408414
@staticmethod
409415
def insert_reset_lod_op(block, idx, X, Y, op_role, sync=True):

python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def should_remove_op(self, op_idx):
142142
op = self._block.ops[op_idx]
143143

144144
# NOTE: At present, it is found that the OP without output is
145-
# only send_v2 and partial_send op, which will be used in
145+
# only p_send and partial_send op, which will be used in
146146
# all device
147147
if len(op.desc.output_arg_names()) == 0:
148148
return False

python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py

+13
Original file line numberDiff line numberDiff line change
@@ -620,8 +620,11 @@ def _insert_send_recv(cur_id, prev_id):
620620
attrs={
621621
'peer': 1,
622622
'ring_id': ring_id,
623+
self._op_role_key: op_role,
623624
},
624625
)
626+
print("---------hybrid_parallel_inference----------")
627+
print(send_op)
625628
extra_index_info['index'] += 1
626629
var_shape = list(var.shape)
627630
if var_shape[0] < 0:
@@ -640,8 +643,11 @@ def _insert_send_recv(cur_id, prev_id):
640643
'dtype': var.dtype,
641644
'peer': 0,
642645
'ring_id': ring_id,
646+
self._op_role_key: op_role,
643647
},
644648
)
649+
print("---------hybrid_parallel_inference----------")
650+
print(recv_op)
645651
extra_index_info['index'] += 1
646652

647653
_insert_send_recv(
@@ -712,8 +718,12 @@ def _insert_sendrecv_ops_in_while_block(
712718
attrs={
713719
'peer': 0,
714720
'ring_id': ring_id,
721+
self._op_role_key: int(self._op_role.Forward),
715722
},
716723
)
724+
print("---------hybrid_parallel_inference----------")
725+
print(send_op)
726+
717727
else:
718728
var_shape = list(var.shape)
719729
print(var_name)
@@ -731,8 +741,11 @@ def _insert_sendrecv_ops_in_while_block(
731741
'dtype': var.dtype,
732742
'peer': 1,
733743
'ring_id': ring_id,
744+
self._op_role_key: int(self._op_role.Forward),
734745
},
735746
)
747+
print("---------hybrid_parallel_inference----------")
748+
print(recv_op)
736749
index += 1
737750
block._sync_with_cpp()
738751

python/paddle/distributed/passes/pass_utils.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -471,20 +471,18 @@ def _pir_overlap_send_recv(program):
471471
This function is used to replace the function '_insert_sync_for_fthenb_1f1b'.
472472
The finally target of this function is as follows:
473473
1. no need to insert the 'c_sync_calc' and 'c_sync_calc' operators
474-
2. 'send_v2' operator uses 'dist_attr.execution_stream' to set stream of its own.
475-
3. 'recv_v2' operator uses 'dist_attr.execution_stream' to set stream of its own.
474+
2. 'p_send' operator uses 'dist_attr.execution_stream' to set stream of its own.
475+
3. 'p_recv' operator uses 'dist_attr.execution_stream' to set stream of its own.
476476
"""
477477
for block in program.blocks:
478478
for op in block.ops:
479-
if op.name() == "pd_op.send_v2":
479+
if op.name() == "pd_op.p_send":
480480
op.set_bool_attr("dynamic_shape", False)
481-
op.set_bool_attr("use_calc_stream", True)
482481
ring_id = op.attrs()["ring_id"]
483482
op.set_execution_stream(f"send_stream_{ring_id}")
484483
op.set_scheduling_priority(0)
485-
elif op.name() == "pd_op.recv_v2":
484+
elif op.name() == "pd_op.p_recv":
486485
op.set_bool_attr("dynamic_shape", False)
487-
op.set_bool_attr("use_calc_stream", True)
488486
op.set_execution_stream("recv_stream")
489487
op.set_scheduling_priority(0)
490488

python/paddle/incubate/optimizer/pipeline.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -804,10 +804,13 @@ def _insert_send_recv(cur_id, prev_id):
804804
inputs={'x': var},
805805
attrs={
806806
'peer': 1,
807+
self._op_role_key: op_role,
807808
'ring_id': ring_id,
808809
},
809810
)
810-
send_op.dist_attr.execution_stream = "default"
811+
print("---------pipeline----------")
812+
print(send_op)
813+
# send_op.dist_attr.execution_stream = "default"
811814
extra_index_info['index'] += 1
812815
var_shape = list(var.shape)
813816
var_shape[0] = (
@@ -823,9 +826,12 @@ def _insert_send_recv(cur_id, prev_id):
823826
'dtype': var.dtype,
824827
'peer': 0,
825828
'ring_id': ring_id,
829+
self._op_role_key: op_role,
826830
},
827831
)
828-
recv_op.dist_attr.execution_stream = "default"
832+
print("---------pipeline----------")
833+
print(recv_op)
834+
# recv_op.dist_attr.execution_stream = "default"
829835
extra_index_info['index'] += 1
830836
elif self.schedule_mode == '1F1B': # 1F1B
831837
var_shape = list(var.shape)
@@ -892,9 +898,12 @@ def _insert_send_recv(cur_id, prev_id):
892898
attrs={
893899
'ring_id': ring_id,
894900
'peer': 1,
901+
self._op_role_key: op_role,
895902
},
896903
)
897-
send_op.dist_attr.execution_stream = "default"
904+
print("---------pipeline----------")
905+
print(send_op)
906+
# send_op.dist_attr.execution_stream = "default"
898907
else:
899908
block._insert_op_without_sync(
900909
index=index + extra_index_info['index'],
@@ -906,6 +915,7 @@ def _insert_send_recv(cur_id, prev_id):
906915
'use_calc_stream': False,
907916
'num': self.mp_degree,
908917
'id': self.mp_rank,
918+
self._op_role_key: op_role,
909919
},
910920
)
911921
extra_index_info['index'] += 1
@@ -941,9 +951,12 @@ def _insert_send_recv(cur_id, prev_id):
941951
'dtype': var.dtype,
942952
'peer': 0,
943953
'ring_id': ring_id,
954+
self._op_role_key: op_role,
944955
},
945956
)
946-
recv_op.dist_attr.execution_stream = "default"
957+
print("---------pipeline----------")
958+
print(recv_op)
959+
# recv_op.dist_attr.execution_stream = "default"
947960
else:
948961
block._insert_op_without_sync(
949962
index=index + extra_index_info['index'],
@@ -956,6 +969,7 @@ def _insert_send_recv(cur_id, prev_id):
956969
'out_shape': var_shape,
957970
'num': self.mp_degree,
958971
'id': self.mp_rank,
972+
self._op_role_key: op_role,
959973
},
960974
)
961975
extra_index_info['index'] += 1
@@ -1622,8 +1636,11 @@ def _process_persistable_vars_in_multi_sections(
16221636
# microbatch
16231637
'peer': read_dev_index,
16241638
'ring_id': ring_id,
1639+
self._op_role_key: self._op_role.LRSched,
16251640
},
16261641
)
1642+
print("---------pipeline----------")
1643+
# print(recv_op)
16271644
read_block._insert_op(
16281645
index=0,
16291646
type='p_recv',
@@ -1634,6 +1651,7 @@ def _process_persistable_vars_in_multi_sections(
16341651
# microbatch
16351652
'peer': write_dev_index,
16361653
'ring_id': ring_id,
1654+
self._op_role_key: self._op_role.LRSched,
16371655
},
16381656
)
16391657
read_block._insert_op(

test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model_vpp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def setUp(self):
4242
"use_param_group": ["true"],
4343
"recompute": ["true"],
4444
"recompute_granularity": ["full"],
45-
"virtual_pp_degree": ["2"],
45+
"virtual_pp_degree": ["1"],
4646
}
4747

4848
def test_simple_net_hybrid_strategy(self):

tools/enforce/grep_invalid_enforce.sh

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
# This script is used to grep invalid PADDLE checks by directory or file in the paddle/fluid/,
1818
# the result show all invalid PADDLE checks in specified directory or file.
1919

20-
# Usage:
20+
# Usage:
2121
# - bash grep_invalid_enforce.sh [target directory or file] (run in tools directory)
2222
# - The default check path is paddle/fluid/operators
2323

2424
# Result Examples:
2525
# 1. grep invalid PADDLE checks in directory
2626

27-
# - Command: /work/paddle/tools {develop} bash grep_invalid_enforce.sh ../paddle/fluid/imperative
27+
# - Command: /work/paddle/tools {develop} bash grep_invalid_enforce.sh ../paddle/fluid/imperative
2828
# - Results:
2929
# - paddle/fluid/imperative/gradient_accumulator.cc
3030
# PADDLE_ENFORCE_EQ(dst_tensor->numel() == numel, true,
@@ -60,7 +60,7 @@
6060
# "Place cannot be CUDAPlace when use_double_buffer is False");
6161
# PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
6262
# PADDLE_ENFORCE_EQ(status, Status::kException);
63-
# PADDLE_ENFORCE_EQ(status, Status::kSuccess);
63+
# PADDLE_ENFORCE_EQ(status, Status::kSuccess);
6464

6565
. ./count_enforce_by_file.sh --source-only
6666

0 commit comments

Comments
 (0)