12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import hashlib
16
15
import random
17
16
18
17
import numpy as np
19
18
20
19
import paddle
21
20
import paddle .distributed as dist
22
21
from paddle .distributed import ProcessMesh , fleet , get_rank , shard_dataloader
23
- from paddle .distributed .auto_parallel .local_layer import LocalLayer
24
22
from paddle .io import BatchSampler , DataLoader , DistributedBatchSampler
25
23
26
24
base_lr = 0.01 # Learning rate
30
28
batch_num = 100 # Number of batches per epoch
31
29
batch_size = 32 # Batch size for training
32
30
class_dim = 10
31
+ global_local_loss_list = []
33
32
34
33
35
34
class RandomDataset (paddle .io .Dataset ):
@@ -65,7 +64,8 @@ def forward(self, x):
65
64
return x
66
65
67
66
68
- def masked_lm_loss_func (pred , label ):
67
+ def masked_lm_loss_func (pred , label , global_local_loss_list_item = None ):
68
+ """自定义损失函数,基于rank进行掩码"""
69
69
lossmask = paddle .zeros_like (label ).astype ('float32' )
70
70
if dist .get_rank () == 0 :
71
71
lossmask [:3 ] = 1
@@ -74,35 +74,23 @@ def masked_lm_loss_func(pred, label):
74
74
75
75
pred_sub = pred [:, 0 :1 ] # shape [B,1]
76
76
label_float = paddle .cast (label , 'float32' ) # shape [B,1]
77
-
78
77
raw_loss = paddle .abs (pred_sub - label_float )
79
-
80
78
lossmask_ = lossmask .reshape ([- 1 ]).cast ('float32' )
81
79
raw_loss_flat = raw_loss .reshape ([- 1 ]).cast ('float32' )
82
80
83
81
masked_lm_loss_sum = paddle .sum (raw_loss_flat * lossmask_ )
84
82
valid_count = paddle .sum (lossmask_ )
85
83
86
84
loss = masked_lm_loss_sum / (valid_count + 1e-8 )
85
+ if global_local_loss_list_item is not None :
86
+ np .testing .assert_allclose (
87
+ global_local_loss_list_item ,
88
+ loss ,
89
+ rtol = 1e-8 ,
90
+ )
87
91
return loss
88
92
89
93
90
- class LocalViewMaskLoss (LocalLayer ):
91
- def __init__ (self , out_dist_attrs , grad_dist_attrs ):
92
- super ().__init__ (out_dist_attrs , grad_dist_attrs )
93
- self .local_loss = None
94
-
95
- def forward (self , pred , label ):
96
- loss = masked_lm_loss_func (pred , label )
97
- self .local_loss = loss
98
- return loss
99
-
100
-
101
- def get_md5 (tensor ):
102
- tensor_numpy = tensor .cpu ().numpy ()
103
- return hashlib .md5 (tensor_numpy .tobytes ()).hexdigest ()
104
-
105
-
106
94
class TestLocalViewCompute :
107
95
def __init__ (self ):
108
96
self ._mesh = dist .ProcessMesh ([0 , 1 ], dim_names = ["x" ])
@@ -125,23 +113,8 @@ def create_dataset(self):
125
113
def run_test_cases (self ):
126
114
self .set_random_seed ()
127
115
dataset = self .create_dataset ()
128
- dy_hand_loss_list = self .run_dy_hand (dataset )
129
- self .set_random_seed ()
130
- dataset = self .create_dataset ()
131
- dy_semi_auto_local_loss_list = self .run_dy_semi_auto (dataset )
132
- self .set_random_seed ()
133
- dy2s_semi_auto_local_loss_list = self .run_dy2s_semi_auto (dataset )
134
-
135
- np .testing .assert_allclose (
136
- dy_hand_loss_list [- 1 ], dy_semi_auto_local_loss_list [- 1 ], rtol = 1e-8
137
- )
138
- np .testing .assert_allclose (
139
- dy_semi_auto_local_loss_list [- 1 ],
140
- dy2s_semi_auto_local_loss_list [- 1 ],
141
- rtol = 1e-8 ,
142
- )
143
116
144
- def run_dy_hand ( self , dataset ):
117
+ # run_dy_hand_get_local_loss
145
118
dist_strategy = fleet .DistributedStrategy ()
146
119
dist_strategy .hybrid_configs = {
147
120
"dp_degree" : 2 ,
@@ -182,21 +155,14 @@ def run_dy_hand(self, dataset):
182
155
img , label = data
183
156
184
157
out = model (img )
185
- lossmask = paddle .zeros_like (label ).astype ('float32' )
186
- if dist .get_rank () == 0 :
187
- lossmask [:3 ] = 1
188
- else :
189
- lossmask [4 :9 ] = 1
190
158
191
159
avg_loss = masked_lm_loss_func (out , label )
192
160
avg_loss .backward ()
193
161
optimizer .step ()
194
162
model .clear_gradients ()
163
+ global_local_loss_list .append (avg_loss .numpy ())
195
164
196
- loss_list .append (avg_loss .numpy ())
197
- return loss_list
198
-
199
- def run_dy_semi_auto (self , dataset ):
165
+ # run_dy_semi_auto
200
166
world_process_mesh = ProcessMesh ([0 , 1 ], dim_names = ["dp" ])
201
167
model = SimpleNet (
202
168
input_size = 256 , inner_size = 102400 , output_size = class_dim
@@ -219,73 +185,30 @@ def run_dy_semi_auto(self, dataset):
219
185
)
220
186
221
187
model .train ()
222
- out_process_mesh = ProcessMesh ([0 , 1 ], dim_names = ["dp" ])
188
+ process_mesh = ProcessMesh ([0 , 1 ], dim_names = ["dp" ])
223
189
out_placements = [dist .Partial (dist .ReduceType .kRedAvg )]
224
190
225
- local_loss_list = []
226
-
227
191
for batch_id , data in enumerate (dist_dataloader ()):
228
192
if batch_id > 10 :
229
193
break
230
194
231
195
img , label = data
232
196
233
197
out = model (img )
234
- loss_func = LocalViewMaskLoss (
235
- out_dist_attrs = [(out_process_mesh , out_placements )],
236
- grad_dist_attrs = [None , None ],
198
+ loss_func = dist .local_map (
199
+ masked_lm_loss_func ,
200
+ out_placements = out_placements ,
201
+ in_placements = [None , None ],
202
+ process_mesh = process_mesh ,
203
+ )
204
+ avg_loss = loss_func (
205
+ out ,
206
+ label ,
207
+ global_local_loss_list_item = global_local_loss_list [batch_id ],
237
208
)
238
- avg_loss = loss_func (out , label )
239
209
avg_loss .backward ()
240
- local_loss_list .append (loss_func .local_loss )
241
210
optimizer .step ()
242
211
model .clear_gradients ()
243
- return local_loss_list
244
-
245
- def run_dy2s_semi_auto (self , dataset ):
246
- world_process_mesh = ProcessMesh ([0 , 1 ], dim_names = ["dp" ])
247
- model = SimpleNet (
248
- input_size = 256 , inner_size = 102400 , output_size = class_dim
249
- )
250
- optimizer = paddle .optimizer .AdamW (
251
- learning_rate = base_lr ,
252
- weight_decay = l2_decay ,
253
- parameters = model .parameters (),
254
- )
255
-
256
- sampler = BatchSampler (
257
- dataset , batch_size = batch_size , shuffle = False , drop_last = True
258
- )
259
- train_loader = DataLoader (
260
- dataset , batch_sampler = sampler , num_workers = 1 , shuffle = False
261
- )
262
-
263
- dist_dataloader = shard_dataloader (
264
- dataloader = train_loader , meshes = world_process_mesh , shard_dims = "dp"
265
- )
266
-
267
- process_mesh = ProcessMesh ([0 , 1 ], dim_names = ["dp" ])
268
- out_placements = [dist .Partial (dist .ReduceType .kRedAvg )]
269
- in_grad_placements = [dist .Shard (0 )]
270
- loss_func = LocalViewMaskLoss (
271
- out_dist_attrs = [(process_mesh , out_placements )],
272
- grad_dist_attrs = [(process_mesh , in_grad_placements ), None ],
273
- )
274
- dist_model = dist .to_static (
275
- model , dist_dataloader , loss_func , optimizer
276
- )
277
- dist_model .train ()
278
-
279
- local_loss_list = []
280
- for batch_id , data in enumerate (dist_dataloader ()):
281
- if batch_id > 10 :
282
- break
283
-
284
- img , label = data
285
- loss = dist_model (img , label )
286
- local_loss_list .append (loss )
287
-
288
- return local_loss_list
289
212
290
213
291
214
if __name__ == '__main__' :
0 commit comments