Skip to content

Commit aa66ef0

Browse files
authored
Merge pull request #295 from HydrogenSulfate/fix_amp_accu_grad
add accumulate grad in amp mode
2 parents 8966b4e + 6ccbbc3 commit aa66ef0

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

paddlevideo/tasks/train.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
paddle.framework.seed(1234)
4040
np.random.seed(1234)
4141

42+
4243
def train_model(cfg,
4344
weights=None,
4445
parallel=True,
@@ -188,12 +189,21 @@ def train_model(cfg,
188189
outputs = model(data, mode='train')
189190

190191
avg_loss = outputs['loss']
191-
scaled = scaler.scale(avg_loss)
192-
scaled.backward()
193-
# keep prior to 2.0 design
194-
scaler.minimize(optimizer, scaled)
195-
optimizer.clear_grad()
196-
192+
if use_gradient_accumulation:
193+
if i == 0:
194+
optimizer.clear_grad()
195+
avg_loss /= cfg.GRADIENT_ACCUMULATION.num_iters
196+
scaled = scaler.scale(avg_loss)
197+
scaled.backward()
198+
if (i + 1) % cfg.GRADIENT_ACCUMULATION.num_iters == 0:
199+
scaler.minimize(optimizer, scaled)
200+
optimizer.clear_grad()
201+
else:
202+
scaled = scaler.scale(avg_loss)
203+
scaled.backward()
204+
# keep prior to 2.0 design
205+
scaler.minimize(optimizer, scaled)
206+
optimizer.clear_grad()
197207
else:
198208
outputs = model(data, mode='train')
199209

@@ -259,7 +269,6 @@ def evaluate(best):
259269
if cfg.MODEL.framework != "FastRCNN":
260270
for name, value in outputs.items():
261271
record_list[name].update(value, batch_size)
262-
263272

264273
record_list['batch_time'].update(time.time() - tic)
265274
tic = time.time()
@@ -271,21 +280,21 @@ def evaluate(best):
271280
if cfg.MODEL.framework == "FastRCNN":
272281
if parallel:
273282
results = collect_results_cpu(results, len(valid_dataset))
274-
if not parallel or (parallel and rank==0):
275-
eval_res = valid_dataset.evaluate( results)
283+
if not parallel or (parallel and rank == 0):
284+
eval_res = valid_dataset.evaluate(results)
276285
for name, value in eval_res.items():
277286
record_list[name].update(value, valid_batch_size)
278287

279-
280288
ips = "avg_ips: {:.5f} instance/sec.".format(
281289
valid_batch_size * record_list["batch_time"].count /
282290
record_list["batch_time"].sum)
283291
log_epoch(record_list, epoch + 1, "val", ips)
284292

285293
best_flag = False
286-
if cfg.MODEL.framework == "FastRCNN" and (not parallel or (parallel and rank==0)):
294+
if cfg.MODEL.framework == "FastRCNN" and (not parallel or
295+
(parallel and rank == 0)):
287296
if record_list["mAP@0.5IOU"].val > best:
288-
best = record_list["mAP@0.5IOU"].val
297+
best = record_list["mAP@0.5IOU"].val
289298
best_flag = True
290299
return best, best_flag
291300
#best2, cfg.MODEL.framework != "FastRCNN":

0 commit comments

Comments
 (0)