@@ -281,6 +281,7 @@ def __init__(
281
281
self ._use_multi_tensor = None
282
282
self .regularization = None
283
283
self ._auxiliary_vars = {}
284
+ self ._already_create_accumulater = set ()
284
285
285
286
def _set_auxiliary_var (self , key , val ):
286
287
self ._auxiliary_vars [key ] = val
@@ -422,9 +423,12 @@ def _create_accumulators(self, block, parameters):
422
423
423
424
# Create accumulator tensors for first and second moments
424
425
for p in parameters :
426
+ if p .name in self ._already_create_accumulater :
427
+ continue
425
428
if self ._multi_precision and self ._is_dtype_fp16_or_bf16 (p .dtype ):
426
429
master_p = self ._create_master_weight (p )
427
430
self ._add_moments_pows (master_p )
431
+ self ._already_create_accumulater .add (p .name )
428
432
continue
429
433
if (
430
434
self ._is_dtype_fp16_or_bf16 (p .dtype )
@@ -435,6 +439,7 @@ def _create_accumulators(self, block, parameters):
435
439
"Consider using multi_precision=True option of the Adam optimizer."
436
440
)
437
441
self ._add_moments_pows (p )
442
+ self ._already_create_accumulater .add (p .name )
438
443
439
444
def _append_optimize_op (self , block , param_and_grad ):
440
445
assert isinstance (block , framework .Block )
0 commit comments