39
39
paddle .framework .seed (1234 )
40
40
np .random .seed (1234 )
41
41
42
+
42
43
def train_model (cfg ,
43
44
weights = None ,
44
45
parallel = True ,
@@ -188,12 +189,21 @@ def train_model(cfg,
188
189
outputs = model (data , mode = 'train' )
189
190
190
191
avg_loss = outputs ['loss' ]
191
- scaled = scaler .scale (avg_loss )
192
- scaled .backward ()
193
- # keep prior to 2.0 design
194
- scaler .minimize (optimizer , scaled )
195
- optimizer .clear_grad ()
196
-
192
+ if use_gradient_accumulation :
193
+ if i == 0 :
194
+ optimizer .clear_grad ()
195
+ avg_loss /= cfg .GRADIENT_ACCUMULATION .num_iters
196
+ scaled = scaler .scale (avg_loss )
197
+ scaled .backward ()
198
+ if (i + 1 ) % cfg .GRADIENT_ACCUMULATION .num_iters == 0 :
199
+ scaler .minimize (optimizer , scaled )
200
+ optimizer .clear_grad ()
201
+ else :
202
+ scaled = scaler .scale (avg_loss )
203
+ scaled .backward ()
204
+ # keep prior to 2.0 design
205
+ scaler .minimize (optimizer , scaled )
206
+ optimizer .clear_grad ()
197
207
else :
198
208
outputs = model (data , mode = 'train' )
199
209
@@ -259,7 +269,6 @@ def evaluate(best):
259
269
if cfg .MODEL .framework != "FastRCNN" :
260
270
for name , value in outputs .items ():
261
271
record_list [name ].update (value , batch_size )
262
-
263
272
264
273
record_list ['batch_time' ].update (time .time () - tic )
265
274
tic = time .time ()
@@ -271,21 +280,21 @@ def evaluate(best):
271
280
if cfg .MODEL .framework == "FastRCNN" :
272
281
if parallel :
273
282
results = collect_results_cpu (results , len (valid_dataset ))
274
- if not parallel or (parallel and rank == 0 ):
275
- eval_res = valid_dataset .evaluate ( results )
283
+ if not parallel or (parallel and rank == 0 ):
284
+ eval_res = valid_dataset .evaluate (results )
276
285
for name , value in eval_res .items ():
277
286
record_list [name ].update (value , valid_batch_size )
278
287
279
-
280
288
ips = "avg_ips: {:.5f} instance/sec." .format (
281
289
valid_batch_size * record_list ["batch_time" ].count /
282
290
record_list ["batch_time" ].sum )
283
291
log_epoch (record_list , epoch + 1 , "val" , ips )
284
292
285
293
best_flag = False
286
- if cfg .MODEL .framework == "FastRCNN" and (not parallel or (parallel and rank == 0 )):
294
+ if cfg .MODEL .framework == "FastRCNN" and (not parallel or
295
+ (parallel and rank == 0 )):
287
296
if record_list ["mAP@0.5IOU" ].val > best :
288
- best = record_list ["mAP@0.5IOU" ].val
297
+ best = record_list ["mAP@0.5IOU" ].val
289
298
best_flag = True
290
299
return best , best_flag
291
300
#best2, cfg.MODEL.framework != "FastRCNN":
0 commit comments