diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 2856737ff5..01f81ac13d 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -215,11 +215,13 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete """ load_pre_sharded_checkpoint """ + from fastdeploy.model_executor.layers.utils import get_tensor + state_dict = {} _, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}")) weights_iterator = safetensors_weights_iterator(safetensor_files) for name, weight in weights_iterator: - state_dict[name] = weight + state_dict[name] = get_tensor(weight) return state_dict