Skip to content

Commit c928a35

Browse files
set_state_dict return missing_keys and unexpected_keys (#48436)
* refine set_state_dict
1 parent f5c520b commit c928a35

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

python/paddle/fluid/dygraph/layers.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,8 @@ def set_state_dict(self, state_dict, use_structured_name=True):
16001600
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
16011601
Default: True
16021602
Returns:
1603-
None
1603+
missing_keys(list):A list of str containing the missing keys
1604+
unexpected_keys(list):A list of str containing the unexpected keys
16041605
16051606
Examples:
16061607
.. code-block:: python
@@ -1615,22 +1616,28 @@ def set_state_dict(self, state_dict, use_structured_name=True):
16151616
emb.set_state_dict(para_state_dict)
16161617
16171618
'''
1619+
missing_keys = []
1620+
match_keys = set()
1621+
unexpected_keys = []
16181622

16191623
def _check_match(key, param):
16201624
state = state_dict.get(key, None)
16211625
if state is None:
1626+
missing_keys.append(key)
16221627
raise ValueError(
16231628
"{} is not found in the provided dict.".format(key)
16241629
)
16251630
if isinstance(state, dict) or isinstance(state, list):
16261631
if len(state) != len(param):
1632+
missing_keys.append(key)
16271633
raise ValueError(
16281634
"{} receieves the length of {}, "
16291635
"but the expected shape is {}".format(
16301636
key, len(state), len(param)
16311637
)
16321638
)
16331639
else:
1640+
match_keys.add(key)
16341641
return param, state
16351642
else:
16361643
state_shape = (
@@ -1640,11 +1647,13 @@ def _check_match(key, param):
16401647
)
16411648

16421649
if list(state_shape) != list(param.shape):
1650+
missing_keys.append(key)
16431651
raise ValueError(
16441652
"{} receives a shape {}, but the expected shape is {}.".format(
16451653
key, list(state_shape), list(param.shape)
16461654
)
16471655
)
1656+
match_keys.add(key)
16481657
return param, state
16491658

16501659
matched_param_state = []
@@ -1655,7 +1664,9 @@ def _check_match(key, param):
16551664
matched_param_state.append(match_res)
16561665
except ValueError as err:
16571666
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
1658-
1667+
for key in state_dict.keys():
1668+
if key not in match_keys:
1669+
unexpected_keys.append(key)
16591670
if _non_static_mode():
16601671
for param, state in matched_param_state:
16611672
param.set_value(state)
@@ -1693,6 +1704,8 @@ def _set_var(var, ndarray):
16931704
"This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
16941705
)
16951706

1707+
return missing_keys, unexpected_keys
1708+
16961709
def to(self, device=None, dtype=None, blocking=None):
16971710
'''
16981711
Cast the parameters and buffers of Layer by the give device, dtype and blocking.

python/paddle/fluid/tests/unittests/test_state_dict_convert.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def set_state_dict(self, state_dict, use_structured_name=True):
5353
return super().set_state_dict(state_dict)
5454

5555

56+
class MyModel2(nn.Layer):
57+
def __init__(self):
58+
super().__init__()
59+
self.linear = nn.Linear(100, 300)
60+
61+
def forward(self, x):
62+
return self.linear(x)
63+
64+
5665
def is_state_dict_equal(model1, model2):
5766
st1 = model1.state_dict()
5867
st2 = model2.state_dict()
@@ -73,5 +82,18 @@ def test_main(self):
7382
self.assertTrue(is_state_dict_equal(model1, model2))
7483

7584

85+
class TestStateDictReturn(unittest.TestCase):
86+
def test_missing_keys_and_unexpected_keys(self):
87+
model1 = MyModel2()
88+
tmp_dict = dict()
89+
tmp_dict["unexpected_keys"] = paddle.to_tensor(1)
90+
missing_keys, unexpected_keys = model1.set_state_dict(tmp_dict)
91+
self.assertEqual(len(missing_keys), 2)
92+
self.assertEqual(missing_keys[0], "linear.weight")
93+
self.assertEqual(missing_keys[1], "linear.bias")
94+
self.assertEqual(len(unexpected_keys), 1)
95+
self.assertEqual(unexpected_keys[0], "unexpected_keys")
96+
97+
7698
if __name__ == "__main__":
7799
unittest.main()

0 commit comments

Comments
 (0)