Skip to content

Add broadcast operators #17503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions paddle/fluid/operators/distributed_ops/broadcast_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include <ostream>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

class BroadcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of BroadcastOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Output) of ConvOp should not be null.");
}
};

class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be broadcast.");
AddOutput("Out", "(Tensor) the result of broadcast.");
AddAttr<bool>(
"sync_mode",
"(bool) whether to synchronize the CUDA stream after nccl call.")
.SetDefault(false);
AddAttr<int>("root", "(int).").SetDefault(0).EqualGreaterThan(0);
AddComment(R"DOC(
***Broadcast Operator***

Call NCCL Broadcast internally. Note that this op must be used when one
thread is managing one GPU device.
)DOC");
}
};

template <typename T>
class BroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW("Broadcast op can run on gpu place only for now.");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(broadcast, ops::BroadcastOp,
ops::BroadcastOpMaker);

REGISTER_OP_CPU_KERNEL(broadcast, ops::BroadcastOpKernel<float>,
ops::BroadcastOpKernel<double>,
ops::BroadcastOpKernel<int>,
ops::BroadcastOpKernel<int64_t>,
ops::BroadcastOpKernel<plat::float16>);
81 changes: 81 additions & 0 deletions paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"The place of ExecutionContext should be CUDAPlace.");

#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int dev_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).device;
int root_dev_id = ctx.Attr<int>("root");

auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE(out->IsInitialized(),
"Currently, the output of broadcast op must be initialized, "
"because this op can only be an In-Place operation.");
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(
send_recv_buffer, in->data<void>(),
"Currently, the broadcast op can only be an In-Place operation.");

auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto comm = dev_ctx.nccl_comm();
auto stream = dev_ctx.stream();

PADDLE_ENFORCE(platform::dynload::ncclBcast(
send_recv_buffer, static_cast<size_t>(in->numel()),
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));

VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")"
<< " From " << root_dev_id << " to " << dev_id;

if (ctx.Attr<bool>("sync_mode")) {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(broadcast, ops::NCCLBroadcastOpKernel<float>,
ops::NCCLBroadcastOpKernel<double>,
ops::NCCLBroadcastOpKernel<int>,
ops::NCCLBroadcastOpKernel<int64_t>,
ops::NCCLBroadcastOpKernel<plat::float16>);
3 changes: 3 additions & 0 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import collections
import six
from . import parallel_helper
from .. import unique_name
from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper
Expand Down Expand Up @@ -154,6 +155,8 @@ def build_once(self, *args):
def __call__(self, *inputs):
if not self._built:
self.build_once(*inputs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(self._parameters.values())

outputs = self.forward(*inputs)
self._built = True
Expand Down
110 changes: 95 additions & 15 deletions python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,38 @@

from .. import core
from . import layers
from . import parallel_helper
from .. import framework

from ..layers import collective
from . import to_variable

__all__ = ["prepare_context"]

ParallelStrategy = core.ParallelStrategy

__parallel_ctx__clz__ = None


def prepare_context(parallel_strategy):
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once."
assert framework.in_dygraph_mode(
) is True, "dygraph.parallel.prepare_context should be used with dygrahp mode."
def prepare_context(strategy=None):
if strategy is None:
strategy = ParallelStrategy()
strategy.nranks = Env().nranks
strategy.local_rank = Env().local_rank
strategy.trainer_endpoints = Env().trainer_endpoints
strategy.current_endpoint = Env().current_endpoint
if strategy.nranks < 2:
return
assert framework.in_dygraph_mode() is True,\
"dygraph.parallel.prepare_context should be used with dygrahp mode."
place = framework._current_expected_place()
assert place is not None, "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard."

assert place is not None, \
"dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard."
if isinstance(place, core.CUDAPlace):
__parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy,
place)
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
else:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert ("Only support CUDAPlace for now.")
__parallel_ctx__clz__.init()
parallel_helper._init_parallel_ctx()
return strategy


class Env(object):
Expand Down Expand Up @@ -77,33 +82,108 @@ def trainer_endpoints(self):


class DataParallel(layers.Layer):
"""
Runs the module with data parallelism.

Currently, DataParallel only supports to run the dynamic graph
with multi-process. The usage is:
`python -m paddle.distributed.launch --gpus 2 dynamic_graph_test.py`.
And the content of `dynamic_graph_test.py` is the code of examples.

Examples:
.. code-block:: python

import numpy as np
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import FC
from paddle.fluid.dygraph.base import to_variable

place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place=place):

# prepare the data parallel context
strategy=dygraph.parallel.prepare_context()

fc_layer = FC("FC", 10, act="softmax")
adam = fluid.optimizer.AdamOptimizer()

# make the module become the data parallelism module
fc_layer = dygraph.parallel.DataParallel(fc_layer, strategy)

x_data = np.random.random(size=[10, 1]).astype(np.float32)
data = to_variable(x_data)

hidden = fc_layer(data)
avg_loss = fluid.layers.mean(hidden)

# scale the loss according to the number of trainers.
avg_loss = fc_layer.scale_loss(avg_loss)

avg_loss.backward()

# collect the gradients of trainers.
fc_layer.apply_collective_grads()

adam.minimize(avg_loss)
fc_layer.clear_gradients()

Args:
layers(Layer): The module that should be executed by data parallel.
strategy(ParallelStrategy): The strategy of data parallelism.

Returns:
Layer: The data paralleled module.
"""

def __init__(self, layers, strategy):
super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel")

self._layers = layers
self._strategy = strategy

def forward(self, *inputs, **kwargs):
return self._layers(*inputs, **kwargs)

def scale_loss(self, loss):
if self._strategy.nranks < 2:
"""
Scale the loss. In data parallel mode, the loss should be scale with
the number of trainers. If not in data parallel mode, return the loss
directly.

Args:
loss(Layer): The loss of the current Model.

Returns:
Layer: the scaled loss.
"""
if not self._is_data_parallel_mode():
return loss

loss_scale = to_variable(
np.array([self._strategy.nranks]).astype("float32"))
loss_scale.stop_gradient = True
loss = loss / loss_scale
return loss

def apply_collective_grads(self):
if self._strategy.nranks < 2:
"""
AllReduce the Parameters' gradient.
"""
if not self._is_data_parallel_mode():
return

for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and param._ivar._grad_ivar():
g_var = framework.Variable(
block=self._helper.main_program.current_block(),
name=param._ivar._grad_name(),
stop_gradient=True,
ivar=param._ivar._grad_ivar())
collective._allreduce(g_var, g_var, sync_mode=True)

def _is_data_parallel_mode(self):
return self._strategy.nranks > 1
43 changes: 43 additions & 0 deletions python/paddle/fluid/dygraph/parallel_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from ..layers import collective

__parallel_ctx__clz__ = None


def _is_data_parallel_mode():
global __parallel_ctx__clz__
return __parallel_ctx__clz__ is not None and int(
os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1


def _set_parallel_ctx(nccl_parallel_context):
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, \
"ParallelContext can only be initialized once."
__parallel_ctx__clz__ = nccl_parallel_context


def _init_parallel_ctx():
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is not None, \
"ParallelContext should be initialized."
__parallel_ctx__clz__.init()


def _broadcast_parameters(parameters):
for param in parameters:
if param.trainable:
collective._broadcast(param, 0, sync_mode=True)
11 changes: 11 additions & 0 deletions python/paddle/fluid/layers/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ def _allreduce(x, out=None, reduce_type="sum", sync_mode=False):
attrs={"reduce_type": red_typ_int,
"sync_mode": sync_mode})
return out


def _broadcast(x, root, sync_mode=False):
helper = LayerHelper("broadcast", **locals())
helper.append_op(
type='broadcast',
inputs={'X': [x]},
outputs={'Out': [x]},
attrs={"sync_mode": sync_mode,
"root": root})
return x