Skip to content

Commit b5f4d5e

Browse files
author
chengduo
authored
Add broadcast operators (#17503)
* This PR adds broadcast for multi-process. And it could be used in dynamic graph to broadcast parameters.
1 parent 2280f18 commit b5f4d5e

File tree

6 files changed

+309
-15
lines changed

6 files changed

+309
-15
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include <ostream>
17+
#include <utility>
18+
#include <vector>
19+
#include "paddle/fluid/framework/op_registry.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
class BroadcastOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext* ctx) const override {
29+
PADDLE_ENFORCE(ctx->HasInput("X"),
30+
"Input(X) of BroadcastOp should not be null.");
31+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
32+
"Output(Output) of ConvOp should not be null.");
33+
}
34+
};
35+
36+
class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
37+
public:
38+
void Make() {
39+
AddInput("X", "(Tensor), tensor to be broadcast.");
40+
AddOutput("Out", "(Tensor) the result of broadcast.");
41+
AddAttr<bool>(
42+
"sync_mode",
43+
"(bool) whether to synchronize the CUDA stream after nccl call.")
44+
.SetDefault(false);
45+
AddAttr<int>("root", "(int).").SetDefault(0).EqualGreaterThan(0);
46+
AddComment(R"DOC(
47+
***Broadcast Operator***
48+
49+
Call NCCL Broadcast internally. Note that this op must be used when one
50+
thread is managing one GPU device.
51+
)DOC");
52+
}
53+
};
54+
55+
template <typename T>
56+
class BroadcastOpKernel : public framework::OpKernel<T> {
57+
public:
58+
void Compute(const framework::ExecutionContext& ctx) const override {
59+
PADDLE_THROW("Broadcast op can run on gpu place only for now.");
60+
}
61+
};
62+
63+
} // namespace operators
64+
} // namespace paddle
65+
66+
namespace ops = paddle::operators;
67+
namespace plat = paddle::platform;
68+
69+
REGISTER_OP_WITHOUT_GRADIENT(broadcast, ops::BroadcastOp,
70+
ops::BroadcastOpMaker);
71+
72+
REGISTER_OP_CPU_KERNEL(broadcast, ops::BroadcastOpKernel<float>,
73+
ops::BroadcastOpKernel<double>,
74+
ops::BroadcastOpKernel<int>,
75+
ops::BroadcastOpKernel<int64_t>,
76+
ops::BroadcastOpKernel<plat::float16>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include <utility>
17+
#include <vector>
18+
19+
#include "paddle/fluid/framework/data_type.h"
20+
#include "paddle/fluid/framework/lod_tensor.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
23+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
24+
#include "paddle/fluid/platform/nccl_helper.h"
25+
#endif
26+
27+
namespace ops = paddle::operators;
28+
namespace plat = paddle::platform;
29+
30+
namespace paddle {
31+
namespace operators {
32+
33+
template <typename T>
34+
class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
35+
public:
36+
void Compute(const framework::ExecutionContext& ctx) const override {
37+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
38+
"The place of ExecutionContext should be CUDAPlace.");
39+
40+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
41+
int dev_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).device;
42+
int root_dev_id = ctx.Attr<int>("root");
43+
44+
auto in = ctx.Input<framework::Tensor>("X");
45+
auto out = ctx.Output<framework::Tensor>("Out");
46+
PADDLE_ENFORCE(out->IsInitialized(),
47+
"Currently, the output of broadcast op must be initialized, "
48+
"because this op can only be an In-Place operation.");
49+
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
50+
PADDLE_ENFORCE_EQ(
51+
send_recv_buffer, in->data<void>(),
52+
"Currently, the broadcast op can only be an In-Place operation.");
53+
54+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
55+
auto comm = dev_ctx.nccl_comm();
56+
auto stream = dev_ctx.stream();
57+
58+
PADDLE_ENFORCE(platform::dynload::ncclBcast(
59+
send_recv_buffer, static_cast<size_t>(in->numel()),
60+
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
61+
62+
VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")"
63+
<< " From " << root_dev_id << " to " << dev_id;
64+
65+
if (ctx.Attr<bool>("sync_mode")) {
66+
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
67+
}
68+
#else
69+
PADDLE_THROW("PaddlePaddle should compile with GPU.");
70+
#endif
71+
}
72+
};
73+
74+
} // namespace operators
75+
} // namespace paddle
76+
77+
REGISTER_OP_CUDA_KERNEL(broadcast, ops::NCCLBroadcastOpKernel<float>,
78+
ops::NCCLBroadcastOpKernel<double>,
79+
ops::NCCLBroadcastOpKernel<int>,
80+
ops::NCCLBroadcastOpKernel<int64_t>,
81+
ops::NCCLBroadcastOpKernel<plat::float16>);

python/paddle/fluid/dygraph/layers.py

+3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import collections
2020
import six
21+
from . import parallel_helper
2122
from .. import unique_name
2223
from paddle.fluid import core
2324
from .layer_object_helper import LayerObjectHelper
@@ -154,6 +155,8 @@ def build_once(self, *args):
154155
def __call__(self, *inputs):
155156
if not self._built:
156157
self.build_once(*inputs)
158+
if parallel_helper._is_data_parallel_mode():
159+
parallel_helper._broadcast_parameters(self._parameters.values())
157160

158161
outputs = self.forward(*inputs)
159162
self._built = True

python/paddle/fluid/dygraph/parallel.py

+95-15
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,38 @@
1717

1818
from .. import core
1919
from . import layers
20+
from . import parallel_helper
2021
from .. import framework
21-
2222
from ..layers import collective
2323
from . import to_variable
2424

2525
__all__ = ["prepare_context"]
2626

2727
ParallelStrategy = core.ParallelStrategy
2828

29-
__parallel_ctx__clz__ = None
30-
3129

32-
def prepare_context(parallel_strategy):
33-
global __parallel_ctx__clz__
34-
assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once."
35-
assert framework.in_dygraph_mode(
36-
) is True, "dygraph.parallel.prepare_context should be used with dygrahp mode."
30+
def prepare_context(strategy=None):
31+
if strategy is None:
32+
strategy = ParallelStrategy()
33+
strategy.nranks = Env().nranks
34+
strategy.local_rank = Env().local_rank
35+
strategy.trainer_endpoints = Env().trainer_endpoints
36+
strategy.current_endpoint = Env().current_endpoint
37+
if strategy.nranks < 2:
38+
return
39+
assert framework.in_dygraph_mode() is True,\
40+
"dygraph.parallel.prepare_context should be used with dygrahp mode."
3741
place = framework._current_expected_place()
38-
assert place is not None, "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard."
39-
42+
assert place is not None, \
43+
"dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard."
4044
if isinstance(place, core.CUDAPlace):
41-
__parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy,
42-
place)
45+
parallel_helper._set_parallel_ctx(
46+
core.NCCLParallelContext(strategy, place))
4347
else:
4448
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
4549
assert ("Only support CUDAPlace for now.")
46-
__parallel_ctx__clz__.init()
50+
parallel_helper._init_parallel_ctx()
51+
return strategy
4752

4853

4954
class Env(object):
@@ -77,33 +82,108 @@ def trainer_endpoints(self):
7782

7883

7984
class DataParallel(layers.Layer):
85+
"""
86+
Runs the module with data parallelism.
87+
88+
Currently, DataParallel only supports to run the dynamic graph
89+
with multi-process. The usage is:
90+
`python -m paddle.distributed.launch --gpus 2 dynamic_graph_test.py`.
91+
And the content of `dynamic_graph_test.py` is the code of examples.
92+
93+
Examples:
94+
.. code-block:: python
95+
96+
import numpy as np
97+
import paddle.fluid as fluid
98+
import paddle.fluid.dygraph as dygraph
99+
from paddle.fluid.optimizer import AdamOptimizer
100+
from paddle.fluid.dygraph.nn import FC
101+
from paddle.fluid.dygraph.base import to_variable
102+
103+
place = fluid.CUDAPlace(0)
104+
with fluid.dygraph.guard(place=place):
105+
106+
# prepare the data parallel context
107+
strategy=dygraph.parallel.prepare_context()
108+
109+
fc_layer = FC("FC", 10, act="softmax")
110+
adam = fluid.optimizer.AdamOptimizer()
111+
112+
# make the module become the data parallelism module
113+
fc_layer = dygraph.parallel.DataParallel(fc_layer, strategy)
114+
115+
x_data = np.random.random(size=[10, 1]).astype(np.float32)
116+
data = to_variable(x_data)
117+
118+
hidden = fc_layer(data)
119+
avg_loss = fluid.layers.mean(hidden)
120+
121+
# scale the loss according to the number of trainers.
122+
avg_loss = fc_layer.scale_loss(avg_loss)
123+
124+
avg_loss.backward()
125+
126+
# collect the gradients of trainers.
127+
fc_layer.apply_collective_grads()
128+
129+
adam.minimize(avg_loss)
130+
fc_layer.clear_gradients()
131+
132+
Args:
133+
layers(Layer): The module that should be executed by data parallel.
134+
strategy(ParallelStrategy): The strategy of data parallelism.
135+
136+
Returns:
137+
Layer: The data paralleled module.
138+
"""
139+
80140
def __init__(self, layers, strategy):
81141
super(DataParallel,
82142
self).__init__(layers.full_name() + "_data_parallel")
143+
83144
self._layers = layers
84145
self._strategy = strategy
85146

86147
def forward(self, *inputs, **kwargs):
87148
return self._layers(*inputs, **kwargs)
88149

89150
def scale_loss(self, loss):
90-
if self._strategy.nranks < 2:
151+
"""
152+
Scale the loss. In data parallel mode, the loss should be scale with
153+
the number of trainers. If not in data parallel mode, return the loss
154+
directly.
155+
156+
Args:
157+
loss(Layer): The loss of the current Model.
158+
159+
Returns:
160+
Layer: the scaled loss.
161+
"""
162+
if not self._is_data_parallel_mode():
91163
return loss
164+
92165
loss_scale = to_variable(
93166
np.array([self._strategy.nranks]).astype("float32"))
94167
loss_scale.stop_gradient = True
95168
loss = loss / loss_scale
96169
return loss
97170

98171
def apply_collective_grads(self):
99-
if self._strategy.nranks < 2:
172+
"""
173+
AllReduce the Parameters' gradient.
174+
"""
175+
if not self._is_data_parallel_mode():
100176
return
101177

102178
for param in self._layers.parameters():
179+
# NOTE(zcd): The grad_ivar maybe no generated.
103180
if param.trainable and param._ivar._grad_ivar():
104181
g_var = framework.Variable(
105182
block=self._helper.main_program.current_block(),
106183
name=param._ivar._grad_name(),
107184
stop_gradient=True,
108185
ivar=param._ivar._grad_ivar())
109186
collective._allreduce(g_var, g_var, sync_mode=True)
187+
188+
def _is_data_parallel_mode(self):
189+
return self._strategy.nranks > 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except jin compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from ..layers import collective
16+
17+
__parallel_ctx__clz__ = None
18+
19+
20+
def _is_data_parallel_mode():
21+
global __parallel_ctx__clz__
22+
return __parallel_ctx__clz__ is not None and int(
23+
os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1
24+
25+
26+
def _set_parallel_ctx(nccl_parallel_context):
27+
global __parallel_ctx__clz__
28+
assert __parallel_ctx__clz__ is None, \
29+
"ParallelContext can only be initialized once."
30+
__parallel_ctx__clz__ = nccl_parallel_context
31+
32+
33+
def _init_parallel_ctx():
34+
global __parallel_ctx__clz__
35+
assert __parallel_ctx__clz__ is not None, \
36+
"ParallelContext should be initialized."
37+
__parallel_ctx__clz__.init()
38+
39+
40+
def _broadcast_parameters(parameters):
41+
for param in parameters:
42+
if param.trainable:
43+
collective._broadcast(param, 0, sync_mode=True)

python/paddle/fluid/layers/collective.py

+11
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,14 @@ def _allreduce(x, out=None, reduce_type="sum", sync_mode=False):
4646
attrs={"reduce_type": red_typ_int,
4747
"sync_mode": sync_mode})
4848
return out
49+
50+
51+
def _broadcast(x, root, sync_mode=False):
52+
helper = LayerHelper("broadcast", **locals())
53+
helper.append_op(
54+
type='broadcast',
55+
inputs={'X': [x]},
56+
outputs={'Out': [x]},
57+
attrs={"sync_mode": sync_mode,
58+
"root": root})
59+
return x

0 commit comments

Comments
 (0)