-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Fix unique memory address when doing group-offloading with disk #11767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b6c0f20
7c8fc64
6639f25
a9b7abe
e0bfef9
24ac17f
e37d2b0
99d5ad5
4f081dc
ab2eff7
9710bbc
6901ef4
72d76a8
e75ef18
59d07e5
b572234
9553b79
e8fef13
54a299c
2749d4b
27d41ac
260a834
e6d8779
a891f9b
0946d96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import glob | ||
import hashlib | ||
import os | ||
from contextlib import contextmanager, nullcontext | ||
from typing import Dict, List, Optional, Set, Tuple, Union | ||
|
@@ -35,7 +37,8 @@ | |
_GROUP_OFFLOADING = "group_offloading" | ||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker" | ||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" | ||
|
||
_GROUP_ID_LAZY_LEAF = "lazy_leafs" | ||
_GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules" | ||
_SUPPORTED_PYTORCH_LAYERS = ( | ||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, | ||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, | ||
|
@@ -62,6 +65,7 @@ def __init__( | |
low_cpu_mem_usage: bool = False, | ||
onload_self: bool = True, | ||
offload_to_disk_path: Optional[str] = None, | ||
_group_id: Optional[int] = None, | ||
) -> None: | ||
self.modules = modules | ||
self.offload_device = offload_device | ||
|
@@ -80,7 +84,9 @@ def __init__( | |
self._is_offloaded_to_disk = False | ||
|
||
if self.offload_to_disk_path: | ||
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") | ||
self._group_id = _group_id | ||
short_hash = _compute_group_hash(self._group_id) | ||
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The passed in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any edge cases, either. But having a hash is a bit more future-proof to me. |
||
|
||
all_tensors = [] | ||
for module in self.modules: | ||
|
@@ -603,6 +609,9 @@ def _apply_group_offloading_block_level( | |
|
||
for i in range(0, len(submodule), num_blocks_per_group): | ||
current_modules = submodule[i : i + num_blocks_per_group] | ||
start_idx = i | ||
end_idx = i + len(current_modules) - 1 | ||
group_id = f"{name}.{start_idx}_to_{end_idx}" | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
group = ModuleGroup( | ||
modules=current_modules, | ||
offload_device=offload_device, | ||
|
@@ -615,6 +624,7 @@ def _apply_group_offloading_block_level( | |
record_stream=record_stream, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
onload_self=True, | ||
_group_id=group_id, | ||
) | ||
matched_module_groups.append(group) | ||
for j in range(i, i + len(current_modules)): | ||
|
@@ -649,6 +659,7 @@ def _apply_group_offloading_block_level( | |
stream=None, | ||
record_stream=False, | ||
onload_self=True, | ||
_group_id=_GROUP_ID_UNMATCHED_GROUP, | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
if stream is None: | ||
_apply_group_offloading_hook(module, unmatched_group, None) | ||
|
@@ -715,6 +726,7 @@ def _apply_group_offloading_leaf_level( | |
record_stream=record_stream, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
onload_self=True, | ||
_group_id=name, | ||
) | ||
_apply_group_offloading_hook(submodule, group, None) | ||
modules_with_group_offloading.add(name) | ||
|
@@ -762,6 +774,7 @@ def _apply_group_offloading_leaf_level( | |
record_stream=record_stream, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
onload_self=True, | ||
_group_id=name, | ||
) | ||
_apply_group_offloading_hook(parent_module, group, None) | ||
|
||
|
@@ -783,6 +796,7 @@ def _apply_group_offloading_leaf_level( | |
record_stream=False, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
onload_self=True, | ||
_group_id=_GROUP_ID_LAZY_LEAF, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't think this is needed. Lazy hook doesn't do anything with it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean to exclude |
||
) | ||
_apply_lazy_group_offloading_hook(module, unmatched_group, None) | ||
|
||
|
@@ -890,3 +904,96 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: | |
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: | ||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device | ||
raise ValueError("Group offloading is not enabled for the provided module.") | ||
|
||
|
||
def _compute_group_hash(group_id): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't think we need to hash the group id strings, they should be unique already. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() | ||
# first 16 characters for a reasonably short but unique name | ||
return hashed_id[:16] | ||
|
||
|
||
def _get_expected_safetensors_files( | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
module: torch.nn.Module, | ||
offload_to_disk_path: str, | ||
offload_type: str, | ||
num_blocks_per_group: Optional[int] = None, | ||
) -> Set[str]: | ||
expected_files = set() | ||
|
||
def get_hashed_filename(group_id: str) -> str: | ||
short_hash = _compute_group_hash(group_id) | ||
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors") | ||
|
||
if offload_type == "block_level": | ||
if num_blocks_per_group is None: | ||
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") | ||
|
||
# Handle groups of ModuleList and Sequential blocks | ||
for name, submodule in module.named_children(): | ||
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): | ||
continue | ||
|
||
for i in range(0, len(submodule), num_blocks_per_group): | ||
current_modules = submodule[i : i + num_blocks_per_group] | ||
if not current_modules: | ||
continue | ||
start_idx = i | ||
end_idx = i + len(current_modules) - 1 | ||
group_id = f"{name}.{start_idx}_to_{end_idx}" | ||
expected_files.add(get_hashed_filename(group_id)) | ||
|
||
# Handle the group for unmatched top-level modules and parameters | ||
group_id = _GROUP_ID_UNMATCHED_GROUP | ||
expected_files.add(get_hashed_filename(group_id)) | ||
|
||
elif offload_type == "leaf_level": | ||
# Handle leaf-level module groups | ||
for name, submodule in module.named_modules(): | ||
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): | ||
# These groups will always have parameters, so a file is expected | ||
expected_files.add(get_hashed_filename(name)) | ||
|
||
# Handle groups for non-leaf parameters/buffers | ||
modules_with_group_offloading = { | ||
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS) | ||
} | ||
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) | ||
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) | ||
|
||
all_orphans = parameters + buffers | ||
if all_orphans: | ||
parent_to_tensors = {} | ||
module_dict = dict(module.named_modules()) | ||
for tensor_name, _ in all_orphans: | ||
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict) | ||
if parent_name not in parent_to_tensors: | ||
parent_to_tensors[parent_name] = [] | ||
parent_to_tensors[parent_name].append(tensor_name) | ||
|
||
for parent_name in parent_to_tensors: | ||
# A file is expected for each parent that gathers orphaned tensors | ||
expected_files.add(get_hashed_filename(parent_name)) | ||
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF)) | ||
|
||
else: | ||
raise ValueError(f"Unsupported offload_type: {offload_type}") | ||
|
||
return expected_files | ||
|
||
|
||
def _check_safetensors_serialization( | ||
module: torch.nn.Module, | ||
offload_to_disk_path: str, | ||
offload_type: str, | ||
num_blocks_per_group: Optional[int] = None, | ||
) -> bool: | ||
if not os.path.isdir(offload_to_disk_path): | ||
return False, None, None | ||
|
||
expected_files = _get_expected_safetensors_files(module, offload_to_disk_path, offload_type, num_blocks_per_group) | ||
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) | ||
missing_files = expected_files - actual_files | ||
extra_files = actual_files - expected_files | ||
|
||
is_correct = not missing_files and not extra_files | ||
return is_correct, extra_files, missing_files |
Uh oh!
There was an error while loading. Please reload this page.