Skip to content

Commit 2d3d376

Browse files
sayakpaulDN6
andauthored
Fix unique memory address when doing group-offloading with disk (#11767)
* fix memory address problem * add more tests * updates * updates * update * _group_id = group_id * update * Apply suggestions from code review Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * update * update * update * fix --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent db715e2 commit 2d3d376

File tree

3 files changed

+167
-9
lines changed

3 files changed

+167
-9
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import hashlib
1516
import os
1617
from contextlib import contextmanager, nullcontext
1718
from dataclasses import dataclass
@@ -37,7 +38,7 @@
3738
_GROUP_OFFLOADING = "group_offloading"
3839
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
3940
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
40-
41+
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
4142
_SUPPORTED_PYTORCH_LAYERS = (
4243
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
4344
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -82,6 +83,7 @@ def __init__(
8283
low_cpu_mem_usage: bool = False,
8384
onload_self: bool = True,
8485
offload_to_disk_path: Optional[str] = None,
86+
group_id: Optional[int] = None,
8587
) -> None:
8688
self.modules = modules
8789
self.offload_device = offload_device
@@ -100,7 +102,10 @@ def __init__(
100102
self._is_offloaded_to_disk = False
101103

102104
if self.offload_to_disk_path:
103-
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
105+
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
106+
self.group_id = group_id if group_id is not None else str(id(self))
107+
short_hash = _compute_group_hash(self.group_id)
108+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
104109

105110
all_tensors = []
106111
for module in self.modules:
@@ -609,6 +614,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
609614

610615
for i in range(0, len(submodule), config.num_blocks_per_group):
611616
current_modules = submodule[i : i + config.num_blocks_per_group]
617+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
612618
group = ModuleGroup(
613619
modules=current_modules,
614620
offload_device=config.offload_device,
@@ -621,6 +627,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
621627
record_stream=config.record_stream,
622628
low_cpu_mem_usage=config.low_cpu_mem_usage,
623629
onload_self=True,
630+
group_id=group_id,
624631
)
625632
matched_module_groups.append(group)
626633
for j in range(i, i + len(current_modules)):
@@ -655,6 +662,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
655662
stream=None,
656663
record_stream=False,
657664
onload_self=True,
665+
group_id=f"{module.__class__.__name__}_unmatched_group",
658666
)
659667
if config.stream is None:
660668
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
@@ -686,6 +694,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
686694
record_stream=config.record_stream,
687695
low_cpu_mem_usage=config.low_cpu_mem_usage,
688696
onload_self=True,
697+
group_id=name,
689698
)
690699
_apply_group_offloading_hook(submodule, group, None, config=config)
691700
modules_with_group_offloading.add(name)
@@ -732,6 +741,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
732741
record_stream=config.record_stream,
733742
low_cpu_mem_usage=config.low_cpu_mem_usage,
734743
onload_self=True,
744+
group_id=name,
735745
)
736746
_apply_group_offloading_hook(parent_module, group, None, config=config)
737747

@@ -753,6 +763,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
753763
record_stream=False,
754764
low_cpu_mem_usage=config.low_cpu_mem_usage,
755765
onload_self=True,
766+
group_id=_GROUP_ID_LAZY_LEAF,
756767
)
757768
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
758769

@@ -873,6 +884,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
873884
raise ValueError("Group offloading is not enabled for the provided module.")
874885

875886

887+
def _compute_group_hash(group_id):
888+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
889+
# first 16 characters for a reasonably short but unique name
890+
return hashed_id[:16]
891+
892+
876893
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
877894
r"""
878895
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been

src/diffusers/utils/testing_utils.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import glob
23
import importlib
34
import importlib.metadata
45
import inspect
@@ -18,7 +19,7 @@
1819
from contextlib import contextmanager
1920
from io import BytesIO, StringIO
2021
from pathlib import Path
21-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
22+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
2223

2324
import numpy as np
2425
import PIL.Image
@@ -1392,6 +1393,103 @@ def get_device_properties() -> DeviceProperties:
13921393
else:
13931394
DevicePropertiesUserDict = UserDict
13941395

1396+
if is_torch_available():
1397+
from diffusers.hooks.group_offloading import (
1398+
_GROUP_ID_LAZY_LEAF,
1399+
_SUPPORTED_PYTORCH_LAYERS,
1400+
_compute_group_hash,
1401+
_find_parent_module_in_module_dict,
1402+
_gather_buffers_with_no_group_offloading_parent,
1403+
_gather_parameters_with_no_group_offloading_parent,
1404+
)
1405+
1406+
def _get_expected_safetensors_files(
1407+
module: torch.nn.Module,
1408+
offload_to_disk_path: str,
1409+
offload_type: str,
1410+
num_blocks_per_group: Optional[int] = None,
1411+
) -> Set[str]:
1412+
expected_files = set()
1413+
1414+
def get_hashed_filename(group_id: str) -> str:
1415+
short_hash = _compute_group_hash(group_id)
1416+
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
1417+
1418+
if offload_type == "block_level":
1419+
if num_blocks_per_group is None:
1420+
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
1421+
1422+
# Handle groups of ModuleList and Sequential blocks
1423+
unmatched_modules = []
1424+
for name, submodule in module.named_children():
1425+
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
1426+
unmatched_modules.append(module)
1427+
continue
1428+
1429+
for i in range(0, len(submodule), num_blocks_per_group):
1430+
current_modules = submodule[i : i + num_blocks_per_group]
1431+
if not current_modules:
1432+
continue
1433+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
1434+
expected_files.add(get_hashed_filename(group_id))
1435+
1436+
# Handle the group for unmatched top-level modules and parameters
1437+
for module in unmatched_modules:
1438+
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
1439+
1440+
elif offload_type == "leaf_level":
1441+
# Handle leaf-level module groups
1442+
for name, submodule in module.named_modules():
1443+
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
1444+
# These groups will always have parameters, so a file is expected
1445+
expected_files.add(get_hashed_filename(name))
1446+
1447+
# Handle groups for non-leaf parameters/buffers
1448+
modules_with_group_offloading = {
1449+
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
1450+
}
1451+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
1452+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
1453+
1454+
all_orphans = parameters + buffers
1455+
if all_orphans:
1456+
parent_to_tensors = {}
1457+
module_dict = dict(module.named_modules())
1458+
for tensor_name, _ in all_orphans:
1459+
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
1460+
if parent_name not in parent_to_tensors:
1461+
parent_to_tensors[parent_name] = []
1462+
parent_to_tensors[parent_name].append(tensor_name)
1463+
1464+
for parent_name in parent_to_tensors:
1465+
# A file is expected for each parent that gathers orphaned tensors
1466+
expected_files.add(get_hashed_filename(parent_name))
1467+
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
1468+
1469+
else:
1470+
raise ValueError(f"Unsupported offload_type: {offload_type}")
1471+
1472+
return expected_files
1473+
1474+
def _check_safetensors_serialization(
1475+
module: torch.nn.Module,
1476+
offload_to_disk_path: str,
1477+
offload_type: str,
1478+
num_blocks_per_group: Optional[int] = None,
1479+
) -> bool:
1480+
if not os.path.isdir(offload_to_disk_path):
1481+
return False, None, None
1482+
1483+
expected_files = _get_expected_safetensors_files(
1484+
module, offload_to_disk_path, offload_type, num_blocks_per_group
1485+
)
1486+
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
1487+
missing_files = expected_files - actual_files
1488+
extra_files = actual_files - expected_files
1489+
1490+
is_correct = not missing_files and not extra_files
1491+
return is_correct, extra_files, missing_files
1492+
13951493

13961494
class Expectations(DevicePropertiesUserDict):
13971495
def get_expectation(self) -> Any:

tests/models/test_modeling_common.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from diffusers.utils.hub_utils import _add_variant
6262
from diffusers.utils.testing_utils import (
6363
CaptureLogger,
64+
_check_safetensors_serialization,
6465
backend_empty_cache,
6566
backend_max_memory_allocated,
6667
backend_reset_peak_memory_stats,
@@ -1702,18 +1703,43 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
17021703
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
17031704
_ = model(**inputs_dict)[0]
17041705

1705-
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
1706+
@parameterized.expand([("block_level", False), ("leaf_level", True)])
17061707
@require_torch_accelerator
17071708
@torch.no_grad()
1708-
def test_group_offloading_with_disk(self, record_stream, offload_type):
1709+
@torch.inference_mode()
1710+
def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
17091711
if not self.model_class._supports_group_offloading:
17101712
pytest.skip("Model does not support group offloading.")
17111713

1712-
torch.manual_seed(0)
1714+
def _has_generator_arg(model):
1715+
sig = inspect.signature(model.forward)
1716+
params = sig.parameters
1717+
return "generator" in params
1718+
1719+
def _run_forward(model, inputs_dict):
1720+
accepts_generator = _has_generator_arg(model)
1721+
if accepts_generator:
1722+
inputs_dict["generator"] = torch.manual_seed(0)
1723+
torch.manual_seed(0)
1724+
return model(**inputs_dict)[0]
1725+
1726+
if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
1727+
pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")
1728+
17131729
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1730+
torch.manual_seed(0)
17141731
model = self.model_class(**init_dict)
1732+
17151733
model.eval()
1716-
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
1734+
model.to(torch_device)
1735+
output_without_group_offloading = _run_forward(model, inputs_dict)
1736+
1737+
torch.manual_seed(0)
1738+
model = self.model_class(**init_dict)
1739+
model.eval()
1740+
1741+
num_blocks_per_group = None if offload_type == "leaf_level" else 1
1742+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
17171743
with tempfile.TemporaryDirectory() as tmpdir:
17181744
model.enable_group_offload(
17191745
torch_device,
@@ -1724,8 +1750,25 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17241750
**additional_kwargs,
17251751
)
17261752
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1727-
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
1728-
_ = model(**inputs_dict)[0]
1753+
self.assertTrue(has_safetensors, "No safetensors found in the directory.")
1754+
1755+
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
1756+
# in nature. So, skip it.
1757+
if offload_type != "leaf_level":
1758+
is_correct, extra_files, missing_files = _check_safetensors_serialization(
1759+
module=model,
1760+
offload_to_disk_path=tmpdir,
1761+
offload_type=offload_type,
1762+
num_blocks_per_group=num_blocks_per_group,
1763+
)
1764+
if not is_correct:
1765+
if extra_files:
1766+
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
1767+
elif missing_files:
1768+
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
1769+
1770+
output_with_group_offloading = _run_forward(model, inputs_dict)
1771+
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
17291772

17301773
def test_auto_model(self, expected_max_diff=5e-5):
17311774
if self.forward_requires_fresh_args:

0 commit comments

Comments
 (0)