Skip to content

Fix unique memory address when doing group-offloading with disk #11767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import hashlib
import os
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple, Union
Expand All @@ -35,7 +37,8 @@
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"

_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
Expand All @@ -62,6 +65,7 @@ def __init__(
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
_group_id: Optional[int] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
Expand All @@ -80,7 +84,9 @@ def __init__(
self._is_offloaded_to_disk = False

if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
self._group_id = _group_id
short_hash = _compute_group_hash(self._group_id)
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The passed in group_id argument should be unique no? I don't think we need to compute a hash.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any edge cases, either. But having a hash is a bit more future-proof to me.


all_tensors = []
for module in self.modules:
Expand Down Expand Up @@ -603,6 +609,9 @@ def _apply_group_offloading_block_level(

for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
start_idx = i
end_idx = i + len(current_modules) - 1
group_id = f"{name}.{start_idx}_to_{end_idx}"
group = ModuleGroup(
modules=current_modules,
offload_device=offload_device,
Expand All @@ -615,6 +624,7 @@ def _apply_group_offloading_block_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
Expand Down Expand Up @@ -649,6 +659,7 @@ def _apply_group_offloading_block_level(
stream=None,
record_stream=False,
onload_self=True,
_group_id=_GROUP_ID_UNMATCHED_GROUP,
)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
Expand Down Expand Up @@ -715,6 +726,7 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_group_id=name,
)
_apply_group_offloading_hook(submodule, group, None)
modules_with_group_offloading.add(name)
Expand Down Expand Up @@ -762,6 +774,7 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None)

Expand All @@ -783,6 +796,7 @@ def _apply_group_offloading_leaf_level(
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_group_id=_GROUP_ID_LAZY_LEAF,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this is needed. Lazy hook doesn't do anything with it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean to exclude group_id here? In that case, it would default id(self) I think.

)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)

Expand Down Expand Up @@ -890,3 +904,96 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
raise ValueError("Group offloading is not enabled for the provided module.")


def _compute_group_hash(group_id):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need to hash the group id strings, they should be unique already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
# first 16 characters for a reasonably short but unique name
return hashed_id[:16]


def _get_expected_safetensors_files(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
) -> Set[str]:
expected_files = set()

def get_hashed_filename(group_id: str) -> str:
short_hash = _compute_group_hash(group_id)
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")

if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")

# Handle groups of ModuleList and Sequential blocks
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
continue

for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
continue
start_idx = i
end_idx = i + len(current_modules) - 1
group_id = f"{name}.{start_idx}_to_{end_idx}"
expected_files.add(get_hashed_filename(group_id))

# Handle the group for unmatched top-level modules and parameters
group_id = _GROUP_ID_UNMATCHED_GROUP
expected_files.add(get_hashed_filename(group_id))

elif offload_type == "leaf_level":
# Handle leaf-level module groups
for name, submodule in module.named_modules():
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
# These groups will always have parameters, so a file is expected
expected_files.add(get_hashed_filename(name))

# Handle groups for non-leaf parameters/buffers
modules_with_group_offloading = {
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
}
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)

all_orphans = parameters + buffers
if all_orphans:
parent_to_tensors = {}
module_dict = dict(module.named_modules())
for tensor_name, _ in all_orphans:
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
if parent_name not in parent_to_tensors:
parent_to_tensors[parent_name] = []
parent_to_tensors[parent_name].append(tensor_name)

for parent_name in parent_to_tensors:
# A file is expected for each parent that gathers orphaned tensors
expected_files.add(get_hashed_filename(parent_name))
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))

else:
raise ValueError(f"Unsupported offload_type: {offload_type}")

return expected_files


def _check_safetensors_serialization(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
) -> bool:
if not os.path.isdir(offload_to_disk_path):
return False, None, None

expected_files = _get_expected_safetensors_files(module, offload_to_disk_path, offload_type, num_blocks_per_group)
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
missing_files = expected_files - actual_files
extra_files = actual_files - expected_files

is_correct = not missing_files and not extra_files
return is_correct, extra_files, missing_files
55 changes: 47 additions & 8 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from pytest import skip
from requests.exceptions import HTTPError

from diffusers.hooks.group_offloading import _check_safetensors_serialization
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
Expand Down Expand Up @@ -1345,7 +1347,6 @@ def test_model_parallelism(self):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
print(f" new_model.hf_device_map:{new_model.hf_device_map}")

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

Expand Down Expand Up @@ -1694,22 +1695,43 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]

@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@parameterized.expand([("block_level", False), ("leaf_level", True)])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type):
torch.manual_seed(0)
@torch.inference_mode()
def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params

def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]

if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict)

if not getattr(model, "_supports_group_offloading", True):
return

model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)

torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}

num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
Expand All @@ -1720,8 +1742,25 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
_ = model(**inputs_dict)[0]
self.assertTrue(has_safetensors, "No safetensors found in the directory.")

# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")

output_with_group_offloading = _run_forward(model, inputs_dict)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))

def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
Expand Down
Loading