Skip to content

Commit ce8acc0

Browse files
authored
【Comm】Fix update comm flags (#67005)
1 parent 7311ff7 commit ce8acc0

13 files changed

+38
-67
lines changed

paddle/fluid/operators/collective/barrier_op.h

+19-8
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ limitations under the License. */
2222
#include "paddle/fluid/framework/data_type.h"
2323
#include "paddle/fluid/framework/lod_tensor.h"
2424
#include "paddle/fluid/framework/op_registry.h"
25+
#include "paddle/phi/core/distributed/comm_context_manager.h"
2526

2627
#if defined(PADDLE_WITH_GLOO)
2728
#include <gloo/barrier.h>
2829

2930
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
31+
#include "paddle/phi/core/distributed/gloo_comm_context.h"
3032
#endif
3133

3234
namespace paddle {
@@ -37,14 +39,23 @@ class BarrierOpCPUKernel : public framework::OpKernel<T> {
3739
public:
3840
void Compute(const framework::ExecutionContext& ctx) const override {
3941
#if defined(PADDLE_WITH_GLOO)
40-
auto gloo = paddle::framework::GlooWrapper::GetInstance();
41-
PADDLE_ENFORCE_EQ(
42-
gloo->IsInitialized(),
43-
true,
44-
common::errors::PreconditionNotMet(
45-
"You must initialize the gloo environment first to use it."));
46-
gloo::BarrierOptions opts(gloo->GetContext());
47-
gloo::barrier(opts);
42+
int rid = ctx.Attr<int>("ring_id");
43+
const auto& comm_context_manager =
44+
phi::distributed::CommContextManager::GetInstance();
45+
if (comm_context_manager.Has(std::to_string(rid))) {
46+
auto* comm_context = static_cast<phi::distributed::GlooCommContext*>(
47+
comm_context_manager.Get(std::to_string(rid)));
48+
comm_context->Barrier();
49+
} else {
50+
auto gloo = paddle::framework::GlooWrapper::GetInstance();
51+
PADDLE_ENFORCE_EQ(
52+
gloo->IsInitialized(),
53+
true,
54+
common::errors::PreconditionNotMet(
55+
"You must initialize the gloo environment first to use it."));
56+
gloo::BarrierOptions opts(gloo->GetContext());
57+
gloo::barrier(opts);
58+
}
4859
#else
4960
PADDLE_THROW(common::errors::Unavailable(
5061
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));

paddle/fluid/operators/collective/send_v2_op.cu.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
222222
} else {
223223
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
224224
x.data<T>(), numel, dtype, peer, comm->comm(), stream));
225+
VLOG(3) << "rank " << comm->rank() << " send "
226+
<< common::product(x.dims()) << " to " << peer;
225227
}
226-
VLOG(3) << "rank " << comm->rank() << " send "
227-
<< common::product(x.dims()) << " to " << peer;
228228
}
229229
return;
230230
}

test/collective/parallel_embedding_api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class TestParallelEmbeddingAPI(TestCollectiveAPIRunnerBase):
2929
def __init__(self):
3030
self.global_ring_id = 0
3131

32-
def get_model(self, main_prog, startup_program, rank):
32+
def get_model(self, main_prog, startup_program, rank, dtype="float32"):
3333
with base.program_guard(main_prog, startup_program):
3434
fleet.init(is_collective=True)
3535
np.random.seed(2020)
@@ -40,7 +40,7 @@ def get_model(self, main_prog, startup_program, rank):
4040
data_in = paddle.randint(0, size[0], shape=(10, 4))
4141

4242
data = paddle.static.data(
43-
name='tindata', shape=[10, 1000], dtype="float32"
43+
name='tindata', shape=[10, 1000], dtype=dtype
4444
)
4545
per_part_size = size[0] // 2
4646
if rank == 0:

test/collective/test_collective_barrier_api.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ def test_barrier_nccl_with_new_comm(self):
3838

3939
def test_barrier_gloo(self):
4040
self.check_with_place(
41-
"collective_barrier_api.py",
42-
"barrier",
43-
"gloo",
44-
"5",
45-
need_envs={"FLAGS_dynamic_static_unified_comm": "false"},
41+
"collective_barrier_api.py", "barrier", "gloo", "5"
4642
)
4743

4844

test/collective/test_collective_sendrecv.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,12 @@ def test_sendrecv(self):
3030

3131
def test_sendrecv_dynamic_shape(self):
3232
self.check_with_place(
33-
"collective_sendrecv_op_dynamic_shape.py",
34-
"sendrecv_dynamic_shape",
35-
need_envs={"FLAGS_dynamic_static_unified_comm": "0"},
33+
"collective_sendrecv_op_dynamic_shape.py", "sendrecv_dynamic_shape"
3634
)
3735

3836
def test_sendrecv_array(self):
3937
self.check_with_place(
40-
"collective_sendrecv_op_array.py",
41-
"sendrecv_array",
42-
need_envs={"FLAGS_dynamic_static_unified_comm": "0"},
38+
"collective_sendrecv_op_array.py", "sendrecv_array"
4339
)
4440

4541

test/collective/test_collective_split_embedding.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ def _setup_config(self):
2727

2828
def test_parallel_embedding(self):
2929
self.check_with_place(
30-
"parallel_embedding_api.py",
31-
"parallel_embedding",
32-
"nccl",
33-
need_envs={"FLAGS_dynamic_static_unified_comm": "false"},
30+
"parallel_embedding_api.py", "parallel_embedding", "nccl"
3431
)
3532

3633

test/legacy_test/test_collective_api_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def run_trainer(self, args):
154154
reduce_type=args['reduce_type'],
155155
)
156156
if args["use_comm_context"]
157-
else (self.get_model(train_prog, startup_prog, rank))
157+
else (
158+
self.get_model(
159+
train_prog, startup_prog, rank, dtype=args['dtype']
160+
)
161+
)
158162
)
159163
exe = base.Executor(place)
160164
exe.run(startup_prog)

test/legacy_test/test_collective_base.py

-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ def check_with_place(
263263
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
264264
"GLOG_v": "3",
265265
"NCCL_P2P_DISABLE": "1",
266-
"FLAGS_dynamic_static_unified_comm": "1",
267266
"DTYPE": "float32",
268267
}
269268
required_envs.update(need_envs)

test/xpu/collective_allreduce_api.py

-15
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,6 @@ def get_model_new(
8585
all_reduce_new(tindata, reduce_type)
8686
return [tindata]
8787

88-
def get_model_new_comm(
89-
self,
90-
main_prog,
91-
startup_program,
92-
rank,
93-
dtype='float32',
94-
):
95-
with base.program_guard(main_prog, startup_program):
96-
tindata = paddle.static.data(
97-
name="tindata", shape=[10, 1000], dtype=dtype
98-
)
99-
reduce_type = int(os.getenv("REDUCE_TYPE"))
100-
paddle.distributed.all_reduce(tindata, op=reduce_type)
101-
return [tindata]
102-
10388

10489
if __name__ == "__main__":
10590
test_base.runtime_main(TestCollectiveAllreduceAPI, "allreduce")

test/xpu/collective_reduce_api.py

-11
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,6 @@ def get_model_new(
7979
reduce_new(tindata, dst=0, reduce_type=reduce_type)
8080
return [tindata]
8181

82-
def get_model_new_comm(
83-
self, main_prog, startup_program, rank, dtype='float32'
84-
):
85-
with base.program_guard(main_prog, startup_program):
86-
tindata = paddle.static.data(
87-
name="tindata", shape=[-1, 10, 1000], dtype=dtype
88-
)
89-
tindata.desc.set_need_check_feed(False)
90-
paddle.distributed.reduce(tindata, dst=0)
91-
return [tindata]
92-
9382

9483
if __name__ == "__main__":
9584
runtime_main(TestCollectiveReduceAPI, "reduce")

test/xpu/test_collective_api_base.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def run_trainer(self, args):
125125
rank = args["trainerid"]
126126
current_endpoint = args["currentendpoint"]
127127
nranks = 2
128-
if args["use_comm_context"] or args["dynamic_static_unified_comm"]:
128+
if args['static_mode'] and (
129+
args["use_comm_context"] or args["dynamic_static_unified_comm"]
130+
):
129131
paddle.distributed.collective._init_parallel_env(args["backend"])
130132
else:
131133
paddle.distributed.init_parallel_env()
@@ -153,11 +155,7 @@ def run_trainer(self, args):
153155
)
154156
if args["use_comm_context"]
155157
else (
156-
self.get_model_new_comm(
157-
train_prog, startup_prog, rank, dtype=args['dtype']
158-
)
159-
if args["dynamic_static_unified_comm"]
160-
else self.get_model(
158+
self.get_model(
161159
train_prog, startup_prog, rank, dtype=args['dtype']
162160
)
163161
)
@@ -190,8 +188,7 @@ def runtime_main(test_class, col_type):
190188
args["reduce_type"] = os.getenv("REDUCE_TYPE")
191189
args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0")))
192190
args["dynamic_static_unified_comm"] = bool(
193-
os.getenv("FLAGS_dynamic_static_unified_comm", "false").lower()
194-
== "true"
191+
os.getenv("FLAGS_dynamic_static_unified_comm", "true").lower() == "true"
195192
)
196193
model.run_trainer(args)
197194

@@ -352,7 +349,6 @@ def check_with_place(
352349
"PATH_ID": path_id,
353350
"DTYPE": dtype,
354351
"REDUCE_TYPE": str(reduce_type),
355-
"FLAGS_dynamic_static_unified_comm": "0",
356352
}
357353
required_envs.update(additional_envs)
358354
required_envs.update(need_envs)

test/xpu/test_collective_base_xpu.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def runtime_main(test_class, col_type, sub_type):
175175
args["dtype"] = os.getenv("DTYPE")
176176
args["batch_size"] = os.getenv("BATCH_SIZE")
177177
args["dynamic_static_unified_comm"] = bool(
178-
int(os.getenv("FLAGS_dynamic_static_unified_comm", "0"))
178+
int(os.getenv("FLAGS_dynamic_static_unified_comm", "1"))
179179
)
180180
model.run_trainer(args)
181181

@@ -293,7 +293,6 @@ def check_with_place(
293293
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
294294
"GLOG_v": "3",
295295
"DTYPE": dtype,
296-
"FLAGS_dynamic_static_unified_comm": "0",
297296
}
298297
required_envs.update(need_envs)
299298
if check_error_log:

test/xpu/test_collective_softmax_with_cross_entropy_xpu.py

-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def check_with_place(
122122
"GLOG_v": "3",
123123
"DTYPE": dtype,
124124
"BATCH_SIZE": str(self.batch_size),
125-
"FLAGS_dynamic_static_unified_comm": "0",
126125
}
127126
required_envs.update(need_envs)
128127
if check_error_log:

0 commit comments

Comments
 (0)