Skip to content

Commit 56d4e92

Browse files
authored
【Comm】Fix new comm_ctx and dev_ctx in dy (#68610)
1 parent d711d9e commit 56d4e92

File tree

13 files changed

+376
-363
lines changed

13 files changed

+376
-363
lines changed

paddle/fluid/distributed/collective/process_group_nccl.cc

+11
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,17 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
217217
return iter->second->nccl_comm();
218218
}
219219

220+
phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetOrCreateCommContext(
221+
const Place& place, CommType comm_type) {
222+
const auto& key = GetKeyFromPlace(place);
223+
std::string store_key;
224+
GetStoreKey(key, comm_type, &store_key);
225+
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
226+
CreateNCCLEnvCache(place, key, store_key, comm_type);
227+
}
228+
return GetCommContext(&store_key);
229+
}
230+
220231
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
221232
phi::DenseTensor* out_tensor,
222233
const phi::DenseTensor& in_tensor,

paddle/fluid/distributed/collective/process_group_nccl.h

+3
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
181181

182182
const bool GetNCCLCommInitOption() { return nccl_comm_init_option_; }
183183

184+
phi::distributed::NCCLCommContext* GetOrCreateCommContext(
185+
const Place& place, CommType comm_type = CommType::UNKNOWN);
186+
184187
private:
185188
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
186189
int rank,

paddle/fluid/imperative/CMakeLists.txt

+16
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ if(WITH_XPU)
2121
phi
2222
common
2323
var_helper)
24+
elseif((WITH_GPU OR WITH_ROCM) AND NOT WIN32)
25+
cc_library(
26+
prepared_operator
27+
SRCS prepared_operator.cc
28+
DEPS proto_desc
29+
operator
30+
device_context
31+
lod_tensor
32+
selected_rows_utils
33+
var_type_traits
34+
op_kernel_type
35+
data_transform
36+
phi
37+
common
38+
var_helper
39+
process_group_nccl)
2440
else()
2541
cc_library(
2642
prepared_operator

paddle/fluid/imperative/prepared_operator.cc

+41
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
#ifdef PADDLE_WITH_DNNL
2929
#include "paddle/phi/core/platform/onednn_op_list.h"
3030
#endif
31+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
32+
#include "paddle/fluid/distributed/collective/process_group.h"
33+
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
34+
#endif
3135
#include "paddle/common/flags.h"
3236
#include "paddle/fluid/framework/library_type.h"
3337
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
@@ -296,6 +300,43 @@ PreparedOp PrepareImpl(
296300
phi::TransToPhiBackend(dev_ctx->GetPlace()))) {
297301
dev_ctx = pool.Get(phi::TransToPhiPlace(expected_kernel_key.backend()));
298302
}
303+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
304+
if (attrs.find("ring_id") != attrs.end()) {
305+
auto ring_id_attr = attrs.at("ring_id");
306+
int ring_id = PADDLE_GET(int, ring_id_attr);
307+
auto map = distributed::ProcessGroupMapFromGid::getInstance();
308+
if (map->has(ring_id)) {
309+
distributed::ProcessGroup* pg = map->get(ring_id);
310+
auto comm_context =
311+
static_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
312+
->GetOrCreateCommContext(place);
313+
auto original_stream =
314+
static_cast<phi::GPUContext*>(dev_ctx)->cuda_stream();
315+
dev_ctx =
316+
static_cast<phi::distributed::NCCLCommContext*>(comm_context)
317+
->GetDevContext();
318+
dev_ctx->SetCommContext(comm_context);
319+
// Note(lizhenxing): In dynamic mode, c_softmax_with_cross_entropy
320+
// need use global calculate stream (original_stream). Using the
321+
// comm_ctx's stream will lead to synchronization issues, causing
322+
// accuracy diff in test_parallel_dygraph_mp_layers.
323+
if (phi::is_gpu_place(place) &&
324+
((attrs.find("use_calc_stream") != attrs.end() &&
325+
PADDLE_GET_CONST(bool, attrs.at("use_calc_stream"))) ||
326+
phi_kernel_name == "c_softmax_with_cross_entropy")) {
327+
static_cast<phi::GPUContext*>(dev_ctx)->SetCUDAStream(
328+
original_stream, false);
329+
auto& instance =
330+
paddle::memory::allocation::AllocatorFacade::Instance();
331+
dev_ctx->SetAllocator(
332+
instance
333+
.GetAllocator(
334+
place, static_cast<phi::GPUContext*>(dev_ctx)->stream())
335+
.get());
336+
}
337+
}
338+
}
339+
#endif
299340
return PreparedOp(op,
300341
empty_ctx,
301342
expected_kernel_key,

paddle/fluid/operators/collective/c_concat_op.cc

-10
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,3 @@ REGISTER_OPERATOR(c_concat,
117117
ops::CConcatOpGradMaker<paddle::framework::OpDesc>,
118118
ops::CConcatOpGradMaker<paddle::imperative::OpBase>,
119119
ops::CConcatOpMaker);
120-
121-
PD_REGISTER_STRUCT_KERNEL(c_concat,
122-
CPU,
123-
ALL_LAYOUT,
124-
ops::CConcatOpCPUKernel,
125-
float,
126-
double,
127-
int,
128-
int64_t,
129-
phi::dtype::float16) {}

paddle/fluid/operators/collective/c_concat_op.cu.cc

-181
This file was deleted.

0 commit comments

Comments
 (0)