File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -712,6 +712,9 @@ def _load_state_dict_into_meta_model(
712
712
if param_name .startswith (start_prefix ):
713
713
param_name = param_name [len (start_prefix ) :]
714
714
715
+ if param .place != paddle .framework ._current_expected_place ():
716
+ param = param ._copy_to (paddle .framework ._current_expected_place (), False )
717
+
715
718
# # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
716
719
# # in int/uint/bool and not cast them.
717
720
if dtype is not None and paddle .is_floating_point (param ):
@@ -733,7 +736,7 @@ def _load_state_dict_into_meta_model(
733
736
break
734
737
735
738
if old_param is not None :
736
- param = param .to (dtype = old_param .dtype )
739
+ param = param .astype (dtype = old_param .dtype )
737
740
738
741
with paddle .no_grad ():
739
742
model .state_dict ()[param_name ].get_tensor ()._share_data_with (param .value ().get_tensor ())
You can’t perform that action at this time.
0 commit comments