Skip to content

Commit 1aedf4b

Browse files
author
chengduozh
committed
code refine
test=develop
1 parent 2b3b07c commit 1aedf4b

File tree

5 files changed

+53
-22
lines changed

5 files changed

+53
-22
lines changed

paddle/fluid/operators/distributed_ops/allreduce_op.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
4040
auto in = ctx.Input<framework::Tensor>("X");
4141
auto out = ctx.Output<framework::Tensor>("Out");
4242

43-
int in_dev_id = boost::get<platform::CUDAPlace>(in->place()).device;
4443
int dtype = platform::ToNCCLDataType(in->type());
4544
int64_t numel = in->numel();
4645
auto* sendbuff = in->data<void>();
@@ -68,8 +67,6 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
6867
red_type = ncclMin;
6968
break;
7069
}
71-
VLOG(3) << "AllReduce " << ctx.Inputs("X")[0] << " On " << in_dev_id;
72-
7370
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
7471
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
7572
comm, stream));

paddle/fluid/platform/nccl_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ struct NCCLContextMap {
104104
PADDLE_ENFORCE_EQ(
105105
order_.size(), contexts_.size(),
106106
"NCCL Context Map does not support contain two or more same device");
107-
// NOTE(paddle-dev): Why use std::unique_ptr and the T is ncclComm_t[] here?
107+
108108
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
109109
// if num_trainers == 1, should create a new nccl id for local comms.
110110
if (num_trainers == 1 && nccl_id == nullptr) {

python/paddle/fluid/dygraph/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __call__(self, *inputs):
156156
if not self._built:
157157
self.build_once(*inputs)
158158
if parallel_helper._is_data_parallel_mode():
159-
parallel_helper._broadcast_parameters(self.parameters())
159+
parallel_helper._broadcast_parameters(self._parameters.values())
160160

161161
outputs = self.forward(*inputs)
162162
self._built = True

python/paddle/fluid/dygraph/parallel.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ def prepare_context(strategy=None):
5252

5353

5454
class Env(object):
55-
"""
56-
"""
57-
5855
def __init__(self):
5956
self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
6057
self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
@@ -86,25 +83,58 @@ def trainer_endpoints(self):
8683

8784
class DataParallel(layers.Layer):
8885
"""
89-
DataParallel.
86+
Runs the module with data parallelism.
87+
88+
Currently, DataParallel only supports to run the dynamic graph
89+
with multi-process. The usage is:
90+
`python -m paddle.distributed.launch --gpus 2 dynamic_graph_test.py`.
91+
And the content of `dynamic_graph_test.py` is the code of examples.
9092
9193
Examples:
9294
.. code-block:: python
9395
94-
import paddle.fluid as fluid
95-
import numpy
96-
import os
97-
...
96+
import numpy as np
97+
import paddle.fluid as fluid
98+
import paddle.fluid.dygraph as dygraph
99+
from paddle.fluid.optimizer import AdamOptimizer
100+
from paddle.fluid.dygraph.nn import FC
101+
from paddle.fluid.dygraph.base import to_variable
102+
103+
place = fluid.CUDAPlace(0)
104+
with fluid.dygraph.guard(place=place):
105+
106+
# prepare the data parallel context
107+
strategy=dygraph.parallel.prepare_context()
108+
109+
fc_layer = FC("FC", 10, act="softmax")
110+
adam = fluid.optimizer.AdamOptimizer()
111+
112+
# make the module become the data parallelism module
113+
fc_layer = dygraph.parallel.DataParallel(fc_layer, strategy)
114+
115+
x_data = np.random.random(size=[10, 1]).astype(np.float32)
116+
data = to_variable(x_data)
117+
118+
hidden = fc_layer(data)
119+
avg_loss = fluid.layers.mean(hidden)
120+
121+
# scale the loss according to the number of trainers.
122+
avg_loss = fc_layer.scale_loss(avg_loss)
123+
124+
avg_loss.backward()
125+
126+
# collect the gradients of trainers.
127+
fc_layer.apply_collective_grads()
128+
129+
adam.minimize(avg_loss)
130+
fc_layer.clear_gradients()
98131
99132
Args:
100-
layers(Layer): The layer.Layer.
101-
strategy(ParallelStrategy): The dygraph.parallel.ParallelStrategy.
133+
layers(Layer): The module that should be executed by data parallel.
134+
strategy(ParallelStrategy): The strategy of data parallelism.
102135
103136
Returns:
104-
Layer: The layer.Layer..
105-
106-
Raises:
107-
TypeError: If share_vars_from is provided, but not ParallelExecutor object.
137+
Layer: The data paralleled module.
108138
"""
109139

110140
def __init__(self, layers, strategy):
@@ -119,12 +149,15 @@ def forward(self, *inputs, **kwargs):
119149

120150
def scale_loss(self, loss):
121151
"""
152+
Scale the loss. In data parallel mode, the loss should be scale with
153+
the number of trainers. If not in data parallel mode, return the loss
154+
directly.
122155
123156
Args:
124-
loss(Layer): The layer.Layer.
157+
loss(Layer): The loss of the current Model.
125158
126159
Returns:
127-
Layer: The layer.Layer.
160+
Layer: the scaled loss.
128161
"""
129162
if not self._is_data_parallel_mode():
130163
return loss

python/paddle/fluid/dygraph/parallel_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ def _init_parallel_ctx():
3939

4040
def _broadcast_parameters(parameters):
4141
for param in parameters:
42-
collective._broadcast(param, 0, sync_mode=True)
42+
if param.trainable:
43+
collective._broadcast(param, 0, sync_mode=True)

0 commit comments

Comments
 (0)