Skip to content

【Comm】Fix update comm flags #67005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 7, 2024
27 changes: 19 additions & 8 deletions paddle/fluid/operators/collective/barrier_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#if defined(PADDLE_WITH_GLOO)
#include <gloo/barrier.h>

#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif

namespace paddle {
Expand All @@ -37,14 +39,23 @@ class BarrierOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_GLOO)
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(),
true,
common::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::BarrierOptions opts(gloo->GetContext());
gloo::barrier(opts);
int rid = ctx.Attr<int>("ring_id");
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(std::to_string(rid))) {
auto* comm_context = static_cast<phi::distributed::GlooCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
comm_context->Barrier();
} else {
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(),
true,
common::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::BarrierOptions opts(gloo->GetContext());
gloo::barrier(opts);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议改成phi下调用方式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,感谢review

#else
PADDLE_THROW(common::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/collective/send_v2_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
} else {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
x.data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " send "
<< common::product(x.dims()) << " to " << peer;
}
VLOG(3) << "rank " << comm->rank() << " send "
<< common::product(x.dims()) << " to " << peer;
}
return;
}
Expand Down
4 changes: 2 additions & 2 deletions test/collective/parallel_embedding_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TestParallelEmbeddingAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0

def get_model(self, main_prog, startup_program, rank):
def get_model(self, main_prog, startup_program, rank, dtype="float32"):
with base.program_guard(main_prog, startup_program):
fleet.init(is_collective=True)
np.random.seed(2020)
Expand All @@ -40,7 +40,7 @@ def get_model(self, main_prog, startup_program, rank):
data_in = paddle.randint(0, size[0], shape=(10, 4))

data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32"
name='tindata', shape=[10, 1000], dtype=dtype
)
per_part_size = size[0] // 2
if rank == 0:
Expand Down
6 changes: 1 addition & 5 deletions test/collective/test_collective_barrier_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ def test_barrier_nccl_with_new_comm(self):

def test_barrier_gloo(self):
self.check_with_place(
"collective_barrier_api.py",
"barrier",
"gloo",
"5",
need_envs={"FLAGS_dynamic_static_unified_comm": "false"},
"collective_barrier_api.py", "barrier", "gloo", "5"
)


Expand Down
8 changes: 2 additions & 6 deletions test/collective/test_collective_sendrecv.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ def test_sendrecv(self):

def test_sendrecv_dynamic_shape(self):
self.check_with_place(
"collective_sendrecv_op_dynamic_shape.py",
"sendrecv_dynamic_shape",
need_envs={"FLAGS_dynamic_static_unified_comm": "0"},
"collective_sendrecv_op_dynamic_shape.py", "sendrecv_dynamic_shape"
)

def test_sendrecv_array(self):
self.check_with_place(
"collective_sendrecv_op_array.py",
"sendrecv_array",
need_envs={"FLAGS_dynamic_static_unified_comm": "0"},
"collective_sendrecv_op_array.py", "sendrecv_array"
)


Expand Down
5 changes: 1 addition & 4 deletions test/collective/test_collective_split_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def _setup_config(self):

def test_parallel_embedding(self):
self.check_with_place(
"parallel_embedding_api.py",
"parallel_embedding",
"nccl",
need_envs={"FLAGS_dynamic_static_unified_comm": "false"},
"parallel_embedding_api.py", "parallel_embedding", "nccl"
)


Expand Down
6 changes: 5 additions & 1 deletion test/legacy_test/test_collective_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def run_trainer(self, args):
reduce_type=args['reduce_type'],
)
if args["use_comm_context"]
else (self.get_model(train_prog, startup_prog, rank))
else (
self.get_model(
train_prog, startup_prog, rank, dtype=args['dtype']
)
)
)
exe = base.Executor(place)
exe.run(startup_prog)
Expand Down
1 change: 0 additions & 1 deletion test/legacy_test/test_collective_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def check_with_place(
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "3",
"NCCL_P2P_DISABLE": "1",
"FLAGS_dynamic_static_unified_comm": "1",
"DTYPE": "float32",
}
required_envs.update(need_envs)
Expand Down
15 changes: 0 additions & 15 deletions test/xpu/collective_allreduce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,6 @@ def get_model_new(
all_reduce_new(tindata, reduce_type)
return [tindata]

def get_model_new_comm(
self,
main_prog,
startup_program,
rank,
dtype='float32',
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
reduce_type = int(os.getenv("REDUCE_TYPE"))
paddle.distributed.all_reduce(tindata, op=reduce_type)
return [tindata]


if __name__ == "__main__":
test_base.runtime_main(TestCollectiveAllreduceAPI, "allreduce")
11 changes: 0 additions & 11 deletions test/xpu/collective_reduce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,6 @@ def get_model_new(
reduce_new(tindata, dst=0, reduce_type=reduce_type)
return [tindata]

def get_model_new_comm(
self, main_prog, startup_program, rank, dtype='float32'
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
paddle.distributed.reduce(tindata, dst=0)
return [tindata]


if __name__ == "__main__":
runtime_main(TestCollectiveReduceAPI, "reduce")
14 changes: 5 additions & 9 deletions test/xpu/test_collective_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def run_trainer(self, args):
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
if args["use_comm_context"] or args["dynamic_static_unified_comm"]:
if args['static_mode'] and (
args["use_comm_context"] or args["dynamic_static_unified_comm"]
):
paddle.distributed.collective._init_parallel_env(args["backend"])
else:
paddle.distributed.init_parallel_env()
Expand Down Expand Up @@ -153,11 +155,7 @@ def run_trainer(self, args):
)
if args["use_comm_context"]
else (
self.get_model_new_comm(
train_prog, startup_prog, rank, dtype=args['dtype']
)
if args["dynamic_static_unified_comm"]
else self.get_model(
self.get_model(
train_prog, startup_prog, rank, dtype=args['dtype']
)
)
Expand Down Expand Up @@ -190,8 +188,7 @@ def runtime_main(test_class, col_type):
args["reduce_type"] = os.getenv("REDUCE_TYPE")
args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0")))
args["dynamic_static_unified_comm"] = bool(
os.getenv("FLAGS_dynamic_static_unified_comm", "false").lower()
== "true"
os.getenv("FLAGS_dynamic_static_unified_comm", "true").lower() == "true"
)
model.run_trainer(args)

Expand Down Expand Up @@ -352,7 +349,6 @@ def check_with_place(
"PATH_ID": path_id,
"DTYPE": dtype,
"REDUCE_TYPE": str(reduce_type),
"FLAGS_dynamic_static_unified_comm": "0",
}
required_envs.update(additional_envs)
required_envs.update(need_envs)
Expand Down
3 changes: 1 addition & 2 deletions test/xpu/test_collective_base_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def runtime_main(test_class, col_type, sub_type):
args["dtype"] = os.getenv("DTYPE")
args["batch_size"] = os.getenv("BATCH_SIZE")
args["dynamic_static_unified_comm"] = bool(
int(os.getenv("FLAGS_dynamic_static_unified_comm", "0"))
int(os.getenv("FLAGS_dynamic_static_unified_comm", "1"))
)
model.run_trainer(args)

Expand Down Expand Up @@ -293,7 +293,6 @@ def check_with_place(
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "3",
"DTYPE": dtype,
"FLAGS_dynamic_static_unified_comm": "0",
}
required_envs.update(need_envs)
if check_error_log:
Expand Down
1 change: 0 additions & 1 deletion test/xpu/test_collective_softmax_with_cross_entropy_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def check_with_place(
"GLOG_v": "3",
"DTYPE": dtype,
"BATCH_SIZE": str(self.batch_size),
"FLAGS_dynamic_static_unified_comm": "0",
}
required_envs.update(need_envs)
if check_error_log:
Expand Down