diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 87ac40659212..93cfa37ed0b7 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -673,6 +673,8 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ _supports_gradient_checkpointing = False + # Allow device_map strings + _no_split_modules: List[str] = [] # fmt: off @register_to_config diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 2e07f55e0064..44aa391db93f 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -95,7 +95,12 @@ def _determine_device_map( "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." ) - if device_map != "sequential": + # Support several strategies: + # - "sequential": respect user-provided max_memory (or detect) without balancing + if device_map == "sequential": + max_memory = get_max_memory(max_memory) + else: + # Includes: "balanced", "auto" max_memory = get_balanced_memory( model, dtype=torch_dtype, @@ -103,12 +108,62 @@ def _determine_device_map( max_memory=max_memory, **device_map_kwargs, ) - else: - max_memory = get_max_memory(max_memory) if hf_quantizer is not None: max_memory = hf_quantizer.adjust_max_memory(max_memory) + # Align with Transformers: add currently unused reserved memory on accelerators, + # then optionally apply a safety headroom. + try: + inferred = dict(max_memory) if max_memory is not None else {} + for device_name in list(inferred.keys()): + # Only consider accelerator devices; skip CPU + if device_name == "cpu": + continue + unused_mem = 0 + # device_name could be an int (preferred) or a string key. + dev_index = None + if isinstance(device_name, int): + dev_index = device_name + elif isinstance(device_name, str): + # parse patterns like "cuda:0" or "xpu:0" + if ":" in device_name: + try: + dev_index = int(device_name.split(":", 1)[1]) + except Exception: + dev_index = None + # query backend-specific reserved/allocated + try: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + if dev_index is not None: + unused_mem = torch.xpu.memory_reserved(dev_index) - torch.xpu.memory_allocated(dev_index) + elif torch.cuda.is_available(): + if dev_index is not None: + unused_mem = torch.cuda.memory_reserved(dev_index) - torch.cuda.memory_allocated(dev_index) + except Exception: + unused_mem = 0 + + if unused_mem and unused_mem > 0: + inferred[device_name] = inferred.get(device_name, 0) + unused_mem + + # Respect explicit user cap if provided with the same key. + if max_memory is not None and device_name in max_memory: + inferred[device_name] = min(inferred[device_name], max_memory[device_name]) + + # Apply a slightly safer occupancy for 'auto' to reduce OOMs after + if device_map == "auto": + for k in list(inferred.keys()): + if k != "cpu": + try: + inferred[k] = int(inferred[k] * 0.85) + except Exception: + pass + + max_memory = inferred + except Exception: + # If any backend call fails, proceed with the baseline max_memory + pass + device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 2c611aa2c033..4cebf7e79011 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -613,7 +613,8 @@ def _assign_components_to_devices( def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): - # TODO: seperate out different device_map methods when it gets to it. + # Only implement pipeline-level component balancing for "balanced". + # For other strategies, return as-is so sub-models (Transformers/Diffusers) can perform intra-module sharding. if device_map != "balanced": return device_map # To avoid circular import problem. @@ -801,6 +802,7 @@ def load_sub_model( # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: + # Default forwarding of placement hints loading_kwargs["device_map"] = device_map loading_kwargs["max_memory"] = max_memory loading_kwargs["offload_folder"] = offload_folder @@ -830,6 +832,22 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + # Translate Diffusers-specific strategies for Transformers models: + # Transformers only accept: 'auto', 'balanced', 'balanced_low_0', 'sequential', device names, or dicts. + + # For Diffusers models that don't implement `_no_split_modules`, they cannot accept string strategies + # (auto/balanced/etc.) for intra-module sharding. In that case, place the whole module on a single device. + if is_diffusers_model and isinstance(device_map, str) and device_map in {"auto", "balanced", "sequential"}: + try: + # class_obj is the class, not instance + supports_sharding = getattr(class_obj, "_no_split_modules", None) is not None + except Exception: + supports_sharding = False + if not supports_sharding: + # Prefer the primary accelerator if available; else CPU. + preferred = 0 if (hasattr(torch, "cuda") and torch.cuda.is_available()) else (0 if hasattr(torch, "xpu") and torch.xpu.is_available() else "cpu") + loading_kwargs["device_map"] = {"": preferred} + if ( quantization_config is not None and isinstance(quantization_config, PipelineQuantizationConfig) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d231989973e4..12a558aed825 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -108,7 +108,13 @@ for library in LOADABLE_CLASSES: LIBRARIES.append(library) -SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()] +# Supported high-level device placement strategies for pipeline loading. +# Strings here are forwarded to sub-model loaders which leverage Accelerate for sharding/offload. +SUPPORTED_DEVICE_MAP = [ + "balanced", # pipeline-level component balancing (current behavior) + "sequential", # fill devices in order (Accelerate semantics) + "auto", # Accelerate best-effort automatic mapping +] + [get_device()] logger = logging.get_logger(__name__)