File tree 2 files changed +20
-1
lines changed
2 files changed +20
-1
lines changed Original file line number Diff line number Diff line change @@ -448,7 +448,6 @@ def test_adam_op_dygraph(self):
448
448
449
449
def test_adam_op_with_state_dict (self ):
450
450
451
- import paddle
452
451
paddle .disable_static ()
453
452
emb = paddle .nn .Embedding (10 , 10 )
454
453
@@ -517,6 +516,20 @@ def test_adam_op_invalid_input(self):
517
516
adam = paddle .optimizer .Adam (
518
517
0.1 , epsilon = - 1 , parameters = linear .parameters ())
519
518
519
+ def test_adam_op_with_sparse_input_and_weight_decay (self ):
520
+
521
+ paddle .disable_static ()
522
+ x_data = np .arange (0 , 10 ).reshape ((10 , 1 )).astype (np .int64 )
523
+ x = paddle .to_tensor (x_data , stop_gradient = False )
524
+ emb = paddle .nn .Embedding (10 , 10 , sparse = True )
525
+ adam = paddle .optimizer .Adam (
526
+ 0.001 , parameters = emb .parameters (), weight_decay = 0.01 )
527
+
528
+ with self .assertRaises (RuntimeError ):
529
+ out = emb (x )
530
+ out .backward ()
531
+ adam .step ()
532
+
520
533
521
534
if __name__ == "__main__" :
522
535
unittest .main ()
Original file line number Diff line number Diff line change @@ -913,6 +913,12 @@ def step(self):
913
913
for param in self ._parameter_list :
914
914
if not param .trainable :
915
915
continue
916
+ if hasattr (
917
+ param , "_is_sparse"
918
+ ) and param ._is_sparse and self .regularization is not None :
919
+ raise RuntimeError (
920
+ "Optimizer don't support weight_decay with sparse parameters, please set it to None."
921
+ )
916
922
if param ._grad_ivar () is not None :
917
923
grad_var = param ._grad_ivar ()
918
924
params_grads .append ((param , grad_var ))
You can’t perform that action at this time.
0 commit comments