From d3090fc722a4bb59e28739fc1077d44f80ac06ab Mon Sep 17 00:00:00 2001 From: chengduozh Date: Mon, 20 May 2019 16:13:43 +0800 Subject: [PATCH 1/7] Add broadcast test=develop --- .../operators/distributed_ops/allreduce_op.h | 3 + .../operators/distributed_ops/broadcast_op.cc | 71 ++++++++++++++++ .../distributed_ops/broadcast_op.cu.cc | 83 +++++++++++++++++++ python/paddle/fluid/dygraph/parallel.py | 8 ++ python/paddle/fluid/layers/collective.py | 11 +++ 5 files changed, 176 insertions(+) create mode 100644 paddle/fluid/operators/distributed_ops/broadcast_op.cc create mode 100644 paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.h b/paddle/fluid/operators/distributed_ops/allreduce_op.h index 0275f6a9cf3aa8..ef4001d8b957a7 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.h +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.h @@ -40,6 +40,7 @@ class AllReduceOpKernel : public framework::OpKernel { auto in = ctx.Input("X"); auto out = ctx.Output("Out"); + int in_dev_id = boost::get(in->place()).device; int dtype = platform::ToNCCLDataType(in->type()); int64_t numel = in->numel(); auto* sendbuff = in->data(); @@ -67,6 +68,8 @@ class AllReduceOpKernel : public framework::OpKernel { red_type = ncclMin; break; } + VLOG(3) << "AllReduce " << ctx.Inputs("X")[0] << " On " << in_dev_id; + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, static_cast(dtype), red_type, comm, stream)); diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cc new file mode 100644 index 00000000000000..70ce64bd6b7ed1 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class BroadcastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} +}; + +class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor), tensor to be broadcast."); + AddOutput("Out", "(Tensor) the result of broadcast."); + AddAttr( + "sync_mode", + "(bool) whether to synchronize the CUDA stream after nccl call.") + .SetDefault(false); + AddAttr("root", "(int).").SetDefault(0); + AddComment(R"DOC( +***Broadcast Operator*** + +Call NCCL Broadcast internally. Note that this op must be used when one +thread is managing one GPU device. +)DOC"); + } +}; + +template +class BroadcastOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW("Broadcast op can run on gpu place only for now."); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(broadcast, ops::BroadcastOp, + ops::BroadcastOpMaker); + +REGISTER_OP_CPU_KERNEL(broadcast, ops::BroadcastOpKernel, + ops::BroadcastOpKernel, + ops::BroadcastOpKernel, + ops::BroadcastOpKernel, + ops::BroadcastOpKernel); diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc new file mode 100644 index 00000000000000..bf33bd787f6b5a --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +template +class NCCLBroadcastOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace())); + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + int dev_id = boost::get(ctx.GetPlace()).device; + int root_dev_id = ctx.Attr("root"); + + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + out->Resize(in->dims()); + + const int in_dev_id = boost::get(in->place()).device; + PADDLE_ENFORCE_EQ(dev_id, in_dev_id); + + auto& dev_ctx = ctx.template device_context(); + auto comm = dev_ctx.nccl_comm(); + auto stream = dev_ctx.stream(); + PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly."); + + void* data_buffer = const_cast(in->data()); + if (root_dev_id != in_dev_id) { + data_buffer = out->mutable_data(ctx.GetPlace()); + } + + VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << " From " << root_dev_id + << " to " << in_dev_id; + + auto dtype = platform::ToNCCLDataType(in->type()); + + PADDLE_ENFORCE(platform::dynload::ncclBcast( + data_buffer, static_cast(in->numel()), dtype, root_dev_id, comm, + stream)); + + if (ctx.Attr("sync_mode")) { + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + } +#endif + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(broadcast, ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel, + ops::NCCLBroadcastOpKernel); diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 1378f91402874f..b1e12e02c89468 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -82,6 +82,14 @@ def __init__(self, layers, strategy): self).__init__(layers.full_name() + "_data_parallel") self._layers = layers self._strategy = strategy + # Broadcast parameter to other devices. + for param in self._layers.parameters(): + p_var = framework.Variable( + block=self._helper.main_program.current_block(), + name=param.name, + stop_gradient=True, + ivar=param._ivar) + collective._broadcast(p_var, 0, sync_mode=True) def forward(self, *inputs, **kwargs): return self._layers(*inputs, **kwargs) diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py index 97c290f5a99da5..4fa0d1eb2ccd25 100644 --- a/python/paddle/fluid/layers/collective.py +++ b/python/paddle/fluid/layers/collective.py @@ -46,3 +46,14 @@ def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): attrs={"reduce_type": red_typ_int, "sync_mode": sync_mode}) return out + + +def _broadcast(x, root, sync_mode=False): + helper = LayerHelper("broadcast", **locals()) + helper.append_op( + type='broadcast', + inputs={'X': [x]}, + outputs={'Out': [x]}, + attrs={"sync_mode": sync_mode, + "root": root}) + return x From e71d181dd132ca79164abae053d87770d8daabe6 Mon Sep 17 00:00:00 2001 From: chengduozh Date: Mon, 20 May 2019 20:30:22 +0800 Subject: [PATCH 2/7] polish code test=develop --- .../operators/distributed_ops/broadcast_op.cc | 9 +++++++-- .../operators/distributed_ops/broadcast_op.cu.cc | 16 +++++++--------- python/paddle/fluid/dygraph/__init__.py | 1 - python/paddle/fluid/dygraph/parallel.py | 7 ++++--- .../fluid/tests/unittests/test_dist_base.py | 2 +- 5 files changed, 19 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cc index 70ce64bd6b7ed1..6ae98af1e2ac19 100644 --- a/paddle/fluid/operators/distributed_ops/broadcast_op.cc +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cc @@ -25,7 +25,12 @@ class BroadcastOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of BroadcastOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Output) of ConvOp should not be null."); + } }; class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker { @@ -37,7 +42,7 @@ class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker { "sync_mode", "(bool) whether to synchronize the CUDA stream after nccl call.") .SetDefault(false); - AddAttr("root", "(int).").SetDefault(0); + AddAttr("root", "(int).").SetDefault(0).EqualGreaterThan(0); AddComment(R"DOC( ***Broadcast Operator*** diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc index bf33bd787f6b5a..590f059a074166 100644 --- a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc @@ -50,21 +50,19 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); auto comm = dev_ctx.nccl_comm(); auto stream = dev_ctx.stream(); - PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly."); - void* data_buffer = const_cast(in->data()); + void* send_recv_buffer = const_cast(in->data()); if (root_dev_id != in_dev_id) { - data_buffer = out->mutable_data(ctx.GetPlace()); + send_recv_buffer = out->mutable_data(ctx.GetPlace()); } - VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << " From " << root_dev_id - << " to " << in_dev_id; - - auto dtype = platform::ToNCCLDataType(in->type()); + VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" + << static_cast(in->numel()) << ")" + << " From " << root_dev_id << " to " << in_dev_id; PADDLE_ENFORCE(platform::dynload::ncclBcast( - data_buffer, static_cast(in->numel()), dtype, root_dev_id, comm, - stream)); + send_recv_buffer, static_cast(in->numel()), + platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream)); if (ctx.Attr("sync_mode")) { PADDLE_ENFORCE(cudaStreamSynchronize(stream)); diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index 7ab1dfdf767749..727bf5a3570044 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -47,7 +47,6 @@ __all__ += nn.__all__ __all__ += tracer.__all__ __all__ += profiler.__all__ -__all__ += parallel.__all__ __all__ += checkpoint.__all__ __all__ += learning_rate_scheduler.__all__ __all__ += backward_strategy.__all__ diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index b1e12e02c89468..97e83ede37aef5 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -22,14 +22,12 @@ from ..layers import collective from . import to_variable -__all__ = ["prepare_context"] - ParallelStrategy = core.ParallelStrategy __parallel_ctx__clz__ = None -def prepare_context(parallel_strategy): +def _prepare_context(parallel_strategy): global __parallel_ctx__clz__ assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once." assert framework.in_dygraph_mode( @@ -82,6 +80,9 @@ def __init__(self, layers, strategy): self).__init__(layers.full_name() + "_data_parallel") self._layers = layers self._strategy = strategy + if self._strategy.nranks > 1: + _prepare_context(strategy) + # Broadcast parameter to other devices. for param in self._layers.parameters(): p_var = framework.Variable( diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index b479966d1fb372..998aafd5158ff0 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -230,7 +230,7 @@ def _get_data(batch): strategy.local_rank = args.trainer_id strategy.trainer_endpoints = args.endpoints.split(",") strategy.current_endpoint = args.current_endpoint - dygraph.parallel.prepare_context(strategy) + dygraph.parallel._prepare_context(strategy) model = dygraph.parallel.DataParallel(model, strategy) out_losses = [] for step_id, data in enumerate(train_reader()): From 795f92c060962fee877fb24570cab49352aba09a Mon Sep 17 00:00:00 2001 From: chengduozh Date: Tue, 21 May 2019 14:58:48 +0800 Subject: [PATCH 3/7] Add doc test=develop --- paddle/fluid/platform/nccl_helper.h | 2 +- python/paddle/fluid/dygraph/__init__.py | 1 + python/paddle/fluid/dygraph/parallel.py | 82 ++++++++++++++----- .../fluid/tests/unittests/test_dist_base.py | 2 +- 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index b8b14b3d15efb4..ce51df4eff2ce7 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -104,7 +104,7 @@ struct NCCLContextMap { PADDLE_ENFORCE_EQ( order_.size(), contexts_.size(), "NCCL Context Map does not support contain two or more same device"); - + // NOTE(paddle-dev): Why use std::unique_ptr and the T is ncclComm_t[] here? std::unique_ptr comms(new ncclComm_t[order_.size()]); // if num_trainers == 1, should create a new nccl id for local comms. if (num_trainers == 1 && nccl_id == nullptr) { diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index 727bf5a3570044..7860ea25c515a3 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -45,6 +45,7 @@ __all__ += layers.__all__ __all__ += base.__all__ __all__ += nn.__all__ +__all__ += parallel.__all__ __all__ += tracer.__all__ __all__ += profiler.__all__ __all__ += checkpoint.__all__ diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 97e83ede37aef5..04f6f48b42de2d 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -18,22 +18,24 @@ from .. import core from . import layers from .. import framework - from ..layers import collective from . import to_variable +__all__ = ["prepare_context"] ParallelStrategy = core.ParallelStrategy __parallel_ctx__clz__ = None -def _prepare_context(parallel_strategy): +def prepare_context(parallel_strategy): global __parallel_ctx__clz__ - assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once." - assert framework.in_dygraph_mode( - ) is True, "dygraph.parallel.prepare_context should be used with dygrahp mode." + assert __parallel_ctx__clz__ is None, \ + "ParallelContext can only be initialized once." + assert framework.in_dygraph_mode() is True,\ + "dygraph.parallel.prepare_context should be used with dygrahp mode." place = framework._current_expected_place() - assert place is not None, "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard." + assert place is not None, \ + "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard." if isinstance(place, core.CUDAPlace): __parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy, @@ -45,6 +47,9 @@ def _prepare_context(parallel_strategy): class Env(object): + """ + """ + def __init__(self): self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) @@ -75,29 +80,61 @@ def trainer_endpoints(self): class DataParallel(layers.Layer): - def __init__(self, layers, strategy): + """ + DataParallel. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy + import os + ... + + Args: + layers(Layer): The layer.Layer. + strategy(ParallelStrategy): The dygraph.parallel.ParallelStrategy. + + Returns: + Layer: The layer.Layer.. + + Raises: + TypeError: If share_vars_from is provided, but not ParallelExecutor object. + """ + + def __init__(self, layers, strategy=None): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") + self._layers = layers + if strategy is None: + strategy = ParallelStrategy() + strategy.nranks = Env().nranks + strategy.local_rank = Env().local_rank + strategy.trainer_endpoints = Env().trainer_endpoints + strategy.current_endpoint = Env().current_endpoint self._strategy = strategy - if self._strategy.nranks > 1: - _prepare_context(strategy) - # Broadcast parameter to other devices. - for param in self._layers.parameters(): - p_var = framework.Variable( - block=self._helper.main_program.current_block(), - name=param.name, - stop_gradient=True, - ivar=param._ivar) - collective._broadcast(p_var, 0, sync_mode=True) + if self._is_data_parallel_mode(): + prepare_context(strategy) + for param in self._layers.parameters(): + collective._broadcast(param, 0, sync_mode=True) def forward(self, *inputs, **kwargs): return self._layers(*inputs, **kwargs) def scale_loss(self, loss): - if self._strategy.nranks < 2: + """ + + Args: + loss(Layer): The layer.Layer. + + Returns: + Layer: The layer.Layer. + """ + if not self._is_data_parallel_mode(): return loss + loss_scale = to_variable( np.array([self._strategy.nranks]).astype("float32")) loss_scale.stop_gradient = True @@ -105,10 +142,14 @@ def scale_loss(self, loss): return loss def apply_collective_grads(self): - if self._strategy.nranks < 2: + """ + AllReduce the Parameters' gradient. + """ + if not self._is_data_parallel_mode(): return for param in self._layers.parameters(): + # NOTE(zcd): The grad_ivar maybe no generated. if param.trainable and param._ivar._grad_ivar(): g_var = framework.Variable( block=self._helper.main_program.current_block(), @@ -116,3 +157,6 @@ def apply_collective_grads(self): stop_gradient=True, ivar=param._ivar._grad_ivar()) collective._allreduce(g_var, g_var, sync_mode=True) + + def _is_data_parallel_mode(self): + return self._strategy.nranks > 1 diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 998aafd5158ff0..b479966d1fb372 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -230,7 +230,7 @@ def _get_data(batch): strategy.local_rank = args.trainer_id strategy.trainer_endpoints = args.endpoints.split(",") strategy.current_endpoint = args.current_endpoint - dygraph.parallel._prepare_context(strategy) + dygraph.parallel.prepare_context(strategy) model = dygraph.parallel.DataParallel(model, strategy) out_losses = [] for step_id, data in enumerate(train_reader()): From 2b3b07ced63f824f546a5644be8f9e46112506aa Mon Sep 17 00:00:00 2001 From: chengduozh Date: Tue, 21 May 2019 22:47:45 +0800 Subject: [PATCH 4/7] bcast parameters test=develop --- python/paddle/fluid/dygraph/__init__.py | 2 +- python/paddle/fluid/dygraph/layers.py | 3 ++ python/paddle/fluid/dygraph/parallel.py | 38 +++++++---------- .../paddle/fluid/dygraph/parallel_helper.py | 42 +++++++++++++++++++ 4 files changed, 62 insertions(+), 23 deletions(-) create mode 100644 python/paddle/fluid/dygraph/parallel_helper.py diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index 7860ea25c515a3..7ab1dfdf767749 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -45,9 +45,9 @@ __all__ += layers.__all__ __all__ += base.__all__ __all__ += nn.__all__ -__all__ += parallel.__all__ __all__ += tracer.__all__ __all__ += profiler.__all__ +__all__ += parallel.__all__ __all__ += checkpoint.__all__ __all__ += learning_rate_scheduler.__all__ __all__ += backward_strategy.__all__ diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 7ddf94146c776e..c2a58411402cc3 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -18,6 +18,7 @@ import numpy as np import collections import six +from . import parallel_helper from .. import unique_name from paddle.fluid import core from .layer_object_helper import LayerObjectHelper @@ -154,6 +155,8 @@ def build_once(self, *args): def __call__(self, *inputs): if not self._built: self.build_once(*inputs) + if parallel_helper._is_data_parallel_mode(): + parallel_helper._broadcast_parameters(self.parameters()) outputs = self.forward(*inputs) self._built = True diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 04f6f48b42de2d..91258c4029d752 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -17,33 +17,38 @@ from .. import core from . import layers +from . import parallel_helper from .. import framework from ..layers import collective from . import to_variable __all__ = ["prepare_context"] -ParallelStrategy = core.ParallelStrategy -__parallel_ctx__clz__ = None +ParallelStrategy = core.ParallelStrategy -def prepare_context(parallel_strategy): - global __parallel_ctx__clz__ - assert __parallel_ctx__clz__ is None, \ - "ParallelContext can only be initialized once." +def prepare_context(strategy=None): + if strategy is None: + strategy = ParallelStrategy() + strategy.nranks = Env().nranks + strategy.local_rank = Env().local_rank + strategy.trainer_endpoints = Env().trainer_endpoints + strategy.current_endpoint = Env().current_endpoint + if strategy.nranks < 2: + return assert framework.in_dygraph_mode() is True,\ "dygraph.parallel.prepare_context should be used with dygrahp mode." place = framework._current_expected_place() assert place is not None, \ "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard." - if isinstance(place, core.CUDAPlace): - __parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy, - place) + parallel_helper._set_parallel_ctx( + core.NCCLParallelContext(strategy, place)) else: # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation assert ("Only support CUDAPlace for now.") - __parallel_ctx__clz__.init() + parallel_helper._init_parallel_ctx() + return strategy class Env(object): @@ -102,24 +107,13 @@ class DataParallel(layers.Layer): TypeError: If share_vars_from is provided, but not ParallelExecutor object. """ - def __init__(self, layers, strategy=None): + def __init__(self, layers, strategy): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") self._layers = layers - if strategy is None: - strategy = ParallelStrategy() - strategy.nranks = Env().nranks - strategy.local_rank = Env().local_rank - strategy.trainer_endpoints = Env().trainer_endpoints - strategy.current_endpoint = Env().current_endpoint self._strategy = strategy - if self._is_data_parallel_mode(): - prepare_context(strategy) - for param in self._layers.parameters(): - collective._broadcast(param, 0, sync_mode=True) - def forward(self, *inputs, **kwargs): return self._layers(*inputs, **kwargs) diff --git a/python/paddle/fluid/dygraph/parallel_helper.py b/python/paddle/fluid/dygraph/parallel_helper.py new file mode 100644 index 00000000000000..3a8e39e62cb5d6 --- /dev/null +++ b/python/paddle/fluid/dygraph/parallel_helper.py @@ -0,0 +1,42 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except jin compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from ..layers import collective + +__parallel_ctx__clz__ = None + + +def _is_data_parallel_mode(): + global __parallel_ctx__clz__ + return __parallel_ctx__clz__ is not None and int( + os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1 + + +def _set_parallel_ctx(nccl_parallel_context): + global __parallel_ctx__clz__ + assert __parallel_ctx__clz__ is None, \ + "ParallelContext can only be initialized once." + __parallel_ctx__clz__ = nccl_parallel_context + + +def _init_parallel_ctx(): + global __parallel_ctx__clz__ + assert __parallel_ctx__clz__ is not None, \ + "ParallelContext should be initialized." + __parallel_ctx__clz__.init() + + +def _broadcast_parameters(parameters): + for param in parameters: + collective._broadcast(param, 0, sync_mode=True) From 1aedf4b78d7b7a807557d4b462ee8878a5031969 Mon Sep 17 00:00:00 2001 From: chengduozh Date: Wed, 22 May 2019 11:18:34 +0800 Subject: [PATCH 5/7] code refine test=develop --- .../operators/distributed_ops/allreduce_op.h | 3 - paddle/fluid/platform/nccl_helper.h | 2 +- python/paddle/fluid/dygraph/layers.py | 2 +- python/paddle/fluid/dygraph/parallel.py | 65 ++++++++++++++----- .../paddle/fluid/dygraph/parallel_helper.py | 3 +- 5 files changed, 53 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.h b/paddle/fluid/operators/distributed_ops/allreduce_op.h index ef4001d8b957a7..0275f6a9cf3aa8 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.h +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.h @@ -40,7 +40,6 @@ class AllReduceOpKernel : public framework::OpKernel { auto in = ctx.Input("X"); auto out = ctx.Output("Out"); - int in_dev_id = boost::get(in->place()).device; int dtype = platform::ToNCCLDataType(in->type()); int64_t numel = in->numel(); auto* sendbuff = in->data(); @@ -68,8 +67,6 @@ class AllReduceOpKernel : public framework::OpKernel { red_type = ncclMin; break; } - VLOG(3) << "AllReduce " << ctx.Inputs("X")[0] << " On " << in_dev_id; - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, static_cast(dtype), red_type, comm, stream)); diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index ce51df4eff2ce7..b8b14b3d15efb4 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -104,7 +104,7 @@ struct NCCLContextMap { PADDLE_ENFORCE_EQ( order_.size(), contexts_.size(), "NCCL Context Map does not support contain two or more same device"); - // NOTE(paddle-dev): Why use std::unique_ptr and the T is ncclComm_t[] here? + std::unique_ptr comms(new ncclComm_t[order_.size()]); // if num_trainers == 1, should create a new nccl id for local comms. if (num_trainers == 1 && nccl_id == nullptr) { diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index c2a58411402cc3..54b34919eae213 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -156,7 +156,7 @@ def __call__(self, *inputs): if not self._built: self.build_once(*inputs) if parallel_helper._is_data_parallel_mode(): - parallel_helper._broadcast_parameters(self.parameters()) + parallel_helper._broadcast_parameters(self._parameters.values()) outputs = self.forward(*inputs) self._built = True diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 91258c4029d752..37716cea14c016 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -52,9 +52,6 @@ def prepare_context(strategy=None): class Env(object): - """ - """ - def __init__(self): self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) @@ -86,25 +83,58 @@ def trainer_endpoints(self): class DataParallel(layers.Layer): """ - DataParallel. + Runs the module with data parallelism. + + Currently, DataParallel only supports to run the dynamic graph + with multi-process. The usage is: + `python -m paddle.distributed.launch --gpus 2 dynamic_graph_test.py`. + And the content of `dynamic_graph_test.py` is the code of examples. Examples: .. code-block:: python - import paddle.fluid as fluid - import numpy - import os - ... + import numpy as np + import paddle.fluid as fluid + import paddle.fluid.dygraph as dygraph + from paddle.fluid.optimizer import AdamOptimizer + from paddle.fluid.dygraph.nn import FC + from paddle.fluid.dygraph.base import to_variable + + place = fluid.CUDAPlace(0) + with fluid.dygraph.guard(place=place): + + # prepare the data parallel context + strategy=dygraph.parallel.prepare_context() + + fc_layer = FC("FC", 10, act="softmax") + adam = fluid.optimizer.AdamOptimizer() + + # make the module become the data parallelism module + fc_layer = dygraph.parallel.DataParallel(fc_layer, strategy) + + x_data = np.random.random(size=[10, 1]).astype(np.float32) + data = to_variable(x_data) + + hidden = fc_layer(data) + avg_loss = fluid.layers.mean(hidden) + + # scale the loss according to the number of trainers. + avg_loss = fc_layer.scale_loss(avg_loss) + + avg_loss.backward() + + # collect the gradients of trainers. + fc_layer.apply_collective_grads() + + adam.minimize(avg_loss) + fc_layer.clear_gradients() Args: - layers(Layer): The layer.Layer. - strategy(ParallelStrategy): The dygraph.parallel.ParallelStrategy. + layers(Layer): The module that should be executed by data parallel. + strategy(ParallelStrategy): The strategy of data parallelism. Returns: - Layer: The layer.Layer.. - - Raises: - TypeError: If share_vars_from is provided, but not ParallelExecutor object. + Layer: The data paralleled module. """ def __init__(self, layers, strategy): @@ -119,12 +149,15 @@ def forward(self, *inputs, **kwargs): def scale_loss(self, loss): """ + Scale the loss. In data parallel mode, the loss should be scale with + the number of trainers. If not in data parallel mode, return the loss + directly. Args: - loss(Layer): The layer.Layer. + loss(Layer): The loss of the current Model. Returns: - Layer: The layer.Layer. + Layer: the scaled loss. """ if not self._is_data_parallel_mode(): return loss diff --git a/python/paddle/fluid/dygraph/parallel_helper.py b/python/paddle/fluid/dygraph/parallel_helper.py index 3a8e39e62cb5d6..7932c327e44eea 100644 --- a/python/paddle/fluid/dygraph/parallel_helper.py +++ b/python/paddle/fluid/dygraph/parallel_helper.py @@ -39,4 +39,5 @@ def _init_parallel_ctx(): def _broadcast_parameters(parameters): for param in parameters: - collective._broadcast(param, 0, sync_mode=True) + if param.trainable: + collective._broadcast(param, 0, sync_mode=True) From 5d641ba9981282db62c6159dff077078e226ed53 Mon Sep 17 00:00:00 2001 From: chengduozh Date: Thu, 23 May 2019 11:16:34 +0800 Subject: [PATCH 6/7] use ncclBroadcast test=develop --- .../distributed_ops/broadcast_op.cu.cc | 23 +++++++++---------- paddle/fluid/platform/dynload/nccl.h | 1 + 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc index 590f059a074166..4d307c87efe0f1 100644 --- a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc @@ -34,7 +34,8 @@ template class NCCLBroadcastOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace())); + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "The place of ExecutionContext should be CUDAPlace."); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) int dev_id = boost::get(ctx.GetPlace()).device; @@ -43,30 +44,28 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { auto in = ctx.Input("X"); auto out = ctx.Output("Out"); out->Resize(in->dims()); + void* recv_buffer = out->mutable_data(ctx.GetPlace()); + const void* send_buffer = in->data(); - const int in_dev_id = boost::get(in->place()).device; + int in_dev_id = boost::get(in->place()).device; PADDLE_ENFORCE_EQ(dev_id, in_dev_id); auto& dev_ctx = ctx.template device_context(); auto comm = dev_ctx.nccl_comm(); auto stream = dev_ctx.stream(); - void* send_recv_buffer = const_cast(in->data()); - if (root_dev_id != in_dev_id) { - send_recv_buffer = out->mutable_data(ctx.GetPlace()); - } + PADDLE_ENFORCE(platform::dynload::ncclBroadcast( + send_buffer, recv_buffer, static_cast(in->numel()), + platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream)); - VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" - << static_cast(in->numel()) << ")" + VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")" << " From " << root_dev_id << " to " << in_dev_id; - PADDLE_ENFORCE(platform::dynload::ncclBcast( - send_recv_buffer, static_cast(in->numel()), - platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream)); - if (ctx.Attr("sync_mode")) { PADDLE_ENFORCE(cudaStreamSynchronize(stream)); } +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); #endif } }; diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index 331ca9908e126d..24d31d942e5508 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -62,6 +62,7 @@ extern void* nccl_dso_handle; __macro(ncclCommUserRank); \ __macro(ncclAllReduce); \ __macro(ncclBcast); \ + __macro(ncclBroadcast); \ __macro(ncclAllGather); \ __macro(ncclGroupStart); \ __macro(ncclGroupEnd); \ From 4b43823084ad243cba68802646fb848697414611 Mon Sep 17 00:00:00 2001 From: chengduozh Date: Thu, 23 May 2019 16:44:00 +0800 Subject: [PATCH 7/7] remove ncclBroadcast test=develop --- .../distributed_ops/broadcast_op.cu.cc | 19 ++++++++++--------- paddle/fluid/platform/dynload/nccl.h | 1 - 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc index 4d307c87efe0f1..c9b40e6863f444 100644 --- a/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc +++ b/paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc @@ -43,23 +43,24 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { auto in = ctx.Input("X"); auto out = ctx.Output("Out"); - out->Resize(in->dims()); - void* recv_buffer = out->mutable_data(ctx.GetPlace()); - const void* send_buffer = in->data(); - - int in_dev_id = boost::get(in->place()).device; - PADDLE_ENFORCE_EQ(dev_id, in_dev_id); + PADDLE_ENFORCE(out->IsInitialized(), + "Currently, the output of broadcast op must be initialized, " + "because this op can only be an In-Place operation."); + void* send_recv_buffer = out->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_EQ( + send_recv_buffer, in->data(), + "Currently, the broadcast op can only be an In-Place operation."); auto& dev_ctx = ctx.template device_context(); auto comm = dev_ctx.nccl_comm(); auto stream = dev_ctx.stream(); - PADDLE_ENFORCE(platform::dynload::ncclBroadcast( - send_buffer, recv_buffer, static_cast(in->numel()), + PADDLE_ENFORCE(platform::dynload::ncclBcast( + send_recv_buffer, static_cast(in->numel()), platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream)); VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")" - << " From " << root_dev_id << " to " << in_dev_id; + << " From " << root_dev_id << " to " << dev_id; if (ctx.Attr("sync_mode")) { PADDLE_ENFORCE(cudaStreamSynchronize(stream)); diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index 24d31d942e5508..331ca9908e126d 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -62,7 +62,6 @@ extern void* nccl_dso_handle; __macro(ncclCommUserRank); \ __macro(ncclAllReduce); \ __macro(ncclBcast); \ - __macro(ncclBroadcast); \ __macro(ncclAllGather); \ __macro(ncclGroupStart); \ __macro(ncclGroupEnd); \