Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""

_supports_gradient_checkpointing = False
# Allow device_map strings
_no_split_modules: List[str] = []

# fmt: off
@register_to_config
Expand Down
61 changes: 58 additions & 3 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,75 @@ def _determine_device_map(
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
)

if device_map != "sequential":
# Support several strategies:
# - "sequential": respect user-provided max_memory (or detect) without balancing
if device_map == "sequential":
max_memory = get_max_memory(max_memory)
else:
# Includes: "balanced", "auto"
max_memory = get_balanced_memory(
model,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**device_map_kwargs,
)
else:
max_memory = get_max_memory(max_memory)

if hf_quantizer is not None:
max_memory = hf_quantizer.adjust_max_memory(max_memory)

# Align with Transformers: add currently unused reserved memory on accelerators,
# then optionally apply a safety headroom.
try:
inferred = dict(max_memory) if max_memory is not None else {}
for device_name in list(inferred.keys()):
# Only consider accelerator devices; skip CPU
if device_name == "cpu":
continue
unused_mem = 0
# device_name could be an int (preferred) or a string key.
dev_index = None
if isinstance(device_name, int):
dev_index = device_name
elif isinstance(device_name, str):
# parse patterns like "cuda:0" or "xpu:0"
if ":" in device_name:
try:
dev_index = int(device_name.split(":", 1)[1])
except Exception:
dev_index = None
# query backend-specific reserved/allocated
try:
if hasattr(torch, "xpu") and torch.xpu.is_available():
if dev_index is not None:
unused_mem = torch.xpu.memory_reserved(dev_index) - torch.xpu.memory_allocated(dev_index)
elif torch.cuda.is_available():
if dev_index is not None:
unused_mem = torch.cuda.memory_reserved(dev_index) - torch.cuda.memory_allocated(dev_index)
except Exception:
unused_mem = 0

if unused_mem and unused_mem > 0:
inferred[device_name] = inferred.get(device_name, 0) + unused_mem

# Respect explicit user cap if provided with the same key.
if max_memory is not None and device_name in max_memory:
inferred[device_name] = min(inferred[device_name], max_memory[device_name])

# Apply a slightly safer occupancy for 'auto' to reduce OOMs after
if device_map == "auto":
for k in list(inferred.keys()):
if k != "cpu":
try:
inferred[k] = int(inferred[k] * 0.85)
except Exception:
pass

max_memory = inferred
except Exception:
# If any backend call fails, proceed with the baseline max_memory
pass

device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)

Expand Down
20 changes: 19 additions & 1 deletion src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,8 @@ def _assign_components_to_devices(


def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
# TODO: seperate out different device_map methods when it gets to it.
# Only implement pipeline-level component balancing for "balanced".
# For other strategies, return as-is so sub-models (Transformers/Diffusers) can perform intra-module sharding.
if device_map != "balanced":
return device_map
# To avoid circular import problem.
Expand Down Expand Up @@ -801,6 +802,7 @@ def load_sub_model(
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model:
# Default forwarding of placement hints
loading_kwargs["device_map"] = device_map
loading_kwargs["max_memory"] = max_memory
loading_kwargs["offload_folder"] = offload_folder
Expand Down Expand Up @@ -830,6 +832,22 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False

# Translate Diffusers-specific strategies for Transformers models:
# Transformers only accept: 'auto', 'balanced', 'balanced_low_0', 'sequential', device names, or dicts.

# For Diffusers models that don't implement `_no_split_modules`, they cannot accept string strategies
# (auto/balanced/etc.) for intra-module sharding. In that case, place the whole module on a single device.
if is_diffusers_model and isinstance(device_map, str) and device_map in {"auto", "balanced", "sequential"}:
try:
# class_obj is the class, not instance
supports_sharding = getattr(class_obj, "_no_split_modules", None) is not None
except Exception:
supports_sharding = False
if not supports_sharding:
# Prefer the primary accelerator if available; else CPU.
preferred = 0 if (hasattr(torch, "cuda") and torch.cuda.is_available()) else (0 if hasattr(torch, "xpu") and torch.xpu.is_available() else "cpu")
loading_kwargs["device_map"] = {"": preferred}

if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)

SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
# Supported high-level device placement strategies for pipeline loading.
# Strings here are forwarded to sub-model loaders which leverage Accelerate for sharding/offload.
SUPPORTED_DEVICE_MAP = [
"balanced", # pipeline-level component balancing (current behavior)
"sequential", # fill devices in order (Accelerate semantics)
"auto", # Accelerate best-effort automatic mapping
] + [get_device()]

logger = logging.get_logger(__name__)

Expand Down