Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ff27956
drafts
ArthurZucker Oct 7, 2025
cfe8766
up
ArthurZucker Oct 8, 2025
e8e32f6
simplify modeling utils
ArthurZucker Oct 8, 2025
e9b9270
more simplifications
ArthurZucker Oct 8, 2025
9e95067
type kwargs
ArthurZucker Oct 8, 2025
e3da536
up
ArthurZucker Oct 8, 2025
98b0c74
move more accelerate related stuff
ArthurZucker Oct 8, 2025
281fcdd
safeguarding?
ArthurZucker Oct 8, 2025
0acea98
nits
ArthurZucker Oct 8, 2025
f1ef942
remove func when func is NOPE
ArthurZucker Oct 8, 2025
c7790bf
more
ArthurZucker Oct 8, 2025
7055ef7
nits
ArthurZucker Oct 8, 2025
65ca13e
styling
ArthurZucker Oct 8, 2025
eeb19ec
yups
ArthurZucker Oct 8, 2025
3830685
Merge branch 'main' of github.com:huggingface/transformers into dynam…
ArthurZucker Oct 8, 2025
9d994cc
up
ArthurZucker Oct 8, 2025
16372b3
ups
ArthurZucker Oct 8, 2025
f06524a
revert
ArthurZucker Oct 8, 2025
b6a9600
protect trainer utils iport
ArthurZucker Oct 8, 2025
0ce1046
fix doc
ArthurZucker Oct 8, 2025
2e4f5b8
Update src/transformers/integrations/peft.py
ArthurZucker Oct 9, 2025
61e15b4
review
ArthurZucker Oct 9, 2025
88ff305
Merge branch 'dynamic-weight-loader' of github.com:huggingface/transf…
ArthurZucker Oct 9, 2025
1d81fa7
update
ArthurZucker Oct 9, 2025
2199d3b
?
ArthurZucker Oct 9, 2025
d8a135f
fixx
ArthurZucker Oct 9, 2025
4985198
Merge branch 'main' into dynamic-weight-loader
ArthurZucker Oct 9, 2025
6356518
Merge branch 'main' into dynamic-weight-loader
ArthurZucker Oct 9, 2025
62b7961
update
ArthurZucker Oct 9, 2025
6e4a896
Merge branch 'dynamic-weight-loader' of github.com:huggingface/transf…
ArthurZucker Oct 9, 2025
8e1c1fb
super small update
ArthurZucker Oct 10, 2025
2522c95
ups
ArthurZucker Oct 10, 2025
f70fae6
style
ArthurZucker Oct 10, 2025
7b4b65f
this is stupid
ArthurZucker Oct 10, 2025
155b44e
Merge branch 'main' into dynamic-weight-loader
ArthurZucker Oct 13, 2025
ac053a3
:facepalm: well this was the issue
ArthurZucker Oct 13, 2025
c84d675
small nit
ArthurZucker Oct 13, 2025
06515b6
fix
ArthurZucker Oct 13, 2025
0b4c662
nit
ArthurZucker Oct 13, 2025
4605d79
damn the missing return
ArthurZucker Oct 13, 2025
d067ae3
one last stupid fix
ArthurZucker Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/source/en/main_classes/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,3 @@ set this to `False`.
## Pushing to the Hub

[[autodoc]] utils.PushToHubMixin

## Sharded checkpoints

[[autodoc]] modeling_utils.load_sharded_checkpoint
54 changes: 54 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
import os
import warnings
Expand Down Expand Up @@ -363,6 +364,59 @@ class GenerationMixin(ContinuousMixin):
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""

def adjust_generation_function(
self,
generation_config,
from_auto_class,
from_pipeline,
pretrained_model_name_or_path,
cache_dir,
force_download,
proxies,
local_files_only,
token,
revision,
subfolder,
trust_remote_code,
**kwargs,
):
if self.can_generate() and generation_config is not None:
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
self.generation_config = self.generation_config.from_dict(generation_config.to_dict())
elif self.can_generate() and pretrained_model_name_or_path is not None:
repo_loading_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
**kwargs,
}
# Load generation config
try:
self.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**repo_loading_kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(self, "load_custom_generate"):
try:
custom_generate = self.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
)
self.generate = functools.partial(custom_generate, model=self)
except OSError: # there is no custom generate function
pass

def load_custom_generate(
self,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
Expand Down
157 changes: 156 additions & 1 deletion src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,24 @@
`find_tied_parameters` was copied from `accelerate.utils.modeling.py`
"""

import collections
import inspect
import os
from contextlib import contextmanager

from ..utils import is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationMethod
from .deepspeed import is_deepspeed_zero3_enabled
from .fsdp import is_fsdp_enabled


if is_torch_available():
import torch
import torch.nn as nn

if is_accelerate_available():
from accelerate import dispatch_model


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -194,3 +203,149 @@ def find_tied_parameters(model: "nn.Module", **kwargs):
tied_param_groups[param_name].append(tied_param_name)

return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]


def auto_set_device_map(device_map):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of the function name - maybe check_and_set_device_map or something similar? For maximum clarity

from ..modeling_utils import get_torch_context_manager_or_global_device

# Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None and not is_deepspeed_zero3_enabled():
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
raise RuntimeError(
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
)
device_map = device_in_context

# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
try:
device_map = {"": torch.device(device_map)}
except RuntimeError:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
)
elif isinstance(device_map, int):
if device_map < 0:
raise ValueError(
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
)
else:
device_map = {"": device_map}

if device_map is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
if not is_accelerate_available():
raise ValueError(
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
"requires `accelerate`. You can install it with `pip install accelerate`"
)


def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers):
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
"offload_index": offload_index,
"offload_buffers": offload_buffers,
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
# For HQQ method we force-set the hooks for single GPU envs
if (
"force_hooks" in inspect.signature(dispatch_model).parameters
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if (
hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
device_map_kwargs["offload_buffers"] = True

if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)


def get_disk_only_shard_files(device_map, weight_map):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
files_content = collections.defaultdict(list)
for weight_name, filename in weight_map.items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])

return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]


def expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map


def accelerate_disk_offload(
disk_offload_folder,
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
):
disk_only_shard_files = []
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, checkpoint_keys)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": reverse_key_renaming_mapping[name],
"dtype": str_dtype,
}
for name, file in weight_map.items()
if param_device_map[name] == "disk"
}
else:
disk_offload_index = {}
return disk_offload_index, disk_only_shard_files
63 changes: 61 additions & 2 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import importlib.metadata
import inspect
import json
import os
import re
from typing import Any, Optional, Union

from packaging import version

from ..utils import (
CONFIG_NAME,
cached_file,
check_peft_version,
extract_commit_hash,
find_adapter_config_file,
is_accelerate_available,
is_peft_available,
is_torch_available,
logging,
)
from ..utils.hub import DownloadKwargs


if is_torch_available():
Expand Down Expand Up @@ -249,7 +255,7 @@ def load_adapter(
else:
new_key = key

if key_mapping:
if key_mapping: # TODO dynamic weight loader for adapters
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
Expand Down Expand Up @@ -614,3 +620,56 @@ def old_delete_adapter(model, adapter_name, prefix=None):
if len(self.peft_config) == 0:
del self.peft_config
self._hf_peft_config_loaded = False


def maybe_load_adapters(
pretrained_model_name_or_path,
download_kwargs: DownloadKwargs,
**adapter_kwargs,
):
if pretrained_model_name_or_path is None or not is_peft_available():
return None

token = download_kwargs.get("token")
if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
adapter_kwargs["token"] = token

if download_kwargs.get("commit_hash") is None:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
cache_dir=download_kwargs.get("cache_dir"),
force_download=bool(download_kwargs.get("force_download", False)),
proxies=download_kwargs.get("proxies"),
local_files_only=bool(download_kwargs.get("local_files_only", False)),
token=token,
revision=download_kwargs.get("revision"),
subfolder=download_kwargs.get("subfolder", ""),
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
download_kwargs["commit_hash"] = extract_commit_hash(resolved_config_file, None)

_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)

if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=download_kwargs.get("cache_dir"),
force_download=bool(download_kwargs.get("force_download", False)),
proxies=download_kwargs.get("proxies"),
token=token,
revision=download_kwargs.get("revision"),
local_files_only=bool(download_kwargs.get("local_files_only", False)),
subfolder=download_kwargs.get("subfolder", ""),
_commit_hash=download_kwargs.get("commit_hash"),
**adapter_kwargs,
)

if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]

return _adapter_model_path
Loading