Skip to content

Commit 9b993b2

Browse files
authored
[LLM] fix low cpu mem device. (PaddlePaddle#6300)
1 parent aa4a62c commit 9b993b2

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,9 @@ def _load_state_dict_into_meta_model(
712712
if param_name.startswith(start_prefix):
713713
param_name = param_name[len(start_prefix) :]
714714

715+
if param.place != paddle.framework._current_expected_place():
716+
param = param._copy_to(paddle.framework._current_expected_place(), False)
717+
715718
# # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
716719
# # in int/uint/bool and not cast them.
717720
if dtype is not None and paddle.is_floating_point(param):
@@ -733,7 +736,7 @@ def _load_state_dict_into_meta_model(
733736
break
734737

735738
if old_param is not None:
736-
param = param.to(dtype=old_param.dtype)
739+
param = param.astype(dtype=old_param.dtype)
737740

738741
with paddle.no_grad():
739742
model.state_dict()[param_name].get_tensor()._share_data_with(param.value().get_tensor())

0 commit comments

Comments
 (0)