-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[from_pretrained
] Small refactor from_pretrained
: move around unrelated stuff
#41445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 20 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
ff27956
drafts
ArthurZucker cfe8766
up
ArthurZucker e8e32f6
simplify modeling utils
ArthurZucker e9b9270
more simplifications
ArthurZucker 9e95067
type kwargs
ArthurZucker e3da536
up
ArthurZucker 98b0c74
move more accelerate related stuff
ArthurZucker 281fcdd
safeguarding?
ArthurZucker 0acea98
nits
ArthurZucker f1ef942
remove func when func is NOPE
ArthurZucker c7790bf
more
ArthurZucker 7055ef7
nits
ArthurZucker 65ca13e
styling
ArthurZucker eeb19ec
yups
ArthurZucker 3830685
Merge branch 'main' of github.com:huggingface/transformers into dynam…
ArthurZucker 9d994cc
up
ArthurZucker 16372b3
ups
ArthurZucker f06524a
revert
ArthurZucker b6a9600
protect trainer utils iport
ArthurZucker 0ce1046
fix doc
ArthurZucker 2e4f5b8
Update src/transformers/integrations/peft.py
ArthurZucker 61e15b4
review
ArthurZucker 88ff305
Merge branch 'dynamic-weight-loader' of github.com:huggingface/transf…
ArthurZucker 1d81fa7
update
ArthurZucker 2199d3b
?
ArthurZucker d8a135f
fixx
ArthurZucker 4985198
Merge branch 'main' into dynamic-weight-loader
ArthurZucker 6356518
Merge branch 'main' into dynamic-weight-loader
ArthurZucker 62b7961
update
ArthurZucker 6e4a896
Merge branch 'dynamic-weight-loader' of github.com:huggingface/transf…
ArthurZucker 8e1c1fb
super small update
ArthurZucker 2522c95
ups
ArthurZucker f70fae6
style
ArthurZucker 7b4b65f
this is stupid
ArthurZucker 155b44e
Merge branch 'main' into dynamic-weight-loader
ArthurZucker ac053a3
:facepalm: well this was the issue
ArthurZucker c84d675
small nit
ArthurZucker 06515b6
fix
ArthurZucker 0b4c662
nit
ArthurZucker 4605d79
damn the missing return
ArthurZucker d067ae3
one last stupid fix
ArthurZucker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,15 +20,24 @@ | |
`find_tied_parameters` was copied from `accelerate.utils.modeling.py` | ||
""" | ||
|
||
import collections | ||
import inspect | ||
import os | ||
from contextlib import contextmanager | ||
|
||
from ..utils import is_torch_available, logging | ||
from ..utils import is_accelerate_available, is_torch_available, logging | ||
from ..utils.quantization_config import QuantizationMethod | ||
from .deepspeed import is_deepspeed_zero3_enabled | ||
from .fsdp import is_fsdp_enabled | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
import torch.nn as nn | ||
|
||
if is_accelerate_available(): | ||
from accelerate import dispatch_model | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
@@ -194,3 +203,149 @@ def find_tied_parameters(model: "nn.Module", **kwargs): | |
tied_param_groups[param_name].append(tied_param_name) | ||
|
||
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()] | ||
|
||
|
||
def auto_set_device_map(device_map): | ||
|
||
from ..modeling_utils import get_torch_context_manager_or_global_device | ||
|
||
# Potentially detect context manager or global device, and use it (only if no device_map was provided) | ||
if device_map is None and not is_deepspeed_zero3_enabled(): | ||
device_in_context = get_torch_context_manager_or_global_device() | ||
if device_in_context == torch.device("meta"): | ||
raise RuntimeError( | ||
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n" | ||
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an " | ||
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`" | ||
) | ||
device_map = device_in_context | ||
|
||
# change device_map into a map if we passed an int, a str or a torch.device | ||
if isinstance(device_map, torch.device): | ||
device_map = {"": device_map} | ||
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: | ||
try: | ||
device_map = {"": torch.device(device_map)} | ||
except RuntimeError: | ||
raise ValueError( | ||
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " | ||
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." | ||
) | ||
elif isinstance(device_map, int): | ||
if device_map < 0: | ||
raise ValueError( | ||
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " | ||
) | ||
else: | ||
device_map = {"": device_map} | ||
|
||
if device_map is not None: | ||
if is_deepspeed_zero3_enabled(): | ||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.") | ||
if not is_accelerate_available(): | ||
raise ValueError( | ||
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` " | ||
"requires `accelerate`. You can install it with `pip install accelerate`" | ||
) | ||
|
||
|
||
def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers): | ||
device_map_kwargs = { | ||
"device_map": device_map, | ||
"offload_dir": offload_folder, | ||
"offload_index": offload_index, | ||
"offload_buffers": offload_buffers, | ||
} | ||
if "skip_keys" in inspect.signature(dispatch_model).parameters: | ||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement | ||
# For HQQ method we force-set the hooks for single GPU envs | ||
if ( | ||
"force_hooks" in inspect.signature(dispatch_model).parameters | ||
and hf_quantizer is not None | ||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ | ||
): | ||
device_map_kwargs["force_hooks"] = True | ||
if ( | ||
hf_quantizer is not None | ||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8 | ||
and isinstance(device_map, dict) | ||
and ("cpu" in device_map.values() or "disk" in device_map.values()) | ||
): | ||
device_map_kwargs["offload_buffers"] = True | ||
|
||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): | ||
dispatch_model(model, **device_map_kwargs) | ||
|
||
|
||
def get_disk_only_shard_files(device_map, weight_map): | ||
""" | ||
Returns the list of shard files containing only weights offloaded to disk. | ||
""" | ||
files_content = collections.defaultdict(list) | ||
for weight_name, filename in weight_map.items(): | ||
while len(weight_name) > 0 and weight_name not in device_map: | ||
weight_name = ".".join(weight_name.split(".")[:-1]) | ||
files_content[filename].append(device_map[weight_name]) | ||
|
||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] | ||
|
||
|
||
def expand_device_map(device_map, param_names): | ||
""" | ||
Expand a device map to return the correspondence parameter name to device. | ||
""" | ||
new_device_map = {} | ||
for module, device in device_map.items(): | ||
new_device_map.update( | ||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} | ||
) | ||
return new_device_map | ||
|
||
|
||
def accelerate_disk_offload( | ||
disk_offload_folder, | ||
checkpoint_files, | ||
device_map, | ||
checkpoint_keys, | ||
key_renaming_mapping, | ||
sharded_metadata, | ||
dtype, | ||
reverse_key_renaming_mapping, | ||
): | ||
disk_only_shard_files = [] | ||
if disk_offload_folder is not None: | ||
os.makedirs(disk_offload_folder, exist_ok=True) | ||
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") | ||
if disk_offload_folder is None and not is_offloaded_safetensors: | ||
raise ValueError( | ||
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" | ||
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using" | ||
" offers the weights in this format." | ||
) | ||
if is_offloaded_safetensors: | ||
param_device_map = expand_device_map(device_map, checkpoint_keys) | ||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" | ||
if sharded_metadata is None: | ||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0]) | ||
else: | ||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) | ||
# Fix the weight map keys according to the key mapping | ||
weight_map = { | ||
key_renaming_mapping[k]: v | ||
for k, v in sharded_metadata["weight_map"].items() | ||
if k in key_renaming_mapping | ||
} | ||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} | ||
# Find potential checkpoints containing only offloaded weights | ||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map) | ||
disk_offload_index = { | ||
name: { | ||
"safetensors_file": file, | ||
"weight_name": reverse_key_renaming_mapping[name], | ||
"dtype": str_dtype, | ||
} | ||
for name, file in weight_map.items() | ||
if param_device_map[name] == "disk" | ||
} | ||
else: | ||
disk_offload_index = {} | ||
return disk_offload_index, disk_only_shard_files |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.