Skip to content

Commit 8e789dc

Browse files
fix load_pre_sharded_checkpoint (#3152) (#3169)
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
1 parent 5f6fc7f commit 8e789dc

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,13 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
215215
"""
216216
load_pre_sharded_checkpoint
217217
"""
218+
from fastdeploy.model_executor.layers.utils import get_tensor
219+
218220
state_dict = {}
219221
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
220222
weights_iterator = safetensors_weights_iterator(safetensor_files)
221223
for name, weight in weights_iterator:
222-
state_dict[name] = weight
224+
state_dict[name] = get_tensor(weight)
223225
return state_dict
224226

225227

0 commit comments

Comments
 (0)