Skip to content

Commit a206f43

Browse files
committed
用local_map替换LocalLayer
1 parent 938bc06 commit a206f43

File tree

3 files changed

+40
-137
lines changed

3 files changed

+40
-137
lines changed

test/auto_parallel/hybrid_strategy/single_llama_model.py

+10-26
Original file line numberDiff line numberDiff line change
@@ -256,35 +256,19 @@ def forward(self, prediction_scores, masked_lm_labels):
256256
prediction_scores.astype("float32"),
257257
masked_lm_labels.unsqueeze(2),
258258
)
259-
# XPU dose not support allgather mask with bool dtype, so we use LocalLayer here.
260259
if paddle.device.is_compiled_with_xpu():
261260

262-
class LocalLossLayer(paddle.distributed.LocalLayer):
263-
def __init__(self, out_dist_attrs, grad_dist_attrs):
264-
super().__init__(out_dist_attrs, grad_dist_attrs)
265-
266-
def forward(self, x, mask):
267-
masked_lm_loss = paddle.masked_select(x, mask).astype(
268-
"float32"
269-
)
270-
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
271-
return loss.unsqueeze(0)
272-
273-
out_dist_attrs = [
274-
(
275-
masked_lm_loss.process_mesh,
276-
[dist.Shard(0), dist.Replicate()],
277-
),
278-
]
279-
grad_dist_attrs = [
280-
(
281-
masked_lm_loss.process_mesh,
282-
[dist.Shard(0), dist.Replicate()],
283-
),
284-
None,
285-
]
286-
loss_func = LocalLossLayer(out_dist_attrs, grad_dist_attrs)
261+
def LocalLoss(x, mask):
262+
masked_lm_loss = paddle.masked_select(x, mask).astype("float32")
263+
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
264+
return loss.unsqueeze(0)
287265

266+
loss_func = dist.local_map(
267+
LocalLoss,
268+
[[dist.Shard(0), dist.Replicate()]],
269+
[[dist.Shard(0), dist.Replicate()], None],
270+
masked_lm_loss.process_mesh,
271+
)
288272
loss = loss_func(masked_lm_loss, masked_lm_loss > 0)
289273
loss = loss.mean()
290274
return loss

test/auto_parallel/local_view_compute.py

+23-100
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import hashlib
1615
import random
1716

1817
import numpy as np
1918

2019
import paddle
2120
import paddle.distributed as dist
2221
from paddle.distributed import ProcessMesh, fleet, get_rank, shard_dataloader
23-
from paddle.distributed.auto_parallel.local_layer import LocalLayer
2422
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
2523

2624
base_lr = 0.01 # Learning rate
@@ -30,6 +28,7 @@
3028
batch_num = 100 # Number of batches per epoch
3129
batch_size = 32 # Batch size for training
3230
class_dim = 10
31+
global_local_loss_list = []
3332

3433

3534
class RandomDataset(paddle.io.Dataset):
@@ -65,7 +64,8 @@ def forward(self, x):
6564
return x
6665

6766

68-
def masked_lm_loss_func(pred, label):
67+
def masked_lm_loss_func(pred, label, global_local_loss_list_item=None):
68+
"""自定义损失函数,基于rank进行掩码"""
6969
lossmask = paddle.zeros_like(label).astype('float32')
7070
if dist.get_rank() == 0:
7171
lossmask[:3] = 1
@@ -74,35 +74,23 @@ def masked_lm_loss_func(pred, label):
7474

7575
pred_sub = pred[:, 0:1] # shape [B,1]
7676
label_float = paddle.cast(label, 'float32') # shape [B,1]
77-
7877
raw_loss = paddle.abs(pred_sub - label_float)
79-
8078
lossmask_ = lossmask.reshape([-1]).cast('float32')
8179
raw_loss_flat = raw_loss.reshape([-1]).cast('float32')
8280

8381
masked_lm_loss_sum = paddle.sum(raw_loss_flat * lossmask_)
8482
valid_count = paddle.sum(lossmask_)
8583

8684
loss = masked_lm_loss_sum / (valid_count + 1e-8)
85+
if global_local_loss_list_item is not None:
86+
np.testing.assert_allclose(
87+
global_local_loss_list_item,
88+
loss,
89+
rtol=1e-8,
90+
)
8791
return loss
8892

8993

90-
class LocalViewMaskLoss(LocalLayer):
91-
def __init__(self, out_dist_attrs, grad_dist_attrs):
92-
super().__init__(out_dist_attrs, grad_dist_attrs)
93-
self.local_loss = None
94-
95-
def forward(self, pred, label):
96-
loss = masked_lm_loss_func(pred, label)
97-
self.local_loss = loss
98-
return loss
99-
100-
101-
def get_md5(tensor):
102-
tensor_numpy = tensor.cpu().numpy()
103-
return hashlib.md5(tensor_numpy.tobytes()).hexdigest()
104-
105-
10694
class TestLocalViewCompute:
10795
def __init__(self):
10896
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
@@ -125,23 +113,8 @@ def create_dataset(self):
125113
def run_test_cases(self):
126114
self.set_random_seed()
127115
dataset = self.create_dataset()
128-
dy_hand_loss_list = self.run_dy_hand(dataset)
129-
self.set_random_seed()
130-
dataset = self.create_dataset()
131-
dy_semi_auto_local_loss_list = self.run_dy_semi_auto(dataset)
132-
self.set_random_seed()
133-
dy2s_semi_auto_local_loss_list = self.run_dy2s_semi_auto(dataset)
134-
135-
np.testing.assert_allclose(
136-
dy_hand_loss_list[-1], dy_semi_auto_local_loss_list[-1], rtol=1e-8
137-
)
138-
np.testing.assert_allclose(
139-
dy_semi_auto_local_loss_list[-1],
140-
dy2s_semi_auto_local_loss_list[-1],
141-
rtol=1e-8,
142-
)
143116

144-
def run_dy_hand(self, dataset):
117+
# run_dy_hand_get_local_loss
145118
dist_strategy = fleet.DistributedStrategy()
146119
dist_strategy.hybrid_configs = {
147120
"dp_degree": 2,
@@ -182,21 +155,14 @@ def run_dy_hand(self, dataset):
182155
img, label = data
183156

184157
out = model(img)
185-
lossmask = paddle.zeros_like(label).astype('float32')
186-
if dist.get_rank() == 0:
187-
lossmask[:3] = 1
188-
else:
189-
lossmask[4:9] = 1
190158

191159
avg_loss = masked_lm_loss_func(out, label)
192160
avg_loss.backward()
193161
optimizer.step()
194162
model.clear_gradients()
163+
global_local_loss_list.append(avg_loss.numpy())
195164

196-
loss_list.append(avg_loss.numpy())
197-
return loss_list
198-
199-
def run_dy_semi_auto(self, dataset):
165+
# run_dy_semi_auto
200166
world_process_mesh = ProcessMesh([0, 1], dim_names=["dp"])
201167
model = SimpleNet(
202168
input_size=256, inner_size=102400, output_size=class_dim
@@ -219,73 +185,30 @@ def run_dy_semi_auto(self, dataset):
219185
)
220186

221187
model.train()
222-
out_process_mesh = ProcessMesh([0, 1], dim_names=["dp"])
188+
process_mesh = ProcessMesh([0, 1], dim_names=["dp"])
223189
out_placements = [dist.Partial(dist.ReduceType.kRedAvg)]
224190

225-
local_loss_list = []
226-
227191
for batch_id, data in enumerate(dist_dataloader()):
228192
if batch_id > 10:
229193
break
230194

231195
img, label = data
232196

233197
out = model(img)
234-
loss_func = LocalViewMaskLoss(
235-
out_dist_attrs=[(out_process_mesh, out_placements)],
236-
grad_dist_attrs=[None, None],
198+
loss_func = dist.local_map(
199+
masked_lm_loss_func,
200+
out_placements=out_placements,
201+
in_placements=[None, None],
202+
process_mesh=process_mesh,
203+
)
204+
avg_loss = loss_func(
205+
out,
206+
label,
207+
global_local_loss_list_item=global_local_loss_list[batch_id],
237208
)
238-
avg_loss = loss_func(out, label)
239209
avg_loss.backward()
240-
local_loss_list.append(loss_func.local_loss)
241210
optimizer.step()
242211
model.clear_gradients()
243-
return local_loss_list
244-
245-
def run_dy2s_semi_auto(self, dataset):
246-
world_process_mesh = ProcessMesh([0, 1], dim_names=["dp"])
247-
model = SimpleNet(
248-
input_size=256, inner_size=102400, output_size=class_dim
249-
)
250-
optimizer = paddle.optimizer.AdamW(
251-
learning_rate=base_lr,
252-
weight_decay=l2_decay,
253-
parameters=model.parameters(),
254-
)
255-
256-
sampler = BatchSampler(
257-
dataset, batch_size=batch_size, shuffle=False, drop_last=True
258-
)
259-
train_loader = DataLoader(
260-
dataset, batch_sampler=sampler, num_workers=1, shuffle=False
261-
)
262-
263-
dist_dataloader = shard_dataloader(
264-
dataloader=train_loader, meshes=world_process_mesh, shard_dims="dp"
265-
)
266-
267-
process_mesh = ProcessMesh([0, 1], dim_names=["dp"])
268-
out_placements = [dist.Partial(dist.ReduceType.kRedAvg)]
269-
in_grad_placements = [dist.Shard(0)]
270-
loss_func = LocalViewMaskLoss(
271-
out_dist_attrs=[(process_mesh, out_placements)],
272-
grad_dist_attrs=[(process_mesh, in_grad_placements), None],
273-
)
274-
dist_model = dist.to_static(
275-
model, dist_dataloader, loss_func, optimizer
276-
)
277-
dist_model.train()
278-
279-
local_loss_list = []
280-
for batch_id, data in enumerate(dist_dataloader()):
281-
if batch_id > 10:
282-
break
283-
284-
img, label = data
285-
loss = dist_model(img, label)
286-
local_loss_list.append(loss)
287-
288-
return local_loss_list
289212

290213

291214
if __name__ == '__main__':

test/auto_parallel/pir/vpp_pass_unittest_local_view_pir.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,11 @@ def is_optimize_op(op):
5353
return False
5454

5555

56-
class CustomLayer(dist.LocalLayer):
57-
def __init__(self, out_dist_attrs, grad_dist_attrs):
58-
super().__init__(out_dist_attrs, grad_dist_attrs)
59-
60-
def forward(self, input):
61-
input += 0.1
62-
input -= 0.3
63-
input *= 0.5
64-
return input
56+
def customFunction(input):
57+
input += 0.1
58+
input -= 0.3
59+
input *= 0.5
60+
return input
6561

6662

6763
class MyLinear(nn.Layer):
@@ -95,8 +91,8 @@ def __init__(
9591
[dist.Replicate()],
9692
stop_gradient=False,
9793
)
98-
self.custom_local_layer = CustomLayer(
99-
[(mesh, [dist.Replicate()])], [(mesh, [dist.Replicate()])]
94+
self.custom_local_layer = dist.local_map(
95+
customFunction, [[dist.Replicate()]], [[dist.Replicate()]], mesh
10096
)
10197

10298
def forward(self, input):

0 commit comments

Comments
 (0)