Skip to content

Commit 116529c

Browse files
committed
add check for sparse parameters with weight_decay
1 parent ac48599 commit 116529c

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

python/paddle/fluid/tests/unittests/test_adam_op.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,6 @@ def test_adam_op_dygraph(self):
448448

449449
def test_adam_op_with_state_dict(self):
450450

451-
import paddle
452451
paddle.disable_static()
453452
emb = paddle.nn.Embedding(10, 10)
454453

@@ -517,6 +516,20 @@ def test_adam_op_invalid_input(self):
517516
adam = paddle.optimizer.Adam(
518517
0.1, epsilon=-1, parameters=linear.parameters())
519518

519+
def test_adam_op_with_sparse_input_and_weight_decay(self):
520+
521+
paddle.disable_static()
522+
x_data = np.arange(0, 10).reshape((10, 1)).astype(np.int64)
523+
x = paddle.to_tensor(x_data, stop_gradient=False)
524+
emb = paddle.nn.Embedding(10, 10, sparse=True)
525+
adam = paddle.optimizer.Adam(
526+
0.001, parameters=emb.parameters(), weight_decay=0.01)
527+
528+
with self.assertRaises(RuntimeError):
529+
out = emb(x)
530+
out.backward()
531+
adam.step()
532+
520533

521534
if __name__ == "__main__":
522535
unittest.main()

python/paddle/optimizer/optimizer.py

+6
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,12 @@ def step(self):
913913
for param in self._parameter_list:
914914
if not param.trainable:
915915
continue
916+
if hasattr(
917+
param, "_is_sparse"
918+
) and param._is_sparse and self.regularization is not None:
919+
raise RuntimeError(
920+
"Optimizer don't support weight_decay with sparse parameters, please set it to None."
921+
)
916922
if param._grad_ivar() is not None:
917923
grad_var = param._grad_ivar()
918924
params_grads.append((param, grad_var))

0 commit comments

Comments
 (0)