From b9f5145f4d02d86b131ea599fd9785356cb86b01 Mon Sep 17 00:00:00 2001 From: loxotron Date: Thu, 15 May 2025 06:54:33 +0300 Subject: [PATCH 1/6] Add fast_sampler.py with optimized sampling and VAE decoding, enhance PreviewImage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit introduces `fast_sampler.py`, a new module designed to enhance the performance of sampling and VAE decoding in ComfyUI. It replaces or augments functionality previously handled in `model_management.py`, providing better VRAM management, FP16 support, and tiled decoding for low-memory scenarios. Additionally, it improves the `PreviewImage` node in `nodes.py` for faster and more efficient preview generation. These changes improve efficiency, stability, and usability, particularly for GPU-based workflows. **Key Changes:** - Implemented `fast_ksampler` for optimized sampling with improved memory management, FP16 support via `torch.amp.autocast`, and `channels_last` memory format for better GPU performance. - Added `fast_vae_decode` for efficient VAE decoding, incorporating FP16 support, `channels_last`, and selective VRAM clearing to prevent out-of-memory errors. - Introduced `fast_vae_tiled_decode` for tiled VAE decoding, enabling processing of large latents on GPUs with limited VRAM by using configurable tile sizes and overlaps. - Added profiling and debugging utilities (`profile_section`, `profile_cuda_sync`) to track execution times and VRAM usage when `--profile` or `--debug` flags are enabled. - Improved VRAM management with `clear_vram`, ensuring sufficient free memory before loading models or VAE, with configurable thresholds and minimum free memory requirements. - Implemented `is_fp16_safe` to check GPU compatibility for FP16 operations, disabling them on unsupported hardware (e.g., GTX 1660/Turing). - Optimized tensor transfers with `optimized_transfer` and `optimized_conditioning` for synchronous device placement and dtype casting. - Enhanced model preloading with `preload_model`, which unloads VAE before loading U-Net to conserve VRAM and checks for already-loaded VAE to avoid redundant transfers. - Integrated `cudnn.benchmark` for for tests, disabled by default. - VRAM should now be managed efficiently. - Updated `PreviewImage` node in `nodes.py` to support adaptive resizing of preview images to a maximum dimension of ~512 pixels while preserving aspect ratio, using `Image.LANCZOS` for quality. Increased `compress_level` from 1 to 4 for faster PNG compression, optimizing preview generation. **Impact:** - Significantly reduces VRAM usage during sampling and VAE decoding, making ComfyUI more stable on GPUs with limited memory. - Improves performance for large-scale image generation through tiled decoding and FP16 optimizations. - Enhances debugging capabilities with detailed profiling and logging, aiding development and optimization. **Dependencies:** - Relies on `nodes.py` for integration with `KSampler`, `VAEDecode`, `VAEDecodeTiled`, and `PreviewImage` nodes. - Assumes compatibility with existing `ModelPatcher` functionality for model patching (e.g., in `LoraLoader`). **Notes:** - Users should enable `--profile` or `--debug` flags to access detailed performance logs. - FP16 support requires compatible GPU hardware (compute capability ≥ 8 or > 7). - Tiled decoding parameters (`tile_size`, `overlap`, etc.) may need tuning for specific workflows. - Preview images are now smaller and faster to generate, but users can adjust `max_size` in `PreviewImage` if higher resolution previews are needed. This is a foundational change to improve ComfyUI's performance and scalability, particularly for resource-constrained environments. Thanks to Grok @ xAI for help. --- comfy/cli_args.py | 54 +- comfy/model_management.py | 2499 ++++++++++++++++++++++--------------- comfy/model_patcher.py | 187 ++- comfy/ops.py | 14 +- comfy/utils.py | 20 + fast_sampler.py | 477 +++++++ folder_paths.py | 8 +- main.py | 4 + nodes.py | 188 ++- 9 files changed, 2358 insertions(+), 1093 deletions(-) create mode 100644 fast_sampler.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index de292d9b323..ee50f4622f5 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,3 +1,8 @@ +""" +This file is part of ComfyUI. +Copyright (C) 2024 Comfy +""" + import argparse import enum import os @@ -11,7 +16,7 @@ class EnumAction(argparse.Action): def __init__(self, **kwargs): # Pop off the type value enum_type = kwargs.pop("type", None) - + # Ensure an Enum subclass is provided if enum_type is None: raise ValueError("type must be assigned an Enum when using EnumAction") @@ -22,9 +27,7 @@ def __init__(self, **kwargs): choices = tuple(e.value for e in enum_type) kwargs.setdefault("choices", choices) kwargs.setdefault("metavar", f"[{','.join(list(choices))}]") - super(EnumAction, self).__init__(**kwargs) - self._enum = enum_type def __call__(self, parser, namespace, values, option_string=None): @@ -35,6 +38,8 @@ def __call__(self, parser, namespace, values, option_string=None): parser = argparse.ArgumentParser() +parser.add_argument("--debug", action="store_true", help="Enable debug logging.") +parser.add_argument("--profile", action="store_true", help="Enable profiling.") parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function") @@ -53,34 +58,37 @@ def __call__(self, parser, namespace, values, option_string=None): cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") - +cm_group.add_argument("--model-dtype", type=str, choices=["fp16", "bf16", "fp32"], help="Force model data type (fp16, bf16, fp32)") fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") +fp_group.add_argument("--force-fp32-vae", action="store_true", help="Force VAE to use FP32 precision") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") +fp_group.add_argument("--force-fp16-vae", action="store_true", help="Force VAE to use FP16 precision") fpunet_group = parser.add_mutually_exclusive_group() fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.") -fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.") -fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.") -fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16") -fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") -fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") -fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.") +fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64 (not recommended).") +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16 (requires SM >= 8.0).") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16 (may cause black images on VRAM < 6 GB).") +fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Run UNet in fp8 e4m3fn (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") +fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Run UNet in fp8 e5m2 (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") +fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Run UNet in fp8 e8m0fnu (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") # UPDATED: Clarified requirements fpvae_group = parser.add_mutually_exclusive_group() -fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") -fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") -fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16 (risks black images on VRAM < 6 GB).") +fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in fp32 (recommended for GPUs with VRAM < 6 GB).") +fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16 (requires SM >= 8.0).") -parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.") +parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU (slower, but safe for low VRAM).") fpte_group = parser.add_mutually_exclusive_group() -fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") -fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") -fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") -fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") -fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") +fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Run text encoder in fp8 e4m3fn (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") +fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Run text encoder in fp8 e5m2 (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") +fpte_group.add_argument("--fp8_e8m0fnu-text-enc", action="store_true", help="Run text encoder in fp8 e8m0fnu (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") # NEW +fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Run text encoder in fp16 (may cause issues on VRAM < 6 GB).") +fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Run text encoder in fp32 (recommended for GPUs with VRAM < 6 GB).") +fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Run text encoder in bf16 (requires SM >= 8.0).") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") @@ -96,7 +104,6 @@ class LatentPreviewMethod(enum.Enum): TAESD = "taesd" parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) - parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") cache_group = parser.add_mutually_exclusive_group() @@ -117,7 +124,6 @@ class LatentPreviewMethod(enum.Enum): upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.") upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") - vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") @@ -199,7 +205,7 @@ def is_valid_directory(path: str) -> str: "--comfy-api-base", type=str, default="https://api.comfy.org", - help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", + help="Set the base URL for the ComfyUI API (default: https://api.comfy.org).", ) if comfy.options.args_parsing: @@ -215,6 +221,8 @@ def is_valid_directory(path: str) -> str: if args.force_fp16: args.fp16_unet = True + args.fp16_vae = True + args.fp16_text_enc = True # '--fast' is not provided, use an empty set @@ -225,4 +233,4 @@ def is_valid_directory(path: str) -> str: args.fast = set(PerformanceFeature) # '--fast' is provided with a list of performance features, use that list else: - args.fast = set(args.fast) + args.fast = set(args.fast) \ No newline at end of file diff --git a/comfy/model_management.py b/comfy/model_management.py index 44aff37625c..5f19e4b01c7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,187 +1,322 @@ """ - This file is part of ComfyUI. - Copyright (C) 2024 Comfy - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . +This file is part of ComfyUI. +Copyright (C) 2024 Comfy +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. +You should have received a copy of the GNU General Public License +along with this program. If not, see . """ - import psutil import logging -from enum import Enum -from comfy.cli_args import args, PerformanceFeature import torch import sys import platform +import contextlib import weakref +import time import gc +import re +import threading +import traceback +from enum import Enum +from comfy.cli_args import args, PerformanceFeature +from comfy.ldm.models.autoencoder import AutoencoderKL -class VRAMState(Enum): - DISABLED = 0 #No vram present: no need to move models to vram - NO_VRAM = 1 #Very low vram: enable all the options to save vram - LOW_VRAM = 2 - NORMAL_VRAM = 3 - HIGH_VRAM = 4 - SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both. +try: + import torch_directml + _torch_directml_available = True +except ImportError: + _torch_directml_available = False + +def log_vram_state(device=None): + if not DEBUG_ENABLED: + return + if device is None: + device = get_torch_device() + free_vram, free_torch = get_free_memory(device, torch_free_too=True) + active_models = [(m.model.__class__.__name__, m.model_memory_required(device) / 1024**3) + for m in current_loaded_models if m.device == device] + logging.debug( + f"VRAM state: free_vram={free_vram / 1024**3:.2f} GB, free_torch={free_torch / 1024**3:.2f} GB, models={active_models}") class CPUState(Enum): GPU = 0 CPU = 1 MPS = 2 -# Determine VRAM State +cpu_state = CPUState.GPU # Default to GPU + +# Global flags +PROFILING_ENABLED = args.profile +DEBUG_ENABLED = args.debug +VERBOSE_ENABLED = False + +# Configure logging +logging.basicConfig(level=logging.DEBUG if args.debug or args.profile else logging.INFO) + +# Cache for device and dtype checks +_device_cache = {} + +# VRAM optimizers for extensibility +_vram_optimizers = [] + +class VRAMState(Enum): + DISABLED = 0 # No VRAM: models stay on CPU + NO_VRAM = 1 # Very low VRAM: maximum memory saving + LOW_VRAM = 2 # Low VRAM: partial model loading + NORMAL_VRAM = 3 # Default: balanced memory management + HIGH_VRAM = 4 # High VRAM: keep models in VRAM + SHARED = 5 # Shared CPU/GPU memory (e.g., MPS) + +# Global state vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM cpu_state = CPUState.GPU - total_vram = 0 +total_ram = psutil.virtual_memory().total / (1024 * 1024) + +# Cache for DirectML VRAM +_directml_vram_cache = {} + +# Cache for active models memory in DirectML +_directml_active_memory_cache = {} + +def cpu_mode(): + """Check if system is in CPU mode.""" + global cpu_state + return cpu_state == CPUState.CPU + +def mps_mode(): + """Check if system is in MPS (Apple Metal) mode.""" + global cpu_state + return cpu_state == CPUState.MPS + +def is_device_cpu(device): + return is_device_type(device, 'cpu') + +def is_device_mps(device): + return is_device_type(device, 'mps') + +def is_device_cuda(device): + return is_device_type(device, 'cuda') + +def is_directml_enabled(): + global directml_enabled + if directml_enabled: + return True + + return False def get_supported_float8_types(): + """Get supported float8 data types.""" float8_types = [] - try: - float8_types.append(torch.float8_e4m3fn) - except: - pass - try: - float8_types.append(torch.float8_e4m3fnuz) - except: - pass - try: - float8_types.append(torch.float8_e5m2) - except: - pass - try: - float8_types.append(torch.float8_e5m2fnuz) - except: - pass - try: - float8_types.append(torch.float8_e8m0fnu) - except: - pass + for dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.float8_e8m0fnu]: + try: + float8_types.append(dtype) + except: + pass return float8_types +def get_directml_vram(dev): + """ + Estimate VRAM for DirectML device, trying CUDA first, then heuristic, then fallback. + + Args: + dev: Torch device (DirectML). + + Returns: + int: Estimated VRAM in bytes. + """ + if dev in _directml_vram_cache: + return _directml_vram_cache[dev] + + # Use args.reserve_vram if provided + if args.reserve_vram is not None: + vram = int(args.reserve_vram * 1024 * 1024 * 1024) + _directml_vram_cache[dev] = vram + return vram + + # Try CUDA if available + if torch.cuda.is_available(): + try: + free_vram, total_vram = torch.cuda.mem_get_info() + _directml_vram_cache[dev] = total_vram + if DEBUG_ENABLED: + logging.debug(f"DirectML VRAM from CUDA: {total_vram / (1024**3):.0f} GB") + return total_vram + except Exception as e: + logging.warning(f"Failed to get CUDA VRAM: {e}") + + # Try torch_directml heuristic + if _torch_directml_available: + try: + device_index = dev.index if hasattr(dev, 'index') else 0 + device_name = torch_directml.device_name(device_index).lower() + vram_map = { + 'gtx 1660': 6 * 1024 * 1024 * 1024, + 'gtx 1650': 4 * 1024 * 1024 * 1024, + 'rtx 2060': 6 * 1024 * 1024 * 1024, + 'rtx 3060': 12 * 1024 * 1024 * 1024, + 'rtx 4060': 8 * 1024 * 1024 * 1024, + 'rx 580': 8 * 1024 * 1024 * 1024, + 'rx 570': 8 * 1024 * 1024 * 1024, + 'rx 6700': 12 * 1024 * 1024 * 1024, + 'arc a770': 16 * 1024 * 1024 * 1024, + } + vram = 6 * 1024 * 1024 * 1024 + for key, value in vram_map.items(): + if key in device_name: + vram = value + break + _directml_vram_cache[dev] = vram + if DEBUG_ENABLED: + logging.debug(f"DirectML VRAM for {device_name}: {vram / (1024**3):.0f} GB") + return vram + except Exception as e: + logging.warning(f"Failed to get DirectML device name: {e}") + + # Fallback to safe default + vram = 6 * 1024 * 1024 * 1024 + _directml_vram_cache[dev] = vram + if DEBUG_ENABLED: + logging.debug(f"DirectML VRAM fallback: {vram / (1024**3):.0f} GB") + return vram + FLOAT8_TYPES = get_supported_float8_types() +XFORMERS_IS_AVAILABLE = False +XFORMERS_ENABLED_VAE = True +ENABLE_PYTORCH_ATTENTION = True # Enable PyTorch attention for better performance +FORCE_FP32 = args.force_fp32 +DISABLE_SMART_MEMORY = args.disable_smart_memory + +# Async offload setup +STREAMS = {} +NUM_STREAMS = 1 +stream_counters = {} +if args.async_offload: + logging.info(f"Using async weight offloading with {NUM_STREAMS} streams") + # Protection for older GPUs + if is_nvidia(): + props = torch.cuda.get_device_properties(get_torch_device()) + if props.major < 8: # Turing (7.5) or Pascal (6.x) + args.async_offload = False + NUM_STREAMS = 1 + logging.warning("Async offload disabled for GPUs with SM < 8.0 to prevent memory leaks") +# Device initialization xpu_available = False -torch_version = "" +npu_available = False +mlu_available = False +directml_enabled = args.directml is not None +torch_version_numeric = (0, 0) try: - torch_version = torch.version.__version__ + torch_version = torch.__version__ temp = torch_version.split(".") torch_version_numeric = (int(temp[0]), int(temp[1])) - xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available() + xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and hasattr( + torch, "xpu") and torch.xpu.is_available() except: pass - -lowvram_available = True -if args.deterministic: - logging.info("Using deterministic algorithms for pytorch") - torch.use_deterministic_algorithms(True, warn_only=True) - -directml_enabled = False -if args.directml is not None: +if directml_enabled: import torch_directml - directml_enabled = True - device_index = args.directml - if device_index < 0: - directml_device = torch_directml.device() - else: - directml_device = torch_directml.device(device_index) - logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) - # torch_directml.disable_tiled_resources(True) - lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. - + device_index = args.directml if args.directml >= 0 else 0 + directml_device = torch_directml.device(device_index) + logging.info(f"Using DirectML with device: {torch_directml.device_name(device_index)}") try: import intel_extension_for_pytorch as ipex - _ = torch.xpu.device_count() xpu_available = xpu_available or torch.xpu.is_available() except: xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) - try: if torch.backends.mps.is_available(): cpu_state = CPUState.MPS import torch.mps except: pass - try: - import torch_npu # noqa: F401 - _ = torch.npu.device_count() + import torch_npu npu_available = torch.npu.is_available() except: npu_available = False - try: - import torch_mlu # noqa: F401 - _ = torch.mlu.device_count() + import torch_mlu mlu_available = torch.mlu.is_available() except: mlu_available = False - if args.cpu: cpu_state = CPUState.CPU +# Device and memory utilities +def is_nvidia(): + """Check if the device is NVIDIA GPU.""" + return cpu_state == CPUState.GPU and torch.version.cuda + +def is_amd(): + """Check if the device is AMD GPU.""" + return cpu_state == CPUState.GPU and torch.version.hip + def is_intel_xpu(): - global cpu_state - global xpu_available - if cpu_state == CPUState.GPU: - if xpu_available: - return True - return False + """Check if the device is Intel XPU.""" + return cpu_state == CPUState.GPU and xpu_available def is_ascend_npu(): - global npu_available - if npu_available: - return True - return False + """Check if the device is Ascend NPU.""" + return npu_available def is_mlu(): - global mlu_available - if mlu_available: - return True - return False + """Check if the device is MLU.""" + return mlu_available + +def is_device_cuda(device): + """Check if the device is CUDA.""" + return hasattr(device, 'type') and device.type == 'cuda' + +def is_device_type(device, device_type): + """Check if the device matches the given type.""" + return hasattr(device, 'type') and device.type == device_type def get_torch_device(): - global directml_enabled - global cpu_state + """Get the current PyTorch device.""" if directml_enabled: - global directml_device return directml_device if cpu_state == CPUState.MPS: return torch.device("mps") if cpu_state == CPUState.CPU: return torch.device("cpu") - else: - if is_intel_xpu(): - return torch.device("xpu", torch.xpu.current_device()) - elif is_ascend_npu(): - return torch.device("npu", torch.npu.current_device()) - elif is_mlu(): - return torch.device("mlu", torch.mlu.current_device()) - else: - return torch.device(torch.cuda.current_device()) + if is_intel_xpu(): + return torch.device("xpu", torch.xpu.current_device()) + if is_ascend_npu(): + return torch.device("npu", torch.npu.current_device()) + if is_mlu(): + return torch.device("mlu", torch.mlu.current_device()) + return torch.device(torch.cuda.current_device()) + def get_total_memory(dev=None, torch_total_too=False): - global directml_enabled + """ + Get total memory available on the device. + + Args: + dev: Torch device (optional, defaults to current device). + torch_total_too: If True, return (total, torch_total). + + Returns: + int or tuple: Total memory in bytes (or tuple with torch_total). + """ if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: if directml_enabled: - mem_total = 1024 * 1024 * 1024 #TODO + mem_total = get_directml_vram(dev) mem_total_torch = mem_total elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) @@ -191,14 +326,14 @@ def get_total_memory(dev=None, torch_total_too=False): elif is_ascend_npu(): stats = torch.npu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] - _, mem_total_npu = torch.npu.mem_get_info(dev) mem_total_torch = mem_reserved + _, mem_total_npu = torch.npu.mem_get_info(dev) mem_total = mem_total_npu elif is_mlu(): stats = torch.mlu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] - _, mem_total_mlu = torch.mlu.mem_get_info(dev) mem_total_torch = mem_reserved + _, mem_total_mlu = torch.mlu.mem_get_info(dev) mem_total = mem_total_mlu else: stats = torch.cuda.memory_stats(dev) @@ -206,37 +341,51 @@ def get_total_memory(dev=None, torch_total_too=False): _, mem_total_cuda = torch.cuda.mem_get_info(dev) mem_total_torch = mem_reserved mem_total = mem_total_cuda + return (mem_total, mem_total_torch) if torch_total_too else mem_total - if torch_total_too: - return (mem_total, mem_total_torch) - else: - return mem_total - -def mac_version(): - try: - return tuple(int(n) for n in platform.mac_ver()[0].split(".")) - except: - return None - +# Initialize VRAM state total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) -total_ram = psutil.virtual_memory().total / (1024 * 1024) -logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) - -try: - logging.info("pytorch version: {}".format(torch_version)) - mac_ver = mac_version() - if mac_ver is not None: - logging.info("Mac Version {}".format(mac_ver)) -except: - pass - -try: - OOM_EXCEPTION = torch.cuda.OutOfMemoryError -except: - OOM_EXCEPTION = Exception +logging.info(f"Total VRAM {total_vram:.0f} MB, total RAM {total_ram:.0f} MB") +logging.info(f"Pytorch version: {torch_version}") + +def get_extra_reserved_vram(): + """ + Determine extra VRAM to reserve based on total VRAM and args. + + Returns: + int: Reserved VRAM in bytes. + """ + total_vram = get_total_memory(get_torch_device()) / (1024 * 1024 * 1024) # VRAM in GB + if args.reserve_vram is not None: + return args.reserve_vram * 1024 * 1024 * 1024 + if total_vram < 7.9: + return 150 * 1024 * 1024 # 150 MB for low VRAM (<7.9 GB) + return 200 * 1024 * 1024 # 200 MB for high VRAM (?7.9 GB) + +EXTRA_RESERVED_VRAM = get_extra_reserved_vram() +logging.info(f"EXTRA_RESERVED_VRAM set to {EXTRA_RESERVED_VRAM / (1024 * 1024):.0f} MB") +if args.lowvram: + set_vram_to = VRAMState.LOW_VRAM +elif args.novram: + set_vram_to = VRAMState.NO_VRAM +elif args.highvram or args.gpu_only: + vram_state = VRAMState.HIGH_VRAM +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED +elif cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED +if directml_enabled: + lowvram_available = False +else: + lowvram_available = True +if lowvram_available and set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to +logging.info(f"Set VRAM state to: {vram_state.name}") +if DISABLE_SMART_MEMORY: + logging.info("Disabling smart memory management") +# XFormers and attention settings XFORMERS_VERSION = "" -XFORMERS_ENABLED_VAE = True if args.disable_xformers: XFORMERS_IS_AVAILABLE = False else: @@ -248,151 +397,354 @@ def mac_version(): XFORMERS_IS_AVAILABLE = xformers._has_cpp_library except: pass - try: - XFORMERS_VERSION = xformers.version.__version__ - logging.info("xformers version: {}".format(XFORMERS_VERSION)) - if XFORMERS_VERSION.startswith("0.0.18"): - logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") - logging.warning("Please downgrade or upgrade xformers to a different version.\n") - XFORMERS_ENABLED_VAE = False - except: - pass + XFORMERS_VERSION = xformers.version.__version__ + logging.info(f"xformers version: {XFORMERS_VERSION}") + if XFORMERS_VERSION.startswith("0.0.18"): + logging.warning( + "WARNING: xformers 0.0.18 has a bug causing black images at high resolutions. Please downgrade or upgrade.") + XFORMERS_ENABLED_VAE = False except: XFORMERS_IS_AVAILABLE = False -def is_nvidia(): - global cpu_state - if cpu_state == CPUState.GPU: - if torch.version.cuda: - return True - return False - -def is_amd(): - global cpu_state - if cpu_state == CPUState.GPU: - if torch.version.hip: - return True - return False +def xformers_enabled(): + """Check if xformers is enabled and available.""" + global directml_enabled, cpu_state + if cpu_state != CPUState.GPU or is_intel_xpu() or is_ascend_npu() or is_mlu() or directml_enabled: + return False + return XFORMERS_IS_AVAILABLE and not args.disable_xformers and not args.use_pytorch_cross_attention -MIN_WEIGHT_MEMORY_RATIO = 0.4 -if is_nvidia(): - MIN_WEIGHT_MEMORY_RATIO = 0.0 +def xformers_enabled_vae(): + """Check if xformers is enabled for VAE.""" + enabled = xformers_enabled() + if not enabled: + return False + return XFORMERS_ENABLED_VAE -ENABLE_PYTORCH_ATTENTION = False -if args.use_pytorch_cross_attention: - ENABLE_PYTORCH_ATTENTION = True - XFORMERS_IS_AVAILABLE = False +def sage_attention_enabled(): + """Check if Sage Attention is enabled.""" + global directml_enabled, cpu_state + if cpu_state != CPUState.GPU or is_intel_xpu() or is_ascend_npu() or is_mlu() or directml_enabled: + return False + return hasattr(args, 'use_sage_attention') and args.use_sage_attention -try: - if is_nvidia(): - if torch_version_numeric[0] >= 2: - if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - ENABLE_PYTORCH_ATTENTION = True - if is_intel_xpu() or is_ascend_npu() or is_mlu(): - if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - ENABLE_PYTORCH_ATTENTION = True -except: - pass +def flash_attention_enabled(): + """Check if Flash Attention is enabled.""" + global directml_enabled, cpu_state + if cpu_state != CPUState.GPU or is_intel_xpu() or is_ascend_npu() or is_mlu() or directml_enabled: + return False + return hasattr(args, 'use_flash_attention') and args.use_flash_attention +def pytorch_attention_enabled(): + """Check if PyTorch attention is enabled.""" + global ENABLE_PYTORCH_ATTENTION + return ENABLE_PYTORCH_ATTENTION or not (xformers_enabled() or sage_attention_enabled() or flash_attention_enabled()) -try: +def pytorch_attention_enabled_vae(): + """Check if PyTorch attention is enabled for VAE.""" if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName - logging.info("AMD arch: {}".format(arch)) - if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches - ENABLE_PYTORCH_ATTENTION = True -except: - pass - - -if ENABLE_PYTORCH_ATTENTION: - torch.backends.cuda.enable_math_sdp(True) - torch.backends.cuda.enable_flash_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(True) - - -PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other -try: - if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast: - torch.backends.cuda.matmul.allow_fp16_accumulation = True - PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance - logging.info("Enabled fp16 accumulation.") -except: - pass - -try: - if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: - torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) -except: - logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") - -if args.lowvram: - set_vram_to = VRAMState.LOW_VRAM - lowvram_available = True -elif args.novram: - set_vram_to = VRAMState.NO_VRAM -elif args.highvram or args.gpu_only: - vram_state = VRAMState.HIGH_VRAM - -FORCE_FP32 = False -if args.force_fp32: - logging.info("Forcing FP32, if this improves things please report it.") - FORCE_FP32 = True - -if lowvram_available: - if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): - vram_state = set_vram_to - - -if cpu_state != CPUState.GPU: - vram_state = VRAMState.DISABLED - -if cpu_state == CPUState.MPS: - vram_state = VRAMState.SHARED + return False # Enabling PyTorch attention on AMD causes crashes at high resolutions + return pytorch_attention_enabled() -logging.info(f"Set vram state to: {vram_state.name}") +def pytorch_attention_flash_attention(): + """Check if PyTorch Flash Attention is supported.""" + if pytorch_attention_enabled(): + if is_nvidia() or is_intel_xpu() or is_ascend_npu() or is_mlu() or is_amd(): + return True + return False + return False -DISABLE_SMART_MEMORY = args.disable_smart_memory +def force_upcast_attention_dtype(): + """Check if attention dtype should be upcast (e.g., FP16 to FP32).""" + upcast = args.force_upcast_attention + macos_version = mac_version() + if macos_version is not None and ((14, 5) <= macos_version < (16,)): + upcast = True # Workaround for macOS black image bug + if upcast: + return {torch.float16: torch.float32} + return None -if DISABLE_SMART_MEMORY: - logging.info("Disabling smart memory management") +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + """Cast tensor to specified dtype and device, compatible with comfy.ops.""" + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + if stream is not None: + with stream: + return weight.to(dtype=dtype, copy=copy) + return weight.to(dtype=dtype, copy=copy) + if stream is not None: + with stream: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + else: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + return r -def get_torch_device_name(device): - if hasattr(device, 'type'): +def get_torch_device_name(device=None): + """Get the name of the torch device.""" + if device is None: + device = get_torch_device() + if isinstance(device, str): + return device + if isinstance(device, torch.device): if device.type == "cuda": try: - allocator_backend = torch.cuda.get_allocator_backend() + allocator = torch.cuda.get_allocator_backend() except: - allocator_backend = "" - return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) - else: - return "{}".format(device.type) - elif is_intel_xpu(): - return "{} {}".format(device, torch.xpu.get_device_name(device)) - elif is_ascend_npu(): - return "{} {}".format(device, torch.npu.get_device_name(device)) - elif is_mlu(): - return "{} {}".format(device, torch.mlu.get_device_name(device)) - else: - return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + allocator = "" + return f"{device.type}:{device.index if device.index is not None else 0} {allocator}" + return device.type + return str(device) -try: - logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) -except: - logging.warning("Could not pick default device.") +class OOM_EXCEPTION(Exception): + """Exception raised for out-of-memory errors.""" + pass +if args.use_pytorch_cross_attention: + ENABLE_PYTORCH_ATTENTION = True + XFORMERS_IS_AVAILABLE = False +MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia() else 0.0 +if is_nvidia() and torch_version_numeric[0] >= 2: + if not (ENABLE_PYTORCH_ATTENTION or args.use_split_cross_attention or args.use_quad_cross_attention): + ENABLE_PYTORCH_ATTENTION = True +elif is_intel_xpu() or is_ascend_npu() or is_mlu(): + if not (args.use_split_cross_attention or args.use_quad_cross_attention): + ENABLE_PYTORCH_ATTENTION = True +elif is_amd() and torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logging.info(f"AMD arch: {arch}") + if any(a in arch for a in ["gfx1100", "gfx1101"]) and not (args.use_split_cross_attention or args.use_quad_cross_attention): + ENABLE_PYTORCH_ATTENTION = True +if ENABLE_PYTORCH_ATTENTION: + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) +if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) +else: + logging.warning("Could not set allow_fp16_bf16_reduction_math_sdp") -current_loaded_models = [] +def get_free_memory(dev=None, torch_free_too=False): + """ + Get free memory available on the device. -def module_size(module): - module_mem = 0 - sd = module.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nelement() * t.element_size() - return module_mem + Args: + dev: Torch device (optional, defaults to current device). + torch_free_too: If True, return (free_total, free_torch). + + Returns: + int or tuple: Free memory in bytes (or tuple with free_torch). + """ + if dev is None: + dev = get_torch_device() + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_free_total = psutil.virtual_memory().available + mem_free_torch = mem_free_total + else: + if directml_enabled: + total_vram = get_directml_vram(dev) + cache_key = (dev, 'active_models') + if cache_key not in _directml_active_memory_cache: + active_models = sum(m.model_loaded_memory() for m in current_loaded_models if m.device == dev) + _directml_active_memory_cache[cache_key] = active_models + active_models = _directml_active_memory_cache[cache_key] + mem_free_total = max(1024 * 1024 * 1024, total_vram - active_models * 1.2) + mem_free_torch = mem_free_total + if DEBUG_ENABLED: + logging.debug(f"DirectML: total_vram={total_vram / (1024**3):.0f} GB, active_models={active_models / (1024**3):.2f} GB, free={mem_free_total / (1024**3):.2f} GB") + elif is_intel_xpu(): + stats = torch.xpu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_torch = mem_reserved - mem_active + mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved + mem_free_total = mem_free_xpu + mem_free_torch + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_npu, _ = torch.npu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_npu + mem_free_torch + elif is_mlu(): + stats = torch.mlu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_mlu, _ = torch.mlu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_mlu + mem_free_torch + else: + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return (mem_free_total, mem_free_torch) if torch_free_too else mem_free_total + +def get_adaptive_min_free(mem_total, memory_required=None): + """ + Calculate adaptive min_free VRAM based on GPU memory and model requirements. + + Args: + mem_total (float): Total GPU VRAM in GB. + memory_required (float, optional): Estimated memory required by the model in GB. + + Returns: + float: Minimum free VRAM required in GB. + """ + # Base min_free as a fraction of total VRAM + base_min_free = mem_total * 0.25 # 25% of total VRAM as baseline + + if memory_required is not None: + min_free = max(base_min_free, memory_required) # Use memory_required directly, no extra multiplier + else: + min_free = base_min_free + + # Cap min_free to avoid excessive requirements + min_free = min(min_free, mem_total * 0.5) # Never exceed 50% of total VRAM + + # Minimum threshold for very small GPUs + min_free = max(min_free, 1.0 if mem_total < 6.0 else 1.5) + + if PROFILING_ENABLED: + memory_required_str = f"{memory_required:.2f}" if memory_required is not None else "None" + logging.debug(f"get_adaptive_min_free: mem_total={mem_total:.2f} GB, memory_required={memory_required_str} GB, min_free={min_free:.2f} GB") + + return min_free + +def memory_monitor(device, interval=5.0): + """Monitor memory usage in a background thread.""" + if not DEBUG_ENABLED: + return + + def monitor(): + while True: + log_vram_state(device) + time.sleep(interval) + threading.Thread(target=monitor, daemon=True).start() + +memory_monitor(get_torch_device()) + +def soft_empty_cache(clear=False, device=None, caller="unknown"): + """ + Clear PyTorch memory cache efficiently with VRAM check. + + Args: + clear (bool): Force cache clearing regardless of memory state. + device (torch.device): Device to clear cache for. Defaults to current device. + caller (str): Source of the call for debugging. + """ + if device is None: + device = get_torch_device() + + if PROFILING_ENABLED: + start_time = time.time() + logging.debug(f"soft_empty_cache called with clear={clear}, device={device}, caller={caller}") + + # Fixed threshold in bytes (100 MB) + MEMORY_THRESHOLD = 100 * 1024 * 1024 # 100 MB + cache_key = (device, 'free_memory') + + mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) + total_vram = get_total_memory(device) + if not clear and (mem_free_torch <= MEMORY_THRESHOLD or mem_free_total > 0.4 * total_vram): + if PROFILING_ENABLED: + logging.debug(f"soft_empty_cache: Skipped (free_vram={mem_free_total/1024**3:.2f} GB, free_torch={mem_free_torch/1024**3:.2f} GB)") + return + + try: + if is_device_cuda(device): + torch.cuda.empty_cache() + if clear and mem_free_torch < min(0.1 * total_vram, 1.0 * 1024**3): + gc.collect() + if torch.distributed.is_initialized(): + torch.cuda.ipc_collect() + elif cpu_state == CPUState.MPS: + torch.mps.empty_cache() + elif is_intel_xpu(): + torch.xpu.empty_cache() + elif is_ascend_npu(): + torch.npu.empty_cache() + elif is_mlu(): + torch.mlu.empty_cache() + + if PROFILING_ENABLED: + free_vram_after, free_torch_after = get_free_memory(device, torch_free_too=True) + logging.debug(f"After clear: free_vram={free_vram_after/1024**3:.2f} GB, free_torch={free_torch_after/1024**3:.2f} GB") + logging.debug(f"soft_empty_cache took {time.time() - start_time:.3f} s, gained={(free_vram_after - mem_free_total)/1024**3:.2f} GB") + except Exception as e: + if PROFILING_ENABLED: + logging.warning(f"Failed to clear cache for {device}: {str(e)}") + +def unload_all_models(): + """ + Unload all models from memory and clear cache if necessary. + """ + if PROFILING_ENABLED: + start_time = time.time() + logging.debug("unload_all_models called") + + for model in list(current_loaded_models): + model.model_unload() + + current_loaded_models.clear() + + device = get_torch_device() + free_vram = get_free_memory(device)[0] + total_vram = get_total_memory(device) + clear_aggressive = free_vram < 0.4 * total_vram + if PROFILING_ENABLED: + logging.debug(f"unload_all_models: free_vram={free_vram/1024**3:.2f} GB, aggressive_clear={clear_aggressive}") + + soft_empty_cache(clear=clear_aggressive, caller="unload_all_models") + + if PROFILING_ENABLED: + new_free_vram = get_free_memory(device)[0] + logging.debug(f"unload_all_models done: free={new_free_vram/1024**3:.2f} GB, " + f"gained={(new_free_vram - free_vram)/1024**3:.2f} GB, took={time.time() - start_time:.3f} s") + +def get_offload_stream(device): + """Get a stream for asynchronous weight offloading.""" + stream_counter = stream_counters.get(device, 0) + if NUM_STREAMS <= 1 or not is_device_cuda(device): + return None + if device in STREAMS: + ss = STREAMS[device] + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + if is_device_cuda(device): + ss[stream_counter].wait_stream(torch.cuda.current_stream()) + stream_counters[device] = stream_counter + return s + elif is_device_cuda(device): + ss = [torch.cuda.Stream(device=device, priority=0) for _ in range(NUM_STREAMS)] + STREAMS[device] = ss + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + stream_counters[device] = stream_counter + return s + return None + +def sync_stream(device, stream): + """Synchronize the given stream with the current CUDA stream.""" + if stream is None or not is_device_cuda(device): + return + torch.cuda.current_stream().wait_stream(stream) + +def cast_to_device(tensor, device, dtype, copy=False): + """Cast tensor to specified device and dtype with non-blocking support.""" + non_blocking = device_supports_non_blocking(device) + return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) + +def register_vram_optimizer(optimizer): + """Register a VRAM optimizer.""" + _vram_optimizers.append(optimizer) + +# Model management +current_loaded_models = [] class LoadedModel: def __init__(self, model): @@ -400,90 +752,295 @@ def __init__(self, model): self.device = model.load_device self.real_model = None self.currently_used = True + self.model_offloaded = False self.model_finalizer = None self._patcher_finalizer = None def _set_model(self, model): self._model = weakref.ref(model) - if model.parent is not None: + if hasattr(model, 'parent') and model.parent is not None: self._parent_model = weakref.ref(model.parent) self._patcher_finalizer = weakref.finalize(model, self._switch_parent) def _switch_parent(self): - model = self._parent_model() - if model is not None: - self._set_model(model) + if hasattr(self, '_parent_model'): + model = self._parent_model() + if model is not None: + self._set_model(model) @property def model(self): return self._model() def model_memory(self): - return self.model.model_size() + return self.model.model_size() if hasattr(self.model, 'model_size') else module_size(self.model) def model_loaded_memory(self): - return self.model.loaded_size() + return self.model.loaded_size() if hasattr(self.model, 'loaded_size') else module_size(self.model) def model_offloaded_memory(self): - return self.model.model_size() - self.model.loaded_size() + return self.model_memory() - self.model_loaded_memory() def model_memory_required(self, device): - if device == self.model.current_loaded_device(): - return self.model_offloaded_memory() - else: - return self.model_memory() + """ + Estimate memory required for the model on the specified device. - def model_load(self, lowvram_model_memory=0, force_patch_weights=False): - self.model.model_patches_to(self.device) - self.model.model_patches_to(self.model.model_dtype()) + Args: + device (torch.device): Target device for memory estimation. - # if self.model.loaded_size() > 0: - use_more_vram = lowvram_model_memory - if use_more_vram == 0: - use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) - real_model = self.model.model + Returns: + int: Memory required in bytes. + """ + # Fast path: use size if available + if hasattr(self.model, 'size') and self.model.size > 0: + return self.model.size - if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: - with torch.no_grad(): - real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) + # Check if model is already on the target device + if hasattr(self.model, 'current_loaded_device') and device == self.model.current_loaded_device(): + return self.model_offloaded_memory() - self.real_model = weakref.ref(real_model) - self.model_finalizer = weakref.finalize(real_model, cleanup_models) - return real_model + # Handle AutoencoderKL + if self.model.model is not None and isinstance(self.model.model, AutoencoderKL): + shape = getattr(self.model, 'last_shape', (1, 4, 64, 64)) + dtype = getattr(self.model, 'model_dtype', torch.float32)() + return estimate_vae_decode_memory(self.model.model, shape, dtype) + + # Sum memory for additional models + loaded_memory = 0 + if hasattr(self.model, 'additional_models'): + model_device = device + if hasattr(self.model.model, 'device'): + model_device = self.model.model.device + if DEBUG_ENABLED: + logging.debug(f"[DEBUG_CLONES] Model {self.model.__class__.__name__} using device {model_device}") + for m in self.model.additional_models: + try: + loaded_memory += m.model_memory_required(model_device) + except Exception as e: + if DEBUG_ENABLED: + logging.warning(f"[DEBUG_CLONES] Error calculating memory for additional model: {e}") + + return self.model_memory() + loaded_memory + + def model_load(self, lowvram_model_memory=0, force_patch_weights=False): + with profile_section("Model load"): + self.model.model_patches_to(self.device) + self.model.model_patches_to(self.model.model_dtype()) + use_more_vram = lowvram_model_memory if lowvram_model_memory > 0 else float('inf') + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model + if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals(): + with torch.no_grad(): + real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) + self.real_model = weakref.ref(real_model) + self.model_finalizer = weakref.finalize(real_model, cleanup_models) + return real_model def should_reload_model(self, force_patch_weights=False): - if force_patch_weights and self.model.lowvram_patch_counter() > 0: - return True - return False + return force_patch_weights and self.model.lowvram_patch_counter() > 0 def model_unload(self, memory_to_free=None, unpatch_weights=True): - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - if freed >= memory_to_free: - return False - self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None - self.real_model = None - return True + """ + Unload the model, freeing memory on both CPU and GPU. + Clears CUDA cache if needed, logs critical information. + + Args: + memory_to_free: Amount of memory to free (bytes), if partial unloading is needed. + unpatch_weights: Whether to unpatch model weights during unloading. + + Returns: + float: Estimated memory freed (in bytes). + """ + with profile_section("Model unload"): + if self.is_dead() or self.real_model is None: + if DEBUG_ENABLED: + logging.debug("[DEBUG_CLONES] Model is dead or real_model is None, skipping unload") + return 0 + + mem_freed = getattr(self.model, 'model_loaded_weight_memory', 0) if self.model is not None else 0 + is_cuda = is_device_cuda(self.device) - def model_use_more_vram(self, extra_memory, force_patch_weights=False): + try: + model_name = self.model.__class__.__name__ if self.model is not None else "None" + model_type = self.model.model.__class__.__name__ if self.model is not None and hasattr(self.model, 'model') else "Unknown" + if DEBUG_ENABLED: + logging.debug(f"[DEBUG_CLONES] Starting unload for {model_name}(type={model_type})") + + # Partial unload if requested and supported + if memory_to_free is not None and memory_to_free < mem_freed and hasattr(self.model, 'partially_unload'): + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + if freed >= memory_to_free: + if PROFILING_ENABLED: + logging.debug(f"[DEBUG_CLONES] Partial unload freed {freed / 1024**3:.2f} GB") + return freed + + # Full unload + if self.model is not None and hasattr(self.model, 'detach'): + self.model.detach(unpatch_all=unpatch_weights) + if self.model_finalizer is not None: + self.model_finalizer.detach() + self.model_finalizer = None + self.real_model = None + self._model = lambda: None + + # Garbage collection for non-CUDA devices + if not is_cuda: + gc.collect() + + # Clear CUDA cache if on CUDA device + if is_cuda: + device = self.device + free_vram = get_free_memory(device)[0] + total_vram = get_total_memory(device) + clear_aggressive = free_vram < 0.4 * total_vram + soft_empty_cache(clear=clear_aggressive, caller="model_unload") + + if PROFILING_ENABLED: + logging.debug(f"[DEBUG_CLONES] Unload complete for {model_name}") + return mem_freed + + except Exception as e: + if DEBUG_ENABLED: + logging.warning(f"[DEBUG_CLONES] Error during model_unload for {model_name}(type={model_type}): {e}") + return mem_freed + + def model_use_more_vram(self, use_more_vram, force_patch_weights=False): + if not use_more_vram: + if PROFILING_ENABLED: + logging.debug( + "model_use_more_vram: use_more_vram=False, returning 0") + return 0 + mem_required = self.model_memory_required(self.device) + extra_memory = min(mem_required * 0.3, 50 * 1024 * 1024 * 1024) # Reduced to 50 MB chunks return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) def __eq__(self, other): return self.model is other.model def __del__(self): - if self._patcher_finalizer is not None: + if hasattr(self, '_patcher_finalizer') and self._patcher_finalizer is not None: self._patcher_finalizer.detach() + if hasattr(self, '_model_finalizer') and self._model_finalizer is not None: + self._model_finalizer.detach() def is_dead(self): + """ + Check if the model is dead (real_model exists but model is garbage collected). + Returns True if the model is dead, False otherwise. + """ + if self.real_model is None: + return False # Model was never loaded or already unloaded return self.real_model() is not None and self.model is None +def module_size(model, shape=None, dtype=None): + """ + Estimate memory size of a module by summing parameter and buffer sizes, + or using VAE-specific estimation if shape and dtype are provided. + """ + from diffusers import AutoencoderKL + + module_mem = 0 + if shape is not None and dtype is not None and isinstance(model, AutoencoderKL): + try: + batch, channels, height, width = shape + # Adjusted memory estimate for VAE: reduced multiplier from 64*1.1 to 32*1.05 to avoid overestimation + base_memory = height * width * channels * 32 * 1.05 + size_of_dtype = dtype_size(dtype) + module_mem = base_memory * size_of_dtype + # Add parameter memory for VAE to account for model weights + param_mem = sum(p.numel() * p.element_size() for p in model.parameters()) + module_mem += param_mem + if DEBUG_ENABLED: + logging.debug( + f"Estimated VAE memory: shape={shape}, dtype={dtype}, " + f"params={param_mem / (1024**3):.2f} GB, total={module_mem / (1024**3):.2f} GB" + ) + except Exception as e: + logging.warning(f"Failed to estimate VAE memory for {model.__class__.__name__}: {str(e)}") + + if module_mem == 0: + try: + # Sum memory of state dict (parameters and buffers) + module_mem = sum(p.numel() * p.element_size() for p in model.state_dict().values()) + except AttributeError: + # Fallback: sum parameters and buffers separately + if hasattr(model, 'parameters'): + module_mem += sum(p.numel() * p.element_size() for p in model.parameters()) + if hasattr(model, 'buffers'): + module_mem += sum(b.numel() * b.element_size() for b in model.buffers()) + if module_mem == 0: + model_name = model.__class__.__name__.lower() + if 'vae' in model_name or isinstance(model, AutoencoderKL): + # Reduced fallback from 3.5 GB to 2.5 GB for VAE + module_mem = 2.5 * 1024**3 + logging.warning( + f"Could not estimate module size for {model.__class__.__name__}, " + f"assuming 2.5 GB for VAE" + ) + else: + # Minimal memory assumption for unknown models + module_mem = 1024 * 1024 + logging.warning( + f"Could not estimate module size for {model.__class__.__name__}, " + f"assuming minimal memory (1 MB)" + ) + + if VERBOSE_ENABLED: + logging.debug(f"Module size for {model.__class__.__name__}: {module_mem / (1024**3):.2f} GB") + return module_mem + +def dtype_size(dtype): + """Get the size of a data type in bytes.""" + dtype_size = 4 + if dtype in (torch.float16, torch.bfloat16): + dtype_size = 2 + elif dtype == torch.float32: + dtype_size = 4 + elif dtype in FLOAT8_TYPES: + dtype_size = 1 + else: + try: + dtype_size = dtype.itemsize + except: + pass + return dtype_size + +def get_adaptive_buffer(device): + """ + Calculate adaptive memory buffer based on VRAM size and available memory. + + Args: + device: Device to calculate buffer for. + + Returns: + Buffer size in bytes. + """ + mem_total = get_total_memory(device) if is_device_cuda(device) else 6 * 1024**3 + mem_free_total, _ = get_free_memory(device, torch_free_too=True) + # Use 2% for low VRAM (<50% free) or small GPUs, 5% otherwise + fraction = 0.02 if mem_free_total < mem_total * 0.5 or mem_total < 8 * 1024**3 else 0.05 + buffer = min(max(fraction * mem_total, 0.05 * 1024**3), 0.2 * 1024**3) # 0.05-0.2 GB + if PROFILING_ENABLED: + logging.debug(f"get_adaptive_buffer: mem_total={mem_total / 1024**3:.2f} GB, " + f"mem_free_total={mem_free_total / 1024**3:.2f} GB, buffer={buffer / 1024**3:.2f} GB") + return buffer + +def estimate_vae_decode_memory(model, shape, dtype): + """ + Estimate memory required for VAE decoding. + Uses module_size with shape and dtype for accurate estimation. + """ + total_memory = module_size(model, shape=shape, dtype=dtype) + if PROFILING_ENABLED: + logging.debug( + f"Estimated VAE decode memory: shape={shape}, dtype={dtype}, " + f"total={total_memory / (1024**3):.2f} GB" + ) + return total_memory + def use_more_memory(extra_memory, loaded_models, device): + """Use additional VRAM for loaded models.""" for m in loaded_models: if m.device == device: extra_memory -= m.model_use_more_vram(extra_memory) @@ -491,835 +1048,823 @@ def use_more_memory(extra_memory, loaded_models, device): break def offloaded_memory(loaded_models, device): + """Calculate offloaded memory for loaded models.""" offloaded_mem = 0 for m in loaded_models: if m.device == device: offloaded_mem += m.model_offloaded_memory() return offloaded_mem -WINDOWS = any(platform.win32_ver()) - -EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 -if WINDOWS: - EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue - -if args.reserve_vram is not None: - EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 - logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024))) - def extra_reserved_memory(): + """Get extra reserved VRAM.""" return EXTRA_RESERVED_VRAM def minimum_inference_memory(): - return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() - -def free_memory(memory_required, device, keep_loaded=[]): - cleanup_models_gc() - unloaded_model = [] - can_unload = [] - unloaded_models = [] - - for i in range(len(current_loaded_models) -1, -1, -1): - shift_model = current_loaded_models[i] - if shift_model.device == device: - if shift_model not in keep_loaded and not shift_model.is_dead(): - can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) - shift_model.currently_used = False - - for x in sorted(can_unload): - i = x[-1] - memory_to_free = None - if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): - unloaded_model.append(i) + """Get minimum memory required for inference.""" + return (1024 * 1024 * 1024) * 0.6 + extra_reserved_memory() # Reduced to 600 MB - for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) - if len(unloaded_model) > 0: - soft_empty_cache() - else: - if vram_state != VRAMState.HIGH_VRAM: - mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) - if mem_free_torch > mem_free_total * 0.25: - soft_empty_cache() - return unloaded_models +def cleanup_models_gc(): + """Clean up dead models and collect garbage if significant memory is freed.""" + dead_memory = 0 + for cur in current_loaded_models: + if cur.is_dead(): + dead_memory += cur.model_memory() + + if dead_memory > 50 * 1024 * 1024: # 50 MB threshold + if PROFILING_ENABLED: + device = get_torch_device() + free_vram = get_free_memory(device)[0] + total_vram = get_total_memory(device) + logging.debug(f"cleanup_models_gc: dead_memory={dead_memory/1024**2:.2f} MB, " + f"free_vram={free_vram/1024**3:.2f} GB") + + soft_empty_cache(clear=False, caller="cleanup_models_gc") + + i = len(current_loaded_models) - 1 + while i >= 0: + if current_loaded_models[i].is_dead(): + logging.warning(f"Removing dead model {current_loaded_models[i].real_model().__class__.__name__}") + current_loaded_models.pop(i) + i -= 1 + +def free_memory(memory_required, device, keep_loaded=None, loaded_models=None, caller="unknown"): + """ + Free memory on the device by unloading models efficiently, prioritizing unused models. + + Args: + memory_required (int): Memory needed in bytes. + device (torch.device): Device to free memory on. + keep_loaded (list, optional): Models to keep loaded. Defaults to []. + loaded_models (list, optional): List of models to consider for unloading. Defaults to current_loaded_models. + caller (str): Source of the call for debugging. + + Returns: + list: Unloaded models. + """ + with profile_section("free_memory"): + # Initialize defaults + if keep_loaded is None: + keep_loaded = [] + if loaded_models is None: + loaded_models = current_loaded_models + + # Cache memory state to avoid redundant calls + cache_key = (device, 'free_memory') + mem_free = _device_cache.get(cache_key, None) + if mem_free is None: + mem_free = get_free_memory(device, torch_free_too=True) + _device_cache[cache_key] = mem_free + mem_free_total = mem_free[0] if isinstance(mem_free, tuple) else mem_free + total_vram = get_total_memory(device) + + # Log initial state if profiling is enabled + if PROFILING_ENABLED: + logging.debug( + f"free_memory: requested={memory_required / 1024**3:.2f} GB, " + f"free={mem_free_total / 1024**3:.2f} GB, models={len(loaded_models)}, " + f"device={device}, caller={caller}" + ) + + # Skip if enough VRAM (>20% of total or required memory available) + if mem_free_total > max(memory_required, 0.4 * total_vram): + if PROFILING_ENABLED: + logging.debug(f"free_memory: Skipped (free_vram={mem_free_total / 1024**3:.2f} GB)") + return [] + + # Apply VRAM optimizers + for optimizer in _vram_optimizers: + memory_required = optimizer(memory_required, device, keep_loaded) + + # Ensure minimum inference memory + memory_required = max(memory_required, minimum_inference_memory()) + + # Clean up dead models + cleanup_models_gc() + + unloaded_models = [] + can_unload = [] + + # Collect models that can be unloaded + for i in range(len(loaded_models) - 1, -1, -1): + model = loaded_models[i] + if model.device == device and model not in keep_loaded and not model.is_dead() and not model.currently_used: + mem_required = model.model_memory_required(device) + can_unload.append((mem_required, model, i)) + can_unload.sort(reverse=True) # Prioritize models using more memory + + # Calculate memory to free + memory_to_free = memory_required - mem_free_total + extra_reserved_memory() + + # Unload models to free required memory + for mem, model, index in can_unload: + try: + model_id = getattr(model.model, 'model_id', id(model.model) if hasattr(model.model, 'model') else model.model.__class__.__name__) + model_type = model.model.__class__.__name__ if hasattr(model.model, 'model') else 'Unknown' + mem_freed = model.model_unload(memory_to_free=memory_to_free) + loaded_models.pop(index) + unloaded_models.append(model) + if model.model is not None and hasattr(model.model, 'detach'): + model.model.detach(unpatch_all=True) + mem_free_total += mem_freed + _device_cache[cache_key] = (mem_free_total, mem_free[1] if isinstance(mem_free, tuple) else 0) + if PROFILING_ENABLED: + logging.debug( + f"Unloaded model: id={model_id}, type={model_type}, " + f"freed={mem_freed / 1024**3:.2f} GB, free_vram={mem_free_total / 1024**3:.2f} GB" + ) + if mem_free_total >= memory_required: + break + except Exception as e: + if DEBUG_ENABLED: + logging.warning(f"Failed to unload model at index {index}: {e}") + + # Utilize excess memory if available + use_more_memory(mem_free_total - memory_required, loaded_models, device) + + if PROFILING_ENABLED: + logging.debug( + f"free_memory done: free={mem_free_total / 1024**3:.2f} GB, unloaded={len(unloaded_models)} models" + ) + + return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + """ + Load multiple models to GPU, managing VRAM efficiently. + + Args: + models: List of models to load. + memory_required: Estimated memory needed (bytes). + force_patch_weights: Force re-patching model weights. + minimum_memory_required: Minimum memory needed for inference. + force_full_load: Force full model loading regardless of VRAM state. + """ cleanup_models_gc() - global vram_state - - inference_memory = minimum_inference_memory() - extra_mem = max(inference_memory, memory_required + extra_reserved_memory()) - if minimum_memory_required is None: - minimum_memory_required = extra_mem - else: - minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) - - models = set(models) - - models_to_load = [] - - for x in models: - loaded_model = LoadedModel(x) - try: - loaded_model_index = current_loaded_models.index(loaded_model) - except: - loaded_model_index = None + with profile_section("load_models_gpu"): + # Memory cache for efficient memory queries + memory_cache = {} + def get_cached_memory(device, torch_free_too=False): + cache_key = (device, torch_free_too) + if cache_key not in memory_cache: + try: + memory_cache[cache_key] = get_free_memory(device, torch_free_too) + except Exception as e: + logging.error(f"Failed to get memory for {device}: {e}") + memory_cache[cache_key] = (0, 0) if torch_free_too else 0 + return memory_cache[cache_key] + + if minimum_memory_required is None: + minimum_memory_required = minimum_inference_memory() + device = get_torch_device() + if vram_state in (VRAMState.DISABLED, VRAMState.SHARED): + return + + model_lookup = {m.model: m for m in current_loaded_models if m.model is not None} + + # Reset currently_used flag for all loaded models + for loaded_model in current_loaded_models: + loaded_model.currently_used = False + + loaded = [] + # Prepare models to load + for model in models: + if not hasattr(model, "model"): + continue + loaded_model = model_lookup.get(model) + if loaded_model is None: + loaded_model = LoadedModel(model) + model_lookup[model] = loaded_model + loaded_model.currently_used = True + loaded.append(loaded_model) + + # Unload unused models only if necessary + device = get_torch_device() + to_remove = [] + if len(current_loaded_models) > 10 or (is_device_cuda(device) and get_cached_memory(device) < 1 * 1024 * 1024 * 1024): # >10 models or <1GB VRAM + for i, loaded_model in enumerate(current_loaded_models): + if not loaded_model.currently_used: + model = loaded_model.model + if model is None: + to_remove.append(i) + continue + try: + mem_freed = loaded_model.model_unload() + to_remove.append(i) + if hasattr(model, 'detach') and hasattr(model, 'patched_weights') and model.patched_weights: + model.detach(unpatch_all=True) + except Exception as e: + logging.error(f"Failed to unload model at index {i}: {e}") + for i in reversed(to_remove): + current_loaded_models.pop(i) - if loaded_model_index is not None: - loaded = current_loaded_models[loaded_model_index] - loaded.currently_used = True - models_to_load.append(loaded) - else: - if hasattr(x, "model"): - logging.info(f"Requested to load {x.model.__class__.__name__}") - models_to_load.append(loaded_model) - - for loaded_model in models_to_load: - to_unload = [] - for i in range(len(current_loaded_models)): - if loaded_model.model.is_clone(current_loaded_models[i].model): - to_unload = [i] + to_unload - for i in to_unload: - current_loaded_models.pop(i).model.detach(unpatch_all=False) - - total_memory_required = {} - for loaded_model in models_to_load: - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - - for device in total_memory_required: - if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) - - for device in total_memory_required: - if device != torch.device("cpu"): - free_mem = get_free_memory(device) - if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) - logging.info("{} models unloaded.".format(len(models_l))) - - for loaded_model in models_to_load: - model = loaded_model.model - torch_dev = model.load_device - if is_device_cpu(torch_dev): - vram_set_state = VRAMState.DISABLED - else: - vram_set_state = vram_state lowvram_model_memory = 0 - if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load: - loaded_memory = loaded_model.model_loaded_memory() - current_free_mem = get_free_memory(torch_dev) + loaded_memory + if vram_state == VRAMState.LOW_VRAM and not force_full_load: + lowvram_model_memory = max( + int(get_total_memory(device) * MIN_WEIGHT_MEMORY_RATIO), 400 * 1024 * 1024) + elif vram_state == VRAMState.NO_VRAM: + lowvram_model_memory = 1 + + for l in loaded: + l.currently_used = True + if l.should_reload_model(force_patch_weights=force_patch_weights) or l.real_model is None: + mem_needed = l.model_memory_required(device) + mem_free = get_free_memory(device) + if DEBUG_ENABLED: + logging.debug( + f"Loading {l.model.__class__.__name__}: mem_needed={mem_needed / 1024**3:.2f} GB, free={mem_free / 1024**3:.2f} GB") + + if mem_free < mem_needed + minimum_memory_required: + free_memory(mem_needed + minimum_memory_required, + device, keep_loaded=loaded) + mem_free = get_free_memory(device) + + stream = get_offload_stream(device) + with torch.cuda.stream(stream) if stream is not None else torch.no_grad(): + l.model_load(lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) + if loaded_model not in current_loaded_models: + current_loaded_models.append(l) # append for efficiency + sync_stream(device, stream) + if DEBUG_ENABLED: + logging.debug( + f"Loaded {l.model.__class__.__name__}: free={get_free_memory(device) / 1024**3:.2f} GB") - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) - lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) - - if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 0.1 - - loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) - current_loaded_models.insert(0, loaded_model) return def load_model_gpu(model): + """Load a single model to GPU, wrapper around load_models_gpu.""" return load_models_gpu([model]) + def loaded_models(only_currently_used=False): + """Return list of loaded models, optionally only those currently used.""" output = [] for m in current_loaded_models: - if only_currently_used: - if not m.currently_used: - continue - + if only_currently_used and not m.currently_used: + continue output.append(m.model) return output +# Data type selection +def supports_fp8_compute(device=None): + """Check if the device supports FP8 computation.""" + if not is_nvidia(): + return False + if device is None: + device = get_torch_device() + props = torch.cuda.get_device_properties(device) + if props.major >= 9: # Ada Lovelace + return True + if props.major == 8 and props.minor >= 9 and torch_version_numeric >= (2, 3): + if any(platform.win32_ver()) and torch_version_numeric < (2, 4): + return False + return True + return False -def cleanup_models_gc(): - do_gc = False - for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.is_dead(): - logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) - do_gc = True - break - - if do_gc: - gc.collect() - soft_empty_cache() - - for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.is_dead(): - logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) - - - -def cleanup_models(): - to_delete = [] - for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: - to_delete = [i] + to_delete +def supports_dtype(dtype, device): + """Check if the device supports the given data type.""" + if dtype == torch.bfloat16: + if is_nvidia(): + return torch.cuda.get_device_properties(device).major >= 8 + elif is_amd(): + arch = torch.cuda.get_device_properties(device).gcnArchName + return any(a in arch for a in ["gfx941", "gfx942"]) + return False + elif dtype in (torch.float16, torch.float32): + return True + elif dtype in FLOAT8_TYPES: + return supports_fp8_compute(device) + return False - for i in to_delete: - x = current_loaded_models.pop(i) - del x +def supports_cast(dtype, device): + """Check if the device supports casting to the given data type.""" + if dtype == torch.bfloat16: + return True + return supports_dtype(dtype, device) -def dtype_size(dtype): - dtype_size = 4 - if dtype == torch.float16 or dtype == torch.bfloat16: - dtype_size = 2 - elif dtype == torch.float32: - dtype_size = 4 - else: - try: - dtype_size = dtype.itemsize - except: #Old pytorch doesn't have .itemsize - pass - return dtype_size +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + """Determine if FP16 should be used for the device.""" + if device is None: + device = get_torch_device() + if FORCE_FP32: + return False + if args.force_fp16: + return supports_cast(torch.float16, device) + if is_intel_xpu(): + return True + if is_mlu(): + props = torch.mlu.get_device_properties(device) + return props.major >= 3 + if is_ascend_npu(): + return False + if is_amd(): + arch = torch.cuda.get_device_properties(device).gcnArchName + if any(a in arch for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): + return manual_cast + return True + props = torch.cuda.get_device_properties(device) + if is_nvidia(): + # Prefer FP32 for low VRAM or older GPUs + total_vram = get_total_memory(device) / (1024**3) + if total_vram < 5.9 or props.major <= 7: # Turing (7.5) or Pascal (6.x) + return False + if any(platform.win32_ver()) and props.major <= 7: + return manual_cast and torch.cuda.is_bf16_supported() + if props.major >= 8: + return True + return torch.cuda.is_bf16_supported() and manual_cast and (not prioritize_performance or model_params * 4 > get_total_memory(device)) -def unet_offload_device(): - if vram_state == VRAMState.HIGH_VRAM: - return get_torch_device() - else: - return torch.device("cpu") +def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + """Determine if BF16 should be used for the device.""" + if device is None: + device = get_torch_device() + if args.force_fp16 or FORCE_FP32: + return False + if not is_device_cuda(device): + return False # BF16 not supported on CPU or MPS + props = torch.cuda.get_device_properties(device) + if is_nvidia(): + return props.major >= 8 and supports_cast(torch.bfloat16, device) + elif is_amd(): + arch = props.gcnArchName + return any(a in arch for a in ["gfx941", "gfx942"]) and supports_cast(torch.bfloat16, device) + return False -def unet_inital_load_device(parameters, dtype): - torch_dev = get_torch_device() - if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: - return torch_dev +def vae_dtype(device=None, model=None): + """ + Select appropriate data type for VAE. + + Args: + device: PyTorch device (e.g., 'cuda', 'cpu'). Defaults to get_torch_device(). + model: Optional model to check compatibility (not used in this implementation). + + Returns: + torch.dtype: Appropriate data type for VAE (e.g., torch.float32, torch.float16, torch.bfloat16). + """ + if device is None: + device = get_torch_device() + + # Handle CPU case explicitly to avoid CUDA calls + if device.type == 'cpu': + logging.debug(f"VAE dtype: torch.float32 (CPU device)") + return torch.float32 - cpu_dev = torch.device("cpu") - if DISABLE_SMART_MEMORY: - return cpu_dev + # Handle forced FP32/FP16 via command-line arguments + if args.force_fp32_vae: + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.float32 (forced via --force-fp32-vae)") + return torch.float32 + if args.force_fp16_vae: + if supports_cast(torch.float16, device): + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.float16 (forced via --force-fp16-vae)") + return torch.float16 + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.float32 (FP16 not supported on {device})") + return torch.float32 - model_size = dtype_size(dtype) * parameters + # Handle NVIDIA GPUs + if is_nvidia(): + props = torch.cuda.get_device_properties(device) + total_vram = get_total_memory(device) / (1024**3) + if total_vram < 5.9 or props.major <= 7: # Turing (7.5) or Pascal (6.x) + # Try FP16 with fallback to FP32 if unstable + if supports_cast(torch.float16, device) and total_vram >= 3.9: + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.float16 (Turing SM {props.major}.{props.minor}, VRAM {total_vram:.1f} GB)") + return torch.float16 + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.float32 (Turing SM {props.major}.{props.minor}, low VRAM {total_vram:.1f} GB)") + return torch.float32 + + # Handle bfloat16 and FP16 for other devices + if should_use_bf16(device=device, prioritize_performance=False): + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.bfloat16 (device supports BF16)") + return torch.float16 + if should_use_fp16(device=device, prioritize_performance=False): + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.float16 (device supports FP16)") + return torch.float16 - mem_dev = get_free_memory(torch_dev) - mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev: - return torch_dev - else: - return cpu_dev + # Default fallback + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: torch.float32 (default fallback)") + return torch.float32 -def maximum_vram_for_weights(device=None): - return (get_total_memory(device) * 0.88 - minimum_inference_memory()) +def unet_dtype(device=None, model=None, model_params=None, supported_dtypes=None, weight_dtype=None): + """Select appropriate data type for UNet.""" + if device is None: + device = get_torch_device() + model_params = module_size(model) // 4 if model is not None else 0 -def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None): - if model_params < 0: - model_params = 1000000000000000000000 - if args.fp32_unet: - return torch.float32 - if args.fp64_unet: - return torch.float64 - if args.bf16_unet: - return torch.bfloat16 - if args.fp16_unet: - return torch.float16 - if args.fp8_e4m3fn_unet: + # FP8 support + if args.fp8_e4m3fn_unet and supports_fp8_compute(device): return torch.float8_e4m3fn - if args.fp8_e5m2_unet: + if args.fp8_e5m2_unet and supports_fp8_compute(device): return torch.float8_e5m2 - if args.fp8_e8m0fnu_unet: - return torch.float8_e8m0fnu - fp8_dtype = None if weight_dtype in FLOAT8_TYPES: fp8_dtype = weight_dtype - if fp8_dtype is not None: - if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive + if supports_fp8_compute(device): return fp8_dtype - free_model_memory = maximum_vram_for_weights(device) if model_params * 2 > free_model_memory: return fp8_dtype - if PRIORITIZE_FP16 or weight_dtype == torch.float16: - if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params): - return torch.float16 - - for dt in supported_dtypes: - if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params): - if torch.float16 in supported_dtypes: - return torch.float16 - if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params): - if torch.bfloat16 in supported_dtypes: - return torch.bfloat16 + # Check supported_dtypes and weight_dtype + if supported_dtypes is not None and weight_dtype is not None: + for dtype in supported_dtypes: + if dtype == weight_dtype: + return dtype + # Fallback to bf16/fp16/fp32 based on device and args + if args.force_fp16 and supports_cast(torch.float16, device): + return torch.float16 + if args.force_fp32: + return torch.float32 + if should_use_bf16(device, model_params, prioritize_performance=True): + return torch.bfloat16 + if should_use_fp16(device, model_params, prioritize_performance=True): + return torch.float16 for dt in supported_dtypes: if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True): - if torch.float16 in supported_dtypes: - return torch.float16 + return torch.float16 if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True): - if torch.bfloat16 in supported_dtypes: - return torch.bfloat16 - + return torch.bfloat16 return torch.float32 -# None means no manual cast +def unet_offload_device(): + """Determine device for UNet offloading (GPU or CPU).""" + if vram_state == VRAMState.HIGH_VRAM: + return get_torch_device() + return torch.device("cpu") + +def unet_inital_load_device(parameters, dtype): + """Determine initial load device for UNet based on model size and dtype.""" + torch_dev = get_torch_device() + if vram_state in [VRAMState.HIGH_VRAM, VRAMState.SHARED]: + return torch_dev + cpu_dev = torch.device("cpu") + if DISABLE_SMART_MEMORY: + return cpu_dev + model_size = dtype_size(dtype) * parameters + mem_dev = get_free_memory(torch_dev) + mem_cpu = get_free_memory(cpu_dev) + if mem_dev > mem_cpu and model_size < mem_dev * 0.8: # 80% threshold + return torch_dev + return cpu_dev + def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): - if weight_dtype == torch.float32 or weight_dtype == torch.float64: + """Determine if manual casting is needed for UNet dtype.""" + # No cast needed for fp32/fp64 + if weight_dtype in [torch.float32, torch.float64]: + return None + + # Check FP8 support + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and supports_fp8_compute(inference_device): return None - fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) + # Check FP16 support + fp16_supported = should_use_fp16(inference_device, prioritize_performance=True) if fp16_supported and weight_dtype == torch.float16: return None + # Check BF16 support bf16_supported = should_use_bf16(inference_device) if bf16_supported and weight_dtype == torch.bfloat16: return None - fp16_supported = should_use_fp16(inference_device, prioritize_performance=True) - if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes: + # Prioritize FP16 if supported and in supported_dtypes + if fp16_supported and torch.float16 in supported_dtypes: return torch.float16 + # Check other supported dtypes for dt in supported_dtypes: if dt == torch.float16 and fp16_supported: return torch.float16 if dt == torch.bfloat16 and bf16_supported: return torch.bfloat16 - + if dt in [torch.float8_e4m3fn, torch.float8_e5m2] and supports_fp8_compute(inference_device): + return dt + # Fallback to FP32 return torch.float32 + def text_encoder_offload_device(): - if args.gpu_only: - return get_torch_device() - else: - return torch.device("cpu") + """Determine device for offloading text encoder.""" + return torch.device("cpu") # Keep offload on CPU to save VRAM + def text_encoder_device(): - if args.gpu_only: - return get_torch_device() - elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: - if should_use_fp16(prioritize_performance=False): - return get_torch_device() - else: - return torch.device("cpu") - else: - return torch.device("cpu") + """Determine device for text encoder (prefer GPU).""" + if vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.LOW_VRAM): + return get_torch_device() # Prefer GPU for low VRAM + return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): - if load_device == offload_device or model_size <= 1024 * 1024 * 1024: - return offload_device - + """Determine initial device for text encoder.""" + if load_device == offload_device or model_size <= 512 * 1024 * 1024: + return load_device if is_device_mps(load_device): return load_device - mem_l = get_free_memory(load_device) mem_o = get_free_memory(offload_device) if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: return load_device - else: - return offload_device + return offload_device -def text_encoder_dtype(device=None): - if args.fp8_e4m3fn_text_enc: - return torch.float8_e4m3fn - elif args.fp8_e5m2_text_enc: - return torch.float8_e5m2 - elif args.fp16_text_enc: - return torch.float16 - elif args.bf16_text_enc: - return torch.bfloat16 - elif args.fp32_text_enc: - return torch.float32 +def unet_inital_load_device(parameters, dtype): + """Determine initial load device for UNet based on model size and dtype.""" + torch_dev = get_torch_device() + if vram_state in [VRAMState.HIGH_VRAM, VRAMState.SHARED]: + return torch_dev - if is_device_cpu(device): - return torch.float16 + cpu_dev = torch.device("cpu") + if DISABLE_SMART_MEMORY: + return cpu_dev - return torch.float16 + model_size = dtype_size(dtype) * parameters # Size in bytes + mem_dev = get_free_memory(torch_dev) # Free VRAM + mem_cpu = get_free_memory(cpu_dev) # Free RAM + # Prefer GPU if VRAM > RAM and model fits in VRAM + if mem_dev > mem_cpu and model_size < mem_dev * 0.8: # 80% threshold + return torch_dev + return cpu_dev -def intermediate_device(): - if args.gpu_only: - return get_torch_device() - else: - return torch.device("cpu") +def maximum_vram_for_weights(device=None): + """Calculate maximum VRAM available for model weights.""" + if device is None: + device = get_torch_device() + return (get_total_memory(device) * 0.9 - minimum_inference_memory()) -def vae_device(): - if args.cpu_vae: - return torch.device("cpu") - return get_torch_device() -def vae_offload_device(): +def force_channels_last(): + """ + Check if channels_last format should be used for tensors. + Safe for Turing GPUs with FP32 VAE. + """ + if args.force_channels_last: + if DEBUG_ENABLED: + logging.debug("force_channels_last: Enabled via --force-channels-last") + return True + if cpu_state == CPUState.GPU and is_nvidia() and torch.cuda.is_available(): + if DEBUG_ENABLED: + total_vram = get_total_memory(get_torch_device()) / (1024 * 1024 * 1024) # VRAM in GB + logging.debug( + f"force_channels_last: Enabled for NVIDIA GPU with {total_vram:.1f} GB VRAM") + return True + logging.debug("force_channels_last: Disabled") + return False + +def intermediate_device(): + """Determine device for intermediate computations (GPU or CPU).""" if args.gpu_only: return get_torch_device() - else: - return torch.device("cpu") - -def vae_dtype(device=None, allowed_dtypes=[]): - if args.fp16_vae: - return torch.float16 - elif args.bf16_vae: - return torch.bfloat16 - elif args.fp32_vae: - return torch.float32 - - for d in allowed_dtypes: - if d == torch.float16 and should_use_fp16(device): - return d - - # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 - if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): - return d - - return torch.float32 + return torch.device("cpu") def get_autocast_device(dev): + """Determine device type for autocast (e.g., cuda, cpu, mps).""" if hasattr(dev, 'type'): return dev.type return "cuda" -def supports_dtype(device, dtype): #TODO - if dtype == torch.float32: - return True - if is_device_cpu(device): - return False - if dtype == torch.float16: - return True - if dtype == torch.bfloat16: - return True - return False +def vae_offload_device(): + """Determine device for VAE offloading (GPU or CPU).""" + if args.gpu_only: + return get_torch_device() + return torch.device("cpu") -def supports_cast(device, dtype): #TODO - if dtype == torch.float32: - return True - if dtype == torch.float16: - return True - if directml_enabled: #TODO: test this - return False - if dtype == torch.bfloat16: - return True - if is_device_mps(device): - return False - if dtype == torch.float8_e4m3fn: - return True - if dtype == torch.float8_e5m2: - return True - return False +def vae_device(): + """Determine device for VAE (GPU or CPU).""" + if args.cpu_vae: + return torch.device("cpu") + return get_torch_device() def pick_weight_dtype(dtype, fallback_dtype, device=None): + """Select appropriate dtype for model weights, using fallback if needed.""" if dtype is None: dtype = fallback_dtype elif dtype_size(dtype) > dtype_size(fallback_dtype): dtype = fallback_dtype - if not supports_cast(device, dtype): dtype = fallback_dtype - return dtype +def is_device_mps(device): + """Check if device is MPS (Apple Silicon).""" + return isinstance(device, torch.device) and device.type == "mps" + def device_supports_non_blocking(device): + """Check if device supports non-blocking data transfers.""" if is_device_mps(device): - return False #pytorch bug? mps doesn't support non blocking + return False # pytorch bug? mps doesn't support non blocking if is_intel_xpu(): return False - if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) + if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) return False if directml_enabled: return False return True def device_should_use_non_blocking(device): + """Determine if non-blocking transfers should be used (disabled due to memory issues).""" if not device_supports_non_blocking(device): return False return False # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others -def force_channels_last(): - if args.force_channels_last: - return True - - #TODO - return False - - -STREAMS = {} -NUM_STREAMS = 1 -if args.async_offload: - NUM_STREAMS = 2 - logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) - -stream_counters = {} -def get_offload_stream(device): - stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: - return None - - if device in STREAMS: - ss = STREAMS[device] - s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) - stream_counters[device] = stream_counter - return s - elif is_device_cuda(device): - ss = [] - for k in range(NUM_STREAMS): - ss.append(torch.cuda.Stream(device=device, priority=0)) - STREAMS[device] = ss - s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) - stream_counters[device] = stream_counter - return s - return None - -def sync_stream(device, stream): - if stream is None: - return +def text_encoder_dtype(device=None, model=None): + """Select appropriate data type for text encoder.""" + if device is None: + device = get_torch_device() + model_params = module_size(model) // 4 if model is not None else 0 + # FP8 support (only for CUDA devices) if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) - -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): - if device is None or weight.device == device: - if not copy: - if dtype is None or weight.dtype == dtype: - return weight - if stream is not None: - with stream: - return weight.to(dtype=dtype, copy=copy) - return weight.to(dtype=dtype, copy=copy) - - if stream is not None: - with stream: - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight, non_blocking=non_blocking) - else: - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight, non_blocking=non_blocking) - return r - -def cast_to_device(tensor, device, dtype, copy=False): - non_blocking = device_supports_non_blocking(device) - return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) - -def sage_attention_enabled(): - return args.use_sage_attention - -def flash_attention_enabled(): - return args.use_flash_attention - -def xformers_enabled(): - global directml_enabled - global cpu_state - if cpu_state != CPUState.GPU: - return False - if is_intel_xpu(): - return False - if is_ascend_npu(): - return False - if is_mlu(): - return False - if directml_enabled: - return False - return XFORMERS_IS_AVAILABLE - - -def xformers_enabled_vae(): - enabled = xformers_enabled() - if not enabled: - return False - - return XFORMERS_ENABLED_VAE - -def pytorch_attention_enabled(): - global ENABLE_PYTORCH_ATTENTION - return ENABLE_PYTORCH_ATTENTION - -def pytorch_attention_enabled_vae(): - if is_amd(): - return False # enabling pytorch attention on AMD currently causes crash when doing high res - return pytorch_attention_enabled() - -def pytorch_attention_flash_attention(): - global ENABLE_PYTORCH_ATTENTION - if ENABLE_PYTORCH_ATTENTION: - #TODO: more reliable way of checking for flash attention? - if is_nvidia(): #pytorch flash attention only works on Nvidia - return True - if is_intel_xpu(): - return True - if is_ascend_npu(): - return True - if is_mlu(): - return True - if is_amd(): - return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention - return False - -def force_upcast_attention_dtype(): - upcast = args.force_upcast_attention + if getattr(args, 'fp8_e4m3fn_text_enc', False) and supports_fp8_compute(device): + return torch.float8_e4m3fn + if getattr(args, 'fp8_e5m2_text_enc', False) and supports_fp8_compute(device): + return torch.float8_e5m2 + # Check model_dtype safely + model_dtype = getattr(args, 'model_dtype', None) + if model_dtype is not None: + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + if model_dtype in dtype_map: + # Only allow BF16 on supported devices + if model_dtype == "bf16" and not (is_device_cuda(device) and should_use_bf16(device)): + return torch.float16 + if supports_cast(dtype_map[model_dtype], device): + return dtype_map[model_dtype] + # CPU/MPS fallback to FP32 + if not is_device_cuda(device): + return torch.float32 + # CUDA devices: BF16/FP16 based on device support + if should_use_bf16(device, model_params, prioritize_performance=True): + return torch.bfloat16 + if should_use_fp16(device, model_params, prioritize_performance=True): + return torch.float16 + return torch.float16 # Default to FP16 for GPU - macos_version = mac_version() - if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS - upcast = True - if upcast: - return {torch.float16: torch.float32} - else: +def mac_version(): + """Get macOS version as a tuple.""" + try: + return tuple(int(n) for n in platform.mac_ver()[0].split(".")) + except: return None -def get_free_memory(dev=None, torch_free_too=False): - global directml_enabled - if dev is None: - dev = get_torch_device() - - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_free_total = psutil.virtual_memory().available - mem_free_torch = mem_free_total - else: - if directml_enabled: - mem_free_total = 1024 * 1024 * 1024 #TODO - mem_free_torch = mem_free_total - elif is_intel_xpu(): - stats = torch.xpu.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_torch = mem_reserved - mem_active - mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved - mem_free_total = mem_free_xpu + mem_free_torch - elif is_ascend_npu(): - stats = torch.npu.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_npu, _ = torch.npu.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_npu + mem_free_torch - elif is_mlu(): - stats = torch.mlu.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_mlu, _ = torch.mlu.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_mlu + mem_free_torch - else: - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - if torch_free_too: - return (mem_free_total, mem_free_torch) - else: - return mem_free_total - -def cpu_mode(): - global cpu_state - return cpu_state == CPUState.CPU - -def mps_mode(): - global cpu_state - return cpu_state == CPUState.MPS - -def is_device_type(device, type): - if hasattr(device, 'type'): - if (device.type == type): - return True - return False - -def is_device_cpu(device): - return is_device_type(device, 'cpu') - -def is_device_mps(device): - return is_device_type(device, 'mps') - -def is_device_cuda(device): - return is_device_type(device, 'cuda') - -def is_directml_enabled(): - global directml_enabled - if directml_enabled: - return True - - return False - -def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): - if device is not None: - if is_device_cpu(device): - return False - - if args.force_fp16: - return True - - if FORCE_FP32: - return False - - if is_directml_enabled(): - return True - - if (device is not None and is_device_mps(device)) or mps_mode(): - return True - - if cpu_mode(): - return False - - if is_intel_xpu(): - return True - - if is_ascend_npu(): - return True - - if is_mlu(): - return True - - if torch.version.hip: - return True - - props = torch.cuda.get_device_properties(device) - if props.major >= 8: - return True - - if props.major < 6: - return False - - #FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32 - nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] - for x in nvidia_10_series: - if x in props.name.lower(): - if WINDOWS or manual_cast: - return True - else: - return False #weird linux behavior where fp32 is faster - - if manual_cast: - free_model_memory = maximum_vram_for_weights(device) - if (not prioritize_performance) or model_params * 4 > free_model_memory: - return True - - if props.major < 7: - return False - - #FP16 is just broken on these cards - nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"] - for x in nvidia_16_series: - if x in props.name: - return False - - return True - -def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): - if device is not None: - if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow - return False - - if FORCE_FP32: - return False - - if directml_enabled: - return False - - if (device is not None and is_device_mps(device)) or mps_mode(): - if mac_version() < (14,): - return False - return True - - if cpu_mode(): - return False - - if is_intel_xpu(): - return True - - if is_ascend_npu(): - return True - - if is_amd(): - arch = torch.cuda.get_device_properties(device).gcnArchName - if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 - if manual_cast: - return True - return False - - props = torch.cuda.get_device_properties(device) - - if is_mlu(): - if props.major > 3: - return True - - if props.major >= 8: - return True - - bf16_works = torch.cuda.is_bf16_supported() - - if bf16_works and manual_cast: - free_model_memory = maximum_vram_for_weights(device) - if (not prioritize_performance) or model_params * 4 > free_model_memory: - return True - - return False - -def supports_fp8_compute(device=None): - if not is_nvidia(): - return False - - props = torch.cuda.get_device_properties(device) - if props.major >= 9: - return True - if props.major < 8: - return False - if props.minor < 9: - return False - - if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3): - return False - - if WINDOWS: - if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4): - return False - - return True - -def soft_empty_cache(force=False): - global cpu_state - if cpu_state == CPUState.MPS: - torch.mps.empty_cache() - elif is_intel_xpu(): - torch.xpu.empty_cache() - elif is_ascend_npu(): - torch.npu.empty_cache() - elif is_mlu(): - torch.mlu.empty_cache() - elif torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - -def unload_all_models(): - free_memory(1e30, get_torch_device()) - - -#TODO: might be cleaner to put this somewhere else -import threading - +# Interrupt handling class InterruptProcessingException(Exception): pass interrupt_processing_mutex = threading.RLock() - interrupt_processing = False + def interrupt_current_processing(value=True): + """Set interrupt flag for processing.""" global interrupt_processing global interrupt_processing_mutex with interrupt_processing_mutex: interrupt_processing = value +def lowvram_enabled(): + """Check if low VRAM mode is enabled.""" + return vram_state == VRAMState.LOW_VRAM + +def noram_enabled(): + """Check if no VRAM mode is enabled.""" + return vram_state == VRAMState.NO_VRAM + + def processing_interrupted(): + """Check if processing is interrupted.""" global interrupt_processing global interrupt_processing_mutex with interrupt_processing_mutex: return interrupt_processing def throw_exception_if_processing_interrupted(): + """Throw exception if processing is interrupted.""" global interrupt_processing global interrupt_processing_mutex with interrupt_processing_mutex: if interrupt_processing: interrupt_processing = False raise InterruptProcessingException() + +def controlnet_device(): + """Determine device for ControlNet (GPU or CPU).""" + if args.gpu_only: + return get_torch_device() + return torch.device("cpu") + +def controlnet_dtype(device=None, model=None): + """Select appropriate data type for ControlNet.""" + if device is None: + device = get_torch_device() + model_params = module_size(model) // 4 if model is not None else 0 + if args.force_fp16: + if supports_cast(torch.float16, device): + logging.debug(f"ControlNet dtype: torch.float16 (forced via --force-fp16)") + return torch.float16 + logging.debug(f"ControlNet dtype: torch.float32 (FP16 not supported)") + return torch.float32 + if args.force_fp32: + logging.debug(f"ControlNet dtype: torch.float32 (forced via --force-fp32)") + return torch.float32 + if should_use_bf16(device=device, model_params=model_params, prioritize_performance=False): + logging.debug(f"ControlNet dtype: torch.bfloat16 (device supports BF16)") + return torch.bfloat16 + if should_use_fp16(device=device, model_params=model_params, prioritize_performance=False): + logging.debug(f"ControlNet dtype: torch.float16 (device supports FP16)") + return torch.float16 + logging.debug(f"ControlNet dtype: torch.float32 (default fallback)") + return torch.float32 + +def cleanup_models(): + """Clean up models on finalization.""" + soft_empty_cache(clear=False, caller="cleanup_models") + # Check memory state after soft clear + device = get_torch_device() + cache_key = (device, 'free_memory') + mem_free_total, _ = _device_cache.get(cache_key, (0, 0)) + if mem_free_total == 0: + mem_free_total, _ = get_free_memory(device, torch_free_too=True) + total_vram = get_total_memory(device) + if mem_free_total < 0.4 * total_vram: + if PROFILING_ENABLED: + logging.debug(f"cleanup_models: Insufficient VRAM ({mem_free_total/1024**3:.2f} GB < 20% of {total_vram/1024**3:.2f} GB), forcing aggressive clear") + soft_empty_cache(clear=True, caller="cleanup_models_aggressive") + +# Profiling context manager +@contextlib.contextmanager +def profile_section(name): + """Context manager for profiling code sections.""" + if PROFILING_ENABLED: + start = time.time() + if DEBUG_ENABLED: + stack = [frame for frame in traceback.format_stack( + limit=10) if "model_management" in frame] + logging.debug(f"Starting {name}, stack: {''.join(stack)}") + try: + yield + finally: + logging.debug(f"{name}: {time.time() - start:.3f} s") + else: + yield + +def mac_version(): + """Get macOS version if running on macOS.""" + if platform.system() == "Darwin": + try: + version = platform.mac_ver()[0] + version_parts = version.split(".") + return (int(version_parts[0]), int(version_parts[1])) + except: + return None + return None + +# Additional utilities for memory management +def get_device_memory_info(device=None): + """Get detailed memory information for a device.""" + if device is None: + device = get_torch_device() + mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) + mem_total = get_total_memory(device) + return { + "free_total": mem_free_total, + "free_torch": mem_free_torch, + "total": mem_total, + "used": mem_total - mem_free_total + } + +def optimize_memory_for_device(device=None): + """Optimize memory settings based on device capabilities.""" + if device is None: + device = get_torch_device() + total_vram = get_total_memory(device) / (1024 * 1024 * 1024) # VRAM in GB + global vram_state + if total_vram < 3.9: + vram_state = VRAMState.NO_VRAM + logging.info(f"Low VRAM ({total_vram:.1f} GB), enabling NO_VRAM mode") + elif total_vram < 7.9: + vram_state = VRAMState.LOW_VRAM + logging.info(f"Moderate VRAM ({total_vram:.1f} GB), enabling LOW_VRAM mode") + else: + vram_state = VRAMState.NORMAL_VRAM + logging.info(f"Sufficient VRAM ({total_vram:.1f} GB), using NORMAL_VRAM mode") + +# Initialize device and memory settings +try: + optimize_memory_for_device() + if PROFILING_ENABLED: + logging.debug("Memory optimization completed") +except Exception as e: + logging.error(f"Failed to optimize memory: {e}") + vram_state = VRAMState.DISABLED + +def get_device_cache_state(): + """Return the current state of _device_cache for logging.""" + return _device_cache \ No newline at end of file diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b7cb12dfcf1..48a5695f4d9 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -34,6 +34,15 @@ import comfy.patcher_extension from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.comfy_types import UnetWrapperFunction +from comfy.cli_args import args +from comfy.ldm.models.autoencoder import AutoencoderKL +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.model_management import get_free_memory, get_torch_device + +# Global flag for profiling +PROFILING_ENABLED = args.profile +DEBUG_ENABLED = args.debug +VERBOSE_ENABLED = False def string_to_seed(data): crc = 0xFFFFFFFF @@ -200,10 +209,13 @@ def decrement(self, used: int): class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + if PROFILING_ENABLED: + logging.debug(f"ModelPatcher init: model={type(model).__name__}, load_device={load_device}, offload_device={offload_device}, size={size / (1024**3):.2f} GB") self.size = size self.model = model if not hasattr(self.model, 'device'): - logging.debug("Model doesn't have a device attribute.") + if DEBUG_ENABLED: + logging.debug("Model doesn't have a device attribute.") self.model.device = offload_device elif self.model.device is None: self.model.device = offload_device @@ -323,8 +335,44 @@ def clone(self): return n def is_clone(self, other): - if hasattr(other, 'model') and self.model is other.model: + """ + Check if another ModelPatcher is a clone of this one. + Compares model type, patches_uuid, patches content, and base model equivalence. + """ + if not isinstance(other, ModelPatcher) or not hasattr(other, 'model'): + if DEBUG_ENABLED: + logging.debug("[DEBUG_CLONES] Not clones: invalid other ModelPatcher") + return False + + if self is other: + if DEBUG_ENABLED: + logging.debug("[DEBUG_CLONES] Models are clones: same ModelPatcher object") return True + + if self.model.__class__ != other.model.__class__: + if DEBUG_ENABLED: + logging.debug(f"[DEBUG_CLONES] Not clones: different model types {self.model.__class__.__name__} vs {other.model.__class__.__name__}") + return False + + if self.patches_uuid == other.patches_uuid: + if self.patches != other.patches: + if DEBUG_ENABLED: + logging.debug(f"[DEBUG_CLONES] Not clones: same patches_uuid={self.patches_uuid}, but different patches") + return False + if DEBUG_ENABLED: + logging.debug(f"[DEBUG_CLONES] Models are clones: same patches_uuid={self.patches_uuid} and matching patches") + return True + + self_base = getattr(self.model, 'real_model', getattr(self.model, 'model', self.model)) + other_base = getattr(other.model, 'real_model', getattr(other.model, 'model', other.model)) + + if self_base is other_base: + if DEBUG_ENABLED: + logging.debug(f"[DEBUG_CLONES] Models are clones: same base model object, type={self.model.__class__.__name__}") + return True + + if DEBUG_ENABLED: + logging.debug(f"[DEBUG_CLONES] Not clones: different base model objects, type={self.model.__class__.__name__}") return False def clone_has_same_weights(self, clone: 'ModelPatcher'): @@ -584,6 +632,9 @@ def _load_list(self): return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): + # Set default device if device_to is None to avoid errors in get_free_memory + device = device_to if device_to is not None else comfy.model_management.get_torch_device() + with self.use_ejected(): self.unpatch_hooks() mem_counter = 0 @@ -591,6 +642,30 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False lowvram_counter = 0 loading = self._load_list() + if PROFILING_ENABLED: + # Determine module name for logging + module_name = "Unknown" + if isinstance(self.model, (AutoencoderKL, comfy.sd.VAE)): + module_name = "VAE" + elif isinstance(self.model, UNetModel): + module_name = "UNet" + elif hasattr(self, "is_clip") and self.is_clip: + module_name = "CLIP" + elif "diffusion_model" in str(type(self.model)): + module_name = "DiffusionModel" + elif isinstance(self.model, torch.nn.Module): + module_name = f"{type(self.model).__name__}" + logging.debug(f"Loading module: {module_name}, type: {type(self.model).__name__}") + + # Validate and normalize lowvram_model_memory + if not isinstance(lowvram_model_memory, (int, float)) or lowvram_model_memory < 0: + if PROFILING_ENABLED: + logging.warning(f"Invalid lowvram_model_memory: {lowvram_model_memory}, resetting to 0") + lowvram_model_memory = 0 + + if PROFILING_ENABLED: + logging.debug(f"ModelPatcher.load: model type: {type(self.model).__name__}, device_to: {device_to}, lowvram_model_memory: {lowvram_model_memory / (1024 * 1024):.2f} MB, full_load: {full_load}") + load_completely = [] loading.sort(reverse=True) for x in loading: @@ -604,11 +679,23 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) + is_vae = isinstance(self.model, (AutoencoderKL, comfy.sd.VAE)) + if VERBOSE_ENABLED: + logging.debug(f"Processing module: {n}, type: {type(m).__name__}, is_vae: {is_vae}, module_mem: {module_mem / (1024 * 1024):.2f} MB") + + # Skip VAE module if already loaded on the target device + if is_vae and hasattr(self.model, 'first_stage_model') and hasattr(self.model.first_stage_model, 'device') and self.model.first_stage_model.device == device and hasattr(self.model, '_loaded_to_device') and self.model._loaded_to_device == device: + if PROFILING_ENABLED: + logging.debug(f"Skipping VAE module {n}, already on {device} with _loaded_to_device={self.model._loaded_to_device}") + continue + if not full_load and hasattr(m, "comfy_cast_weights"): if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed + if VERBOSE_ENABLED: + logging.debug(f"Skipping module {n}: already in lowvram mode") continue cast_weight = self.force_cast_weights @@ -630,6 +717,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False m.bias_function = [LowVramPatch(bias_key, self.patches)] patch_counter += 1 + if VERBOSE_ENABLED: + logging.debug(f"Module {n} set to lowvram, weight_key={weight_key}, bias_key={bias_key}, patch_counter={patch_counter}, lowvram_weight={lowvram_weight}") + cast_weight = True else: if hasattr(m, "comfy_cast_weights"): @@ -638,6 +728,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False if full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + if VERBOSE_ENABLED: + logging.debug(f"Module {n} added to load_completely, mem_counter={mem_counter / (1024**3):.2f} GB") if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -649,7 +741,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False if bias_key in self.weight_wrapper_patches: m.bias_function.extend(self.weight_wrapper_patches[bias_key]) - mem_counter += move_weight_functions(m, device_to) + mem_counter += move_weight_functions(m, device) + if VERBOSE_ENABLED: + logging.debug(f"Moved weight functions for {n} to device={device}, mem_counter={mem_counter / (1024**3):.2f} GB") load_completely.sort(reverse=True) for x in load_completely: @@ -658,34 +752,49 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False params = x[3] if hasattr(m, "comfy_patched_weights"): if m.comfy_patched_weights == True: + if VERBOSE_ENABLED: + logging.debug(f"Skipping module {n}: already patched") continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + if VERBOSE_ENABLED: + logging.debug(f"Loaded module {n} regularly, lowvram={self.model.model_lowvram}") m.comfy_patched_weights = True for x in load_completely: - x[2].to(device_to) + x[2].to(device) + if VERBOSE_ENABLED: + logging.debug(f"Moved module {x[1]} to device={device}") + # Safe logging with module name + lowvram_mb = lowvram_model_memory / (1024 * 1024) + mem_counter_mb = mem_counter / (1024 * 1024) if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + if PROFILING_ENABLED: + logging.info(f"Loaded partially {module_name}: {lowvram_mb:.2f} MB, {mem_counter_mb:.2f} MB, patches: {patch_counter}") self.model.model_lowvram = True else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + if PROFILING_ENABLED: + logging.info(f"Loaded completely {module_name}: {lowvram_mb:.2f} MB, {mem_counter_mb:.2f} MB, full_load: {full_load}") self.model.model_lowvram = False if full_load: - self.model.to(device_to) + self.model.to(device) mem_counter = self.model_size() + if PROFILING_ENABLED: + logging.info(f"Moved entire model to device: {device}, mem_counter: {mem_counter / (1024 * 1024):.2f} MB") self.model.lowvram_patch_counter += patch_counter - self.model.device = device_to + self.model.device = device self.model.model_loaded_weight_memory = mem_counter self.model.current_weight_patches_uuid = self.patches_uuid + if PROFILING_ENABLED: + logging.debug(f"Load completed: model type: {type(self.model).__name__}, device: {self.model.device}, loaded_weight_memory: {self.model.model_loaded_weight_memory / (1024 * 1024):.2f} MB, lowvram: {self.model.model_lowvram}, patch_counter: {self.model.lowvram_patch_counter}") + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): - callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + callback(self, device, lowvram_model_memory, force_patch_weights, full_load) self.apply_hooks(self.forced_hooks, force_apply=True) @@ -696,6 +805,13 @@ def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, if k not in self.object_patches_backup: self.object_patches_backup[k] = old + # Validate and normalize lowvram_model_memory + if PROFILING_ENABLED: + logging.debug(f"patch_model: model={type(self.model).__name__}, lowvram_model_memory={lowvram_model_memory / (1024**3):.2f} GB, device_to={device_to}") + if not isinstance(lowvram_model_memory, (int, float)) or lowvram_model_memory < 0: + logging.warning(f"Invalid lowvram_model_memory in patch_model: {lowvram_model_memory}, resetting to 0") + lowvram_model_memory = 0 + if lowvram_model_memory == 0: full_load = True else: @@ -812,21 +928,37 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): with self.use_ejected(skip_and_inject_on_exit_only=True): unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) # TODO: force_patch_weights should not unload + reload full model + if PROFILING_ENABLED: + logging.debug(f"partially_load: unpatch_weights={unpatch_weights}, patches_uuid={self.patches_uuid}, current_weight_patches_uuid={self.model.current_weight_patches_uuid}") used = self.model.model_loaded_weight_memory + if PROFILING_ENABLED: + logging.debug(f"partially_load: used={used / (1024**3):.2f} GB, model_loaded_weight_memory={self.model.model_loaded_weight_memory / (1024**3):.2f} GB") + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) if unpatch_weights: extra_memory += (used - self.model.model_loaded_weight_memory) + if PROFILING_ENABLED: + logging.debug(f"partially_load: updated extra_memory={extra_memory / (1024**3):.2f} GB after unpatch, used={used / (1024**3):.2f} GB, model_loaded_weight_memory={self.model.model_loaded_weight_memory / (1024**3):.2f} GB") self.patch_model(load_weights=False) full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: self.apply_hooks(self.forced_hooks, force_apply=True) + if PROFILING_ENABLED: + logging.debug(f"partially_load: early return, model_lowvram={self.model.model_lowvram}, model_loaded_weight_memory={self.model.model_loaded_weight_memory / (1024**3):.2f} GB") return 0 if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): full_load = True + if PROFILING_ENABLED: + logging.debug(f"partially_load: full_load=True, model_loaded_weight_memory={self.model.model_loaded_weight_memory / (1024**3):.2f} GB, extra_memory={extra_memory / (1024**3):.2f} GB, model_size={self.model_size() / (1024**3):.2f} GB") + current_used = self.model.model_loaded_weight_memory + lowvram_model_memory = current_used + extra_memory + if PROFILING_ENABLED: + logging.debug(f"partially_load: calling load with lowvram_model_memory={lowvram_model_memory / (1024**3):.2f} GB, current_used={current_used / (1024**3):.2f} GB, extra_memory={extra_memory / (1024**3):.2f} GB") + try: - self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) + self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) except Exception as e: self.detach() raise e @@ -834,12 +966,34 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): return self.model.model_loaded_weight_memory - current_used def detach(self, unpatch_all=True): + if PROFILING_ENABLED: + free_vram_before = get_free_memory(get_torch_device()) / 1024**3 + logging.debug(f"detach: Before, free_vram={free_vram_before:.2f} GB, model={self.model.__class__.__name__}") + + if hasattr(self.model, 'on_patched'): + if DEBUG_ENABLED: + logging.debug(f"Calling on_patched for {self.model.__class__.__name__}") + self.model.on_patched() self.eject_model() self.model_patches_to(self.offload_device) if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) - for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): - callback(self, unpatch_all) + self.unpatch_model(self.offload_device) + + if hasattr(self.model, 'to'): + self.model.to(self.offload_device) + self.model.device = self.offload_device + self.model.model_loaded_weight_memory = 0 + + self.patches = [] + self.model_patches = 0 + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if PROFILING_ENABLED: + free_vram_after = get_free_memory(get_torch_device()) / 1024**3 + logging.debug(f"detach: After, free_vram={free_vram_after:.2f} GB, freed={(free_vram_after-free_vram_before):.2f} GB, model={self.model.__class__.__name__}") + return self.model def current_loaded_device(self): @@ -1205,5 +1359,4 @@ def clean_hooks(self): self.clear_cached_hook_weights() def __del__(self): - self.detach(unpatch_all=False) - + self.detach(unpatch_all=False) \ No newline at end of file diff --git a/comfy/ops.py b/comfy/ops.py index 431c8f89d2e..8d4219073e0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,16 +18,16 @@ import torch import logging -import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib -cast_to = comfy.model_management.cast_to #TODO: remove once no more references +from comfy.utils import cast_to +cast_to = cast_to # Maintain compatibility with code expecting comfy.ops.cast_to def cast_to_input(weight, input, non_blocking=False, copy=True): - return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) + return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: @@ -48,7 +48,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) + bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: with wf_context: @@ -56,7 +56,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): bias = f(bias) has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) + weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: with wf_context: for f in s.weight_function: @@ -308,10 +308,10 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() + input = input.reshape(-1, input_shape[2]).to(dtype) else: scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() + input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype) if bias is not None: o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) diff --git a/comfy/utils.py b/comfy/utils.py index 561e1b85859..22c3d8c13c4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1078,3 +1078,23 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): dim=1 ) return out + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + """Cast tensor to specified dtype and device.""" + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + if stream is not None: + with stream: + return weight.to(dtype=dtype, copy=copy) + return weight.to(dtype=dtype, copy=copy) + + if stream is not None: + with stream: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + else: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + return r \ No newline at end of file diff --git a/fast_sampler.py b/fast_sampler.py new file mode 100644 index 00000000000..f8f62e12ae5 --- /dev/null +++ b/fast_sampler.py @@ -0,0 +1,477 @@ +import torch +import comfy +import gc +import time +from torch.amp import autocast +from comfy.cli_args import args +from comfy.model_management import get_torch_device, vae_dtype, soft_empty_cache, free_memory, force_channels_last, estimate_vae_decode_memory, device_supports_non_blocking +from contextlib import contextmanager +import latent_preview +import logging + +# Global flag for profiling +PROFILING_ENABLED = args.profile +DEBUG_ENABLED = args.debug +CUDNN_BENCHMARK_ENABLED = getattr(args, 'cudnn_benchmark', False) # Default: False + +# Configure logging +logging.basicConfig(level=logging.DEBUG if PROFILING_ENABLED or DEBUG_ENABLED else logging.INFO) + +# Cache for FP16 safety check +_fp16_safe_cache = {} + +@contextmanager +def profile_section(name): + """Context manager for profiling execution time.""" + if PROFILING_ENABLED: + start = time.time() + try: + yield + finally: + logging.debug(f"{name}: {time.time() - start:.3f} s") + else: + yield + +def profile_cuda_sync(is_gpu, message="CUDA sync"): + """Profile CUDA synchronization time if GPU is used.""" + if PROFILING_ENABLED and is_gpu: + logging.debug(f"{message} started") + sync_start = time.time() + torch.cuda.synchronize() + logging.debug(f"{message} took {time.time() - sync_start:.3f} s") + +def is_fp16_safe(device): + """Check if FP16 is safe for the GPU (disabled for GTX 1660/Turing).""" + if device.type != 'cuda': + return False + if device in _fp16_safe_cache: + return _fp16_safe_cache[device] + try: + props = torch.cuda.get_device_properties(device) + is_safe = props.major >= 8 or props.compute_capability[0] > 7 + _fp16_safe_cache[device] = is_safe + return is_safe + except Exception: + _fp16_safe_cache[device] = False + return False + +def initialize_device_and_dtype(model, device=None): + """Initialize device and dtype from model.""" + if device is None: + device = get_torch_device() + dtype = getattr(model, 'dtype', torch.float32) + is_gpu = device.type == 'cuda' and torch.cuda.is_available() + return device, dtype, is_gpu + +def clear_vram(device, threshold=0.5, min_free=1.5): + """Clear VRAM if usage exceeds threshold or free memory is below min_free (in GB).""" + if device.type == 'cuda': + if PROFILING_ENABLED: + start_time = time.time() + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + critical_threshold = 0.05 * mem_total + 0.1 # 5% VRAM + 100 MB + if mem_allocated > threshold * mem_total or (mem_total - mem_allocated) < max(min_free, critical_threshold): + logging.debug(f"Clearing VRAM: allocated {mem_allocated:.2f} GB, free {mem_total - mem_allocated:.2f} GB, threshold {critical_threshold:.2f} GB") + torch.cuda.empty_cache() + #soft_empty_cache(clear=False) + mem_after = torch.cuda.memory_allocated(device) / 1024**3 + if PROFILING_ENABLED: + logging.debug(f"VRAM cleared: {mem_allocated:.2f} GB -> {mem_after:.2f} GB, took {time.time() - start_time:.3f} s") + else: + if PROFILING_ENABLED: + logging.debug(f"VRAM not cleared: {mem_allocated:.2f} GB / {mem_total:.2f} GB, sufficient free memory") + return mem_allocated, mem_total + +def preload_model(model, device, is_vae=False): + """Preload model or VAE to device, avoiding unnecessary unloading.""" + with profile_section("Model preload"): + if is_vae: + if PROFILING_ENABLED: + start_time = time.time() + logging.debug(f"Checking VAE device for {model.__class__.__name__}") + + # Check if VAE is already loaded + if (hasattr(model, 'first_stage_model') and + hasattr(model.first_stage_model, 'device') and + model.first_stage_model.device == device and + hasattr(model, '_loaded_to_device') and + model._loaded_to_device == device): + if PROFILING_ENABLED: + logging.debug(f"VAE already loaded on {device}, skipping transfer, check took {time.time() - start_time:.3f} s") + return + + # Load VAE + if PROFILING_ENABLED: + logging.debug(f"Loading VAE to {device}") + transfer_start = time.time() + model.first_stage_model.to(device) + model._loaded_to_device = device + if PROFILING_ENABLED: + logging.debug(f"VAE transferred to {device}, took {time.time() - transfer_start:.3f} s") + logging.debug(f"VAE first_stage_model device: {model.first_stage_model.device}") + logging.debug(f"VAE has decode_tiled: {hasattr(model, 'decode_tiled')}") + else: + # Check if model is already loaded + if hasattr(model, '_loaded_to_device') and model._loaded_to_device == device: + if PROFILING_ENABLED: + logging.debug(f"Model already loaded on {device}, skipping preload") + return + # Load U-Net + if PROFILING_ENABLED: + logging.debug(f"Loading U-Net {model.__class__.__name__} to {device}") + torch.cuda.empty_cache() + comfy.model_management.load_model_gpu(model) + model._loaded_to_device = device + if PROFILING_ENABLED: + free_mem = (torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)) / 1024**3 + logging.debug(f"U-Net loaded to {device}, VRAM free: {free_mem:.2f} GB") + +def optimized_transfer(tensor, device, dtype): + """Synchronous tensor transfer to device.""" + pin_memory = comfy.model_management.is_device_cuda(device) + if isinstance(tensor, torch.Tensor) and tensor.device != device: + tensor = tensor.to(device=device, dtype=dtype, pin_memory=pin_memory) + return tensor + +def optimized_conditioning(conditioning, device, dtype): + """Efficiently transfer conditioning tensors.""" + return [ + optimized_transfer(p, device, dtype) if isinstance(p, torch.Tensor) else p + for p in conditioning + ] + +def finalize_images(images, device): + """Process and finalize output images.""" + if len(images.shape) == 5: # Combine batches + images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) + return images.to(device=device, memory_format=torch.channels_last) + +def fast_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise, disable_noise, start_step, last_step, force_full_denoise, noise_mask, callback, seed, device, dtype, is_gpu): + """Optimized sampling function.""" + if PROFILING_ENABLED: + start_time = time.time() + logging.debug(f"Starting sampling") + + with torch.no_grad(): + use_amp = is_gpu and dtype == torch.float16 and is_fp16_safe(device) + with autocast(device_type='cuda', enabled=use_amp): + samples = comfy.sample.sample( + model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, + start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, + noise_mask=noise_mask, callback=callback, seed=seed + ) + samples = samples.to(device=device, dtype=dtype, memory_format=torch.channels_last) + + if PROFILING_ENABLED: + logging.debug(f"Sampling completed, took {time.time() - start_time:.3f} s") + + return samples + +def fast_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, + denoise=1.0, disable_noise=False, start_step=None, last_step=None, + force_full_denoise=False, device=None, dtype=None, is_gpu=None): + """ + Fast KSampler implementation with optimized memory management and optional cuDNN benchmark. + """ + if DEBUG_ENABLED: + if model is None: + logging.warning("fast_ksampler: model is None") + + if device is None or dtype is None or is_gpu is None: + device, dtype, is_gpu = initialize_device_and_dtype(model.model) + + try: + # Enable cuDNN benchmarking if requested + if is_gpu and comfy.model_management.is_device_cuda(device) and CUDNN_BENCHMARK_ENABLED: + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # Check and move model parameters once + if is_gpu: + if not hasattr(model, '_device_checked') or not model._device_checked: + for param in model.model.parameters(): + if param.device.type != device.type: + if DEBUG_ENABLED: + logging.warning(f"U-Net parameter {param.shape} on {param.device.type}, moving to {device}") + model.model.to(device) + if PROFILING_ENABLED: + logging.debug(f"VRAM after moving U-Net: {torch.cuda.memory_allocated(device)/1024**3:.2f} GB") + model._device = device + model._device_checked = True + break + if hasattr(model, 'control_model'): + for param in model.control_model.parameters(): + if param.device.type != device.type: + if DEBUG_ENABLED: + logging.warning(f"ControlNet parameter {param.shape} on {param.device.type}, moving to {device}") + model.control_model.to(device) + if PROFILING_ENABLED: + logging.debug(f"VRAM after moving ControlNet: {torch.cuda.memory_allocated(device)/1024**3:.2f} GB") + model._control_device = device + model._device_checked = True + break + + # Preload model + preload_model(model, device) + + # Transfer latents + with profile_section("Latent transfer"): + latent_image = latent["samples"] + latent_image = optimized_transfer(latent_image, device, dtype) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + + # Transfer conditioning + with profile_section("Conditioning transfer"): + positive = optimized_conditioning(positive, device, dtype) + negative = optimized_conditioning(negative, device, dtype) + + # Prepare noise + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) + + # Handle noise mask if present + noise_mask = latent.get("noise_mask") + if noise_mask is not None: + noise_mask = optimized_transfer(noise_mask, device, dtype) + + # Allocate output tensor + samples = torch.empty_like(latent_image, device=device, dtype=dtype) + + # Perform sampling + with torch.no_grad(): + callback = None if not comfy.utils.PROGRESS_BAR_ENABLED else latent_preview.prepare_callback(model, steps) + samples = fast_sample( + model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise, disable_noise, start_step, last_step, force_full_denoise, noise_mask, callback, seed, + device, dtype, is_gpu + ) + + # Log VRAM state after sampling + if is_gpu and PROFILING_ENABLED: + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + logging.debug(f"VRAM after sampling: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + + # Log completion of sampling + if PROFILING_ENABLED: + logging.debug(f"Sampling completed, preparing for VAE") + profile_cuda_sync(is_gpu) + + # Clear VRAM after sampling + if is_gpu: + if not PROFILING_ENABLED: + clear_vram(device, threshold=0.5, min_free=1.5) + else: + clear_start = time.time() + mem_allocated, mem_total = clear_vram(device, threshold=0.5, min_free=1.5) + logging.debug(f"VRAM after sampling: {mem_allocated:.2f} GB / {mem_total:.2f} GB, clear took {time.time() - clear_start:.3f} s") + logging.debug(f"Post-VRAM checkpoint: {time.time()}") + + out = latent.copy() + out["samples"] = samples + return (out,) + + finally: + if PROFILING_ENABLED: + finally_start = time.time() + if is_gpu and CUDNN_BENCHMARK_ENABLED: + torch.backends.cudnn.benchmark = False + if PROFILING_ENABLED: + logging.debug(f"Final cleanup took {time.time() - finally_start:.3f} s") + +def fast_vae_decode(vae, samples): + """ + Fast VAE decoding with FP16, channels_last, universal VRAM management, and full logging. + """ + device = get_torch_device() + vae_dtype_val = vae_dtype(device=device) + is_gpu = device.type == 'cuda' and torch.cuda.is_available() + + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: {vae_dtype_val}") + logging.debug(f"Pre-VAE checkpoint: {time.time()}") + + try: + # Disable cuDNN benchmark for VAE stability if enabled + if is_gpu and comfy.model_management.is_device_cuda(device) and CUDNN_BENCHMARK_ENABLED: + torch.backends.cudnn.benchmark = False + + # Prepare VRAM for VAE + if is_gpu: + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + latent_size = samples["samples"].shape + model_for_memory = getattr(vae, 'first_stage_model', vae) + vae_memory_required = estimate_vae_decode_memory(model_for_memory, latent_size, vae_dtype_val) / 1024**3 + vram_threshold = 1.0 if mem_total < 5.9 else 1.1 + vae_memory_required *= vram_threshold + if PROFILING_ENABLED: + logging.debug(f"Estimated VAE memory: {vae_memory_required:.2f} GB") + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + free_mem = mem_total - mem_allocated + if free_mem < vae_memory_required: + #free_memory(vae_memory_required) + mem_allocated, mem_total = clear_vram(device, threshold=0.4, min_free=2.0) + if PROFILING_ENABLED: + logging.debug(f"VRAM after free_memory: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + + # Preload VAE to device + preload_model(vae, device, is_vae=True) + + # Transfer latents with channels_last + with profile_section("VAE latent transfer"): + non_blocking = is_gpu and device_supports_non_blocking(device) + latent_samples = samples["samples"].to(device, dtype=vae_dtype_val, non_blocking=non_blocking) + if is_gpu and force_channels_last(): + latent_samples = latent_samples.to(memory_format=torch.channels_last) + vae.first_stage_model.to(memory_format=torch.channels_last) + if PROFILING_ENABLED: + logging.debug(f"Latent samples device: {latent_samples.device}, dtype: {latent_samples.dtype}") + + # Decode latents + with torch.no_grad(): + use_amp = is_gpu and is_fp16_safe(device) + with autocast(device_type='cuda', enabled=use_amp, dtype=torch.float16 if use_amp else torch.float32): + if PROFILING_ENABLED: + logging.debug(f"Decoding VAE, use_amp={use_amp}") + decode_start = time.time() + images = vae.decode(latent_samples).clamp(0, 1) + if PROFILING_ENABLED: + logging.debug(f"VAE decode took {time.time() - decode_start:.3f} s") + images = finalize_images(images, device) + + return (images,) + + except Exception as e: + if PROFILING_ENABLED: + logging.error(f"VAE decode failed: {e}\n{traceback.format_exc()}") + raise + finally: + if PROFILING_ENABLED: + finally_start = time.time() + if PROFILING_ENABLED: + logging.debug(f"Final cleanup took {time.time() - finally_start:.3f} s") + +def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size=64, temporal_overlap=8): + """Fast VAE decoding with tiling for low VRAM, consistent with fast_vae_decode.""" + device, dtype, is_gpu = initialize_device_and_dtype(vae) + vae_dtype = vae_dtype(device=device) + if DEBUG_ENABLED: + logging.debug(f"VAE dtype: {vae_dtype}") + logging.debug(f"Pre-VAE checkpoint: {time.time()}") + + try: + # Disable cuDNN benchmark for tiled decoding stability if enabled + if is_gpu and comfy.model_management.is_device_cuda(device) and CUDNN_BENCHMARK_ENABLED: + torch.backends.cudnn.benchmark = False # Ensure stability for variable tile sizes + + # Clear VRAM before VAE + if is_gpu: + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + free_mem = mem_total - mem_allocated + # Estimate memory for tiled decoding (conservative, ~50% of full decode) + vae_memory_required = (vae.memory_used_decode(samples["samples"].shape, vae_dtype) / 1024**3 * 0.5 + if hasattr(vae, 'memory_used_decode') else 0.75) + if PROFILING_ENABLED: + logging.debug(f"VRAM before tiled VAE: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + logging.debug(f"Estimated tiled VAE memory: {vae_memory_required:.2f} GB") + + # Skip VRAM cleanup if VAE is already loaded and memory is sufficient + if (hasattr(vae, '_loaded_to_device') and vae._loaded_to_device == device and + free_mem >= vae_memory_required * 1.1): + if PROFILING_ENABLED: + logging.debug(f"VAE already loaded, sufficient memory: {free_mem:.2f} GB") + elif mem_allocated > 0.4 * mem_total or free_mem < vae_memory_required: + if PROFILING_ENABLED: + logging.debug(f"Clearing VRAM: {mem_allocated:.2f} GB used of {mem_total:.2f} GB") + mem_allocated, mem_total = clear_vram(device, threshold=0.4, min_free=0.75) + + # Preload VAE + if not PROFILING_ENABLED: + preload_model(vae, device, is_vae=True) + else: + preload_start = time.time() + preload_model(vae, device, is_vae=True) + logging.debug(f"VAE preload took {time.time() - preload_start:.3f} s") + logging.debug(f"Post-preload checkpoint: {time.time()}") + + # Transfer latents + with profile_section("VAE latent transfer"): + latent_samples = samples["samples"] + if PROFILING_ENABLED: + logging.debug(f"Latent samples device: {latent_samples.device}, dtype: {latent_samples.dtype}") + latent_samples = optimized_transfer(latent_samples, device, vae_dtype) + if is_gpu and force_channels_last(): + latent_samples = latent_samples.to(memory_format=torch.channels_last) + vae.first_stage_model.to(memory_format=torch.channels_last) + + # Log before decoding + if PROFILING_ENABLED: + logging.debug(f"Starting tiled VAE decoding") + logging.debug(f"Pre-decode checkpoint: {time.time()}") + + with torch.no_grad(): + use_amp = is_gpu and is_fp16_safe(device) + with autocast(device_type='cuda', enabled=use_amp, dtype=torch.float16 if use_amp else torch.float32): + if PROFILING_ENABLED: + logging.debug(f"Tiled VAE decoding with tile_size={tile_size}, overlap={overlap}, " + f"temporal_size={temporal_size}, temporal_overlap={temporal_overlap}, use_amp={use_amp}, dtype={'torch.float16' if use_amp else 'torch.float32'}") + + # Adjust tile parameters + if tile_size < overlap * 4: + overlap = tile_size // 4 + if temporal_size < temporal_overlap * 2: + temporal_overlap = temporal_overlap // 2 + + temporal_compression = getattr(vae, 'temporal_compression_decode', lambda: None)() + spacial_compression = getattr(vae, 'spacial_compression_decode', lambda: 8)() + + if temporal_compression is not None: + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression)) + else: + temporal_size = None + temporal_overlap = None + + # Perform tiled decoding + decode_start = time.time() + images = vae.decode_tiled( + latent_samples, + tile_x=tile_size // spacial_compression, + tile_y=tile_size // spacial_compression, + overlap=overlap // spacial_compression, + tile_t=temporal_size, + overlap_t=temporal_overlap + ) + if PROFILING_ENABLED: + logging.debug(f"VAE tiled decode took {time.time() - decode_start:.3f} s") + + images = finalize_images(images, device) + + if is_gpu and PROFILING_ENABLED: + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + logging.debug(f"VRAM after tiled decoding: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + logging.debug(f"Post-decode checkpoint: {time.time()}") + + if PROFILING_ENABLED: + logging.debug(f"VAE tiled decode finished, returning images: {time.time()}") + return (images,) + + except Exception as e: + logging.error(f"VAE tiled decode failed: {e}\n{traceback.format_exc()}") + raise + finally: + if PROFILING_ENABLED: + finally_start = time.time() + if PROFILING_ENABLED: + logging.debug(f"Final cleanup took {time.time() - finally_start:.3f} s") + logging.debug(f"Post-final cleanup checkpoint: {time.time()}") \ No newline at end of file diff --git a/folder_paths.py b/folder_paths.py index f0b3fd10373..2a2844e39ee 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -9,6 +9,8 @@ from comfy.cli_args import args +DEBUG_ENABLED = args.debug + supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {} @@ -245,7 +247,8 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) except FileNotFoundError: logging.warning(f"Warning: Unable to access {directory}. Skipping this path.") - logging.debug("recursive file list on directory {}".format(directory)) + if DEBUG_ENABLED: + logging.debug("recursive file list on directory {}".format(directory)) dirpath: str subdirs: list[str] filenames: list[str] @@ -267,7 +270,8 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) except FileNotFoundError: logging.warning(f"Warning: Unable to access {path}. Skipping this path.") continue - logging.debug("found {} files".format(len(result))) + if DEBUG_ENABLED: + logging.debug("found {} files".format(len(result))) return result, dirs def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]: diff --git a/main.py b/main.py index 221e48e41e6..1cad01b90e6 100644 --- a/main.py +++ b/main.py @@ -8,9 +8,13 @@ from comfy.cli_args import args from app.logger import setup_logger import itertools +import comfy.model_management import utils.extra_config import logging import sys +import atexit + +atexit.register(comfy.model_management.soft_empty_cache, clear=True) if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. diff --git a/nodes.py b/nodes.py index 54e3886a306..28a38035bba 100644 --- a/nodes.py +++ b/nodes.py @@ -27,6 +27,10 @@ import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator +from fast_sampler import fast_vae_decode +from fast_sampler import fast_ksampler +from fast_sampler import fast_vae_tiled_decode + import comfy.clip_vision import comfy.model_management @@ -282,49 +286,34 @@ def INPUT_TYPES(s): RETURN_TYPES = ("IMAGE",) OUTPUT_TOOLTIPS = ("The decoded image.",) FUNCTION = "decode" - CATEGORY = "latent" DESCRIPTION = "Decodes latent images back into pixel space images." def decode(self, vae, samples): - images = vae.decode(samples["samples"]) - if len(images.shape) == 5: #Combine batches - images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) - return (images, ) + return fast_vae_decode(vae, samples) class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), - "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}), - "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), - "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}), - "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), - }} + return { + "required": { + "samples": ("LATENT", {"tooltip": "The latent to be decoded."}), + "vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."}), + "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "tooltip": "Tile size for tiled decoding."}), + "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "tooltip": "Tile overlap for tiled decoding."}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), + } + } RETURN_TYPES = ("IMAGE",) + OUTPUT_TOOLTIPS = ("The decoded image.",) FUNCTION = "decode" - CATEGORY = "_for_testing" + DESCRIPTION = "Decodes latent images back into pixel space images using tiled decoding for VRAM efficiency." - def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): - if tile_size < overlap * 4: - overlap = tile_size // 4 - if temporal_size < temporal_overlap * 2: - temporal_overlap = temporal_overlap // 2 - temporal_compression = vae.temporal_compression_decode() - if temporal_compression is not None: - temporal_size = max(2, temporal_size // temporal_compression) - temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression)) - else: - temporal_size = None - temporal_overlap = None - - compression = vae.spacial_compression_decode() - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) - if len(images.shape) == 5: #Combine batches - images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) - return (images, ) - + def decode(self, vae, samples, tile_size=512, overlap=64, temporal_size=64, temporal_overlap=8): + return fast_vae_tiled_decode(vae, samples, tile_size=tile_size, overlap=overlap, + temporal_size=temporal_size, temporal_overlap=temporal_overlap) class VAEEncode: @classmethod def INPUT_TYPES(s): @@ -1473,28 +1462,26 @@ def set_mask(self, samples, mask): s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, + denoise=1.0, disable_noise=False, start_step=None, last_step=None, + force_full_denoise=False): + # Get device and dtype + device = comfy.model_management.get_torch_device() + dtype = getattr(model.model, 'dtype', torch.float32) + is_gpu = device.type == 'cuda' and torch.cuda.is_available() + + # Prepare latent image latent_image = latent["samples"] latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) - if disable_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") - else: - batch_inds = latent["batch_index"] if "batch_index" in latent else None - noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) - - noise_mask = None - if "noise_mask" in latent: - noise_mask = latent["noise_mask"] - - callback = latent_preview.prepare_callback(model, steps) - disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED - samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, - denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - out = latent.copy() - out["samples"] = samples - return (out, ) + # Call fast_ksampler with device, dtype, and is_gpu + out = fast_ksampler( + model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, + denoise=denoise, disable_noise=disable_noise, start_step=start_step, + last_step=last_step, force_full_denoise=force_full_denoise, + device=device, dtype=dtype, is_gpu=is_gpu + ) + return out class KSampler: @classmethod @@ -1502,27 +1489,64 @@ def INPUT_TYPES(s): return { "required": { "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}), - "sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}), - "scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}), - "positive": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to include in the image."}), - "negative": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to exclude from the image."}), + "seed": ("INT", { + "default": 0, + "min": 0, + "max": 0xffffffffffffffff, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise." + }), + "steps": ("INT", { + "default": 20, + "min": 1, + "max": 10000, + "tooltip": "The number of steps used in the denoising process." + }), + "cfg": ("FLOAT", { + "default": 8.0, + "min": 0.0, + "max": 100.0, + "step": 0.1, + "round": 0.01, + "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt." + }), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, { + "tooltip": "The algorithm used when sampling." + }), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, { + "tooltip": "The scheduler controls how noise is gradually removed to form the image." + }), + "positive": ("CONDITIONING", { + "tooltip": "The conditioning describing the attributes to include." + }), + "negative": ("CONDITIONING", { + "tooltip": "The conditioning describing the attributes to exclude." + }), "latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}), + "denoise": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "The amount of denoising applied." + }), } } RETURN_TYPES = ("LATENT",) OUTPUT_TOOLTIPS = ("The denoised latent.",) FUNCTION = "sample" - CATEGORY = "sampling" - DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." + DESCRIPTION = "Denoises the latent image using the provided model and conditioning." - def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=1.0): + latent = latent_image.copy() + if "samples" in latent: + latent["samples"] = latent["samples"].to( + comfy.model_management.get_torch_device(), non_blocking=True) + return common_ksampler( + model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=denoise) class KSamplerAdvanced: @classmethod @@ -1618,14 +1642,44 @@ def __init__(self): self.output_dir = folder_paths.get_temp_directory() self.type = "temp" self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) - self.compress_level = 1 + self.compress_level = 4 # Faster for previews, SaveImage keeps 1 for fork @classmethod def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + return {"required": {"images": ("IMAGE", )}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}} + + def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + from PIL import Image + import numpy as np + import os + + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = [] + + for batch_number, image in enumerate(images): + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + # Adaptive resize to max dimension ~512, preserve aspect ratio + max_size = 512 + if max(img.width, img.height) > max_size: + scale = max_size / max(img.width, img.height) + new_width = int(img.width * scale) + new_height = int(img.height * scale) + img = img.resize((new_width, new_height), Image.LANCZOS) + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" + img.save(os.path.join(full_output_folder, file), format="PNG", compress_level=self.compress_level, optimize=True) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + return {"ui": {"images": results}} class LoadImage: @classmethod From aaed282c3adf0bc5e6b7a87105358f42f185e33c Mon Sep 17 00:00:00 2001 From: loxotron Date: Thu, 15 May 2025 12:13:57 +0300 Subject: [PATCH 2/6] vae tiled fixes and few other mistakes --- comfy/model_management.py | 2 +- comfy/sd.py | 40 +++++--- comfy/utils.py | 208 +++++++++++++++++++++++++++----------- fast_sampler.py | 19 ++-- 4 files changed, 188 insertions(+), 81 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5f19e4b01c7..7bc3e5075b7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -800,7 +800,7 @@ def model_memory_required(self, device): return self.model_offloaded_memory() # Handle AutoencoderKL - if self.model.model is not None and isinstance(self.model.model, AutoencoderKL): + if self.model is not None and isinstance(self.model.model, AutoencoderKL): shape = getattr(self.model, 'last_shape', (1, 4, 64, 64)) dtype = getattr(self.model, 'model_dtype', torch.float32)() return estimate_vae_decode_memory(self.model.model, shape, dtype) diff --git a/comfy/sd.py b/comfy/sd.py index e98a3aa87ce..046e74c439d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -28,6 +28,7 @@ from . import sd1_clip from . import sdxl_clip +from comfy.cli_args import args import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 @@ -54,6 +55,8 @@ import comfy.ldm.flux.redux +DEBUG_ENABLED = args.debug + def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = {} if model is not None: @@ -497,18 +500,23 @@ def vae_encode_crop_pixels(self, pixels): pixels = pixels.narrow(d + 1, x_offset, x) return pixels - def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16): + # Calculate progress bar steps for a single pass steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) - steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() - output = self.process_output( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) - / 3.0) + # Define decode function with tile logging + if not DEBUG_ENABLED: + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + else: + decode_fn = lambda a: (logging.debug(f"Tile shape: {a.shape}, min: {a.min()}, max: {a.max()}"), + self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float())[1] + + # Single pass with provided tile sizes + output = comfy.utils.tiled_scale( + samples, decode_fn, tile_x, tile_y, overlap, + upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar + ) return output def decode_tiled_1d(self, samples, tile_x=128, overlap=32): @@ -1068,9 +1076,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") - left_over = sd.keys() - if len(left_over) > 0: - logging.debug("left over keys: {}".format(left_over)) + if DEBUG_ENABLED: + left_over = sd.keys() + if len(left_over) > 0: + logging.debug("left over keys: {}".format(left_over)) if output_model: model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) @@ -1137,9 +1146,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") - left_over = sd.keys() - if len(left_over) > 0: - logging.info("left over keys in unet: {}".format(left_over)) + if DEBUG_ENABLED: + left_over = sd.keys() + if len(left_over) > 0: + logging.info("left over keys in unet: {}".format(left_over)) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) diff --git a/comfy/utils.py b/comfy/utils.py index 22c3d8c13c4..8113da7a52d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -33,6 +33,9 @@ MMAP_TORCH_FILES = args.mmap_torch_files ALWAYS_SAFE_LOAD = False + +DEBUG_ENABLED = args.debug + if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated class ModelCheckpoint: pass @@ -876,116 +879,205 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): + """ + Perform tiled scaling of input samples using the provided function with overlap blending. + + Args: + samples: Input tensor of shape [batch, channels, *spatial_dims]. + function: Function to process each tile (e.g., VAE decode). + tile: Tuple of tile sizes for each spatial dimension. + overlap: Overlap size or list of overlaps for each dimension. + upscale_amount: Scaling factor or list of factors for each dimension. + out_channels: Number of output channels. + output_device: Device for output tensor. + downscale: If True, downscale instead of upscale. + index_formulas: Scaling factors for tile positions (defaults to upscale_amount). + pbar: Optional progress bar object. + + Returns: + Scaled output tensor of shape [batch, out_channels, *scaled_spatial_dims]. + """ dims = len(tile) - if not (isinstance(upscale_amount, (tuple, list))): - upscale_amount = [upscale_amount] * dims + # Wrap function to ensure FP32 + def fp32_function(x): + x_fp32 = x.to(dtype=torch.float32) + result = function(x_fp32) + return result.to(dtype=torch.float32) - if not (isinstance(overlap, (tuple, list))): + # Convert parameters to lists for multidimensional support + if not isinstance(upscale_amount, (tuple, list)): + upscale_amount = [upscale_amount] * dims + if not isinstance(overlap, (tuple, list)): overlap = [overlap] * dims - if index_formulas is None: index_formulas = upscale_amount - - if not (isinstance(index_formulas, (tuple, list))): + if not isinstance(index_formulas, (tuple, list)): index_formulas = [index_formulas] * dims + # Define scaling functions def get_upscale(dim, val): up = upscale_amount[dim] - if callable(up): - return up(val) - else: - return up * val + return up(val) if callable(up) else up * val def get_downscale(dim, val): up = upscale_amount[dim] - if callable(up): - return up(val) - else: - return val / up + return up(val) if callable(up) else val / up def get_upscale_pos(dim, val): up = index_formulas[dim] - if callable(up): - return up(val) - else: - return up * val + return up(val) if callable(up) else up * val def get_downscale_pos(dim, val): up = index_formulas[dim] - if callable(up): - return up(val) - else: - return val / up + return up(val) if callable(up) else val / up - if downscale: - get_scale = get_downscale - get_pos = get_downscale_pos - else: - get_scale = get_upscale - get_pos = get_upscale_pos + get_scale = get_downscale if downscale else get_upscale + get_pos = get_downscale_pos if downscale else get_upscale_pos def mult_list_upscale(a): - out = [] - for i in range(len(a)): - out.append(round(get_scale(i, a[i]))) - return out - - output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) + """Compute scaled dimensions for output tensor.""" + return [round(get_scale(i, a[i])) for i in range(len(a))] + + # Initialize output tensor + output_shape = [samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]) + output = torch.empty(output_shape, device=output_device, dtype=torch.float32) + if DEBUG_ENABLED: + logging.debug(f"Input shape: {samples.shape}, output shape: {output_shape}, tile: {tile}") + logging.debug(f"Input stats: min={samples.min():.4f}, max={samples.max():.4f}, mean={samples.mean():.4f}") + + # Test VAE + try: + test_input = samples[:1].to(dtype=torch.float32) + test_result = fp32_function(test_input).to(output_device, dtype=torch.float32) + if DEBUG_ENABLED: + logging.debug(f"VAE test result: shape={test_result.shape}, min={test_result.min():.4f}, max={test_result.max():.4f}") + if torch.isnan(test_result).any() or torch.isinf(test_result).any(): + logging.error("VAE produces NaN or Inf in test output. Check VAE model or input latents.") + raise RuntimeError("VAE output contains NaN or Inf") + except Exception as e: + logging.error(f"VAE test failed: {e}") + raise for b in range(samples.shape[0]): s = samples[b:b+1] - # handle entire input fitting in a single tile + # Handle case where input fits in a single tile if all(s.shape[d+2] <= tile[d] for d in range(dims)): - output[b:b+1] = function(s).to(output_device) + s_fp32 = s.to(dtype=torch.float32) + result = fp32_function(s_fp32).to(output_device, dtype=torch.float32) + if DEBUG_ENABLED: + logging.debug(f"Single tile result: shape={result.shape}, min={result.min():.4f}, max={result.max():.4f}") + if result.shape == output_shape[1:]: + output[b:b+1] = result + else: + result = result.narrow(1, 0, output_shape[1]) + for d in range(dims): + result = result.narrow(d + 2, 0, output_shape[d + 2]) + output[b:b+1] = result if pbar is not None: pbar.update(1) continue - out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - - positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] - + # Initialize accumulation tensors + out = torch.zeros(output_shape[1:], device=output_device, dtype=torch.float32) + out_div = torch.full_like(out, 1e-6) + + # Compute tile positions + positions = [] + tile_counts = [] + for d in range(dims): + step = max(1, tile[d] - overlap[d]) + end = max(0, s.shape[d+2] - tile[d]) + pos = list(range(0, end + 1, step)) + if pos and (pos[-1] < end or s.shape[d+2] > tile[d]): + pos.append(end) + positions.append(pos if pos else [0]) + tile_counts.append(len(pos)) + if DEBUG_ENABLED: + logging.debug(f"Tile positions: {positions}") + + # Process each tile + total_tiles = max(1, len(list(itertools.product(*positions)))) for it in itertools.product(*positions): s_in = s upscaled = [] + # Extract tile for d in range(dims): - pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) - l = min(tile[d], s.shape[d + 2] - pos) - s_in = s_in.narrow(d + 2, pos, l) + pos = max(0, min(s.shape[d+2] - tile[d], it[d])) + length = min(tile[d], s.shape[d+2] - pos) + s_in = s_in.narrow(d + 2, pos, length) upscaled.append(round(get_pos(d, pos))) - ps = function(s_in).to(output_device) - mask = torch.ones_like(ps) - - for d in range(2, dims + 2): - feather = round(get_scale(d - 2, overlap[d - 2])) - if feather >= mask.shape[d]: + # Process tile + s_in_fp32 = s_in.to(dtype=torch.float32) + ps = fp32_function(s_in_fp32).to(output_device, dtype=torch.float32) + if DEBUG_ENABLED: + logging.debug(f"Tile at {it}: input={s_in.shape}, output={ps.shape}, min={ps.min():.4f}, max={ps.max():.4f}") + if torch.isnan(ps).any() or torch.isinf(ps).any(): + if DEBUG_ENABLED: + logging.warning(f"Tile at {it} contains NaN or Inf, clamping values") + ps = torch.clamp(ps, min=-1e6, max=1e6) + mask = torch.ones_like(ps, dtype=torch.float32) / total_tiles + + # Apply feathering for smooth overlap blending + for d in range(dims): + feather = min(round(get_scale(d, overlap[d])), ps.shape[d + 2] // 16) + if feather < 1: continue for t in range(feather): a = (t + 1) / feather - mask.narrow(d, t, 1).mul_(a) - mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) + mask.narrow(d + 2, t, 1).mul_(a) + mask.narrow(d + 2, mask.shape[d + 2] - 1 - t, 1).mul_(a) + mask = mask.clamp(min=1e-6) + # Accumulate results o = out o_d = out_div for d in range(dims): - o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - + start = upscaled[d] + size = min(ps.shape[d + 2], output_shape[d + 2] - start) + o = o.narrow(d + 1, start, size) + o_d = o_d.narrow(d + 1, start, size) + ps = ps.narrow(d + 2, 0, size) + mask = mask.narrow(d + 2, 0, size) + + # Squeeze batch dimension + ps = ps.squeeze(0) + mask = mask.squeeze(0) o.add_(ps * mask) o_d.add_(mask) if pbar is not None: pbar.update(1) - output[b:b+1] = out/out_div - return output + # Fallback to non-tiled if NaN + if torch.isnan(out).any(): + if DEBUG_ENABLED: + logging.warning("NaN detected in tiled output, falling back to non-tiled") + s_fp32 = s.to(dtype=torch.float32) + result = fp32_function(s_fp32).to(output_device, dtype=torch.float32) + if result.shape == output_shape[1:]: + output[b:b+1] = result + else: + result = result.narrow(1, 0, output_shape[1]) + for d in range(dims): + result = result.narrow(d + 2, 0, output_shape[d + 2]) + output[b:b+1] = result + else: + if DEBUG_ENABLED: + logging.debug(f"out stats: min={out.min():.4f}, max={out.max():.4f}") + logging.debug(f"out_div stats: min={out_div.min():.4f}, max={out_div.max():.4f}") + output[b:b+1] = out / out_div + + if pbar is not None: + pbar.update(1) -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): + return output + +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): + """Wrapper for 2D tiled scaling.""" return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) PROGRESS_BAR_ENABLED = True diff --git a/fast_sampler.py b/fast_sampler.py index f8f62e12ae5..47a8efc6d13 100644 --- a/fast_sampler.py +++ b/fast_sampler.py @@ -8,6 +8,7 @@ from contextlib import contextmanager import latent_preview import logging +import traceback # Global flag for profiling PROFILING_ENABLED = args.profile @@ -41,15 +42,18 @@ def profile_cuda_sync(is_gpu, message="CUDA sync"): logging.debug(f"{message} took {time.time() - sync_start:.3f} s") def is_fp16_safe(device): - """Check if FP16 is safe for the GPU (disabled for GTX 1660/Turing).""" + """Check if FP16 is safe for the GPU (disabled for Turing).""" if device.type != 'cuda': return False if device in _fp16_safe_cache: return _fp16_safe_cache[device] try: props = torch.cuda.get_device_properties(device) - is_safe = props.major >= 8 or props.compute_capability[0] > 7 + # Disable FP16 for Turing (major == 7) and earlier architectures + is_safe = props.major >= 8 # Allow FP16 only for Ampere (8.x) and later _fp16_safe_cache[device] = is_safe + if DEBUG_ENABLED: + logging.debug(f"FP16 safety check for {props.name}: major={props.major}, is_safe={is_safe}") return is_safe except Exception: _fp16_safe_cache[device] = False @@ -72,7 +76,8 @@ def clear_vram(device, threshold=0.5, min_free=1.5): mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 critical_threshold = 0.05 * mem_total + 0.1 # 5% VRAM + 100 MB if mem_allocated > threshold * mem_total or (mem_total - mem_allocated) < max(min_free, critical_threshold): - logging.debug(f"Clearing VRAM: allocated {mem_allocated:.2f} GB, free {mem_total - mem_allocated:.2f} GB, threshold {critical_threshold:.2f} GB") + if PROFILING_ENABLED: + logging.debug(f"Clearing VRAM: allocated {mem_allocated:.2f} GB, free {mem_total - mem_allocated:.2f} GB, threshold {critical_threshold:.2f} GB") torch.cuda.empty_cache() #soft_empty_cache(clear=False) mem_after = torch.cuda.memory_allocated(device) / 1024**3 @@ -362,9 +367,9 @@ def fast_vae_decode(vae, samples): def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size=64, temporal_overlap=8): """Fast VAE decoding with tiling for low VRAM, consistent with fast_vae_decode.""" device, dtype, is_gpu = initialize_device_and_dtype(vae) - vae_dtype = vae_dtype(device=device) + vae_dtype_val = vae_dtype(device=device) if DEBUG_ENABLED: - logging.debug(f"VAE dtype: {vae_dtype}") + logging.debug(f"VAE dtype: {vae_dtype_val}") logging.debug(f"Pre-VAE checkpoint: {time.time()}") try: @@ -378,7 +383,7 @@ def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 free_mem = mem_total - mem_allocated # Estimate memory for tiled decoding (conservative, ~50% of full decode) - vae_memory_required = (vae.memory_used_decode(samples["samples"].shape, vae_dtype) / 1024**3 * 0.5 + vae_memory_required = (vae.memory_used_decode(samples["samples"].shape, vae_dtype_val) / 1024**3 * 0.5 if hasattr(vae, 'memory_used_decode') else 0.75) if PROFILING_ENABLED: logging.debug(f"VRAM before tiled VAE: {mem_allocated:.2f} GB / {mem_total:.2f} GB") @@ -408,7 +413,7 @@ def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size latent_samples = samples["samples"] if PROFILING_ENABLED: logging.debug(f"Latent samples device: {latent_samples.device}, dtype: {latent_samples.dtype}") - latent_samples = optimized_transfer(latent_samples, device, vae_dtype) + latent_samples = optimized_transfer(latent_samples, device, vae_dtype_val) if is_gpu and force_channels_last(): latent_samples = latent_samples.to(memory_format=torch.channels_last) vae.first_stage_model.to(memory_format=torch.channels_last) From 07b066c5108f033c983cc656741d298513c968e8 Mon Sep 17 00:00:00 2001 From: loxotron Date: Fri, 16 May 2025 19:45:50 +0300 Subject: [PATCH 3/6] fixes for directml, use_pytorch_cross_attention and channels_last args workaround for torch.count_nonzero on DirectML --- comfy/model_management.py | 135 ++++++++++++++++++++++++++------------ comfy/samplers.py | 19 +++++- fast_sampler.py | 37 ++++++++--- 3 files changed, 139 insertions(+), 52 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7bc3e5075b7..dac5985dc2d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -114,13 +114,34 @@ def is_directml_enabled(): return False def get_supported_float8_types(): - """Get supported float8 data types.""" + """Get supported float8 data types available in the current PyTorch version.""" float8_types = [] - for dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.float8_e8m0fnu]: + # List of potential float8 type names to check + float8_type_names = [ + 'float8_e4m3fn', + 'float8_e4m3fnuz', + 'float8_e5m2', + 'float8_e5m2fnuz', + 'float8_e8m0fnu', + ] + + for dtype_name in float8_type_names: try: - float8_types.append(dtype) - except: + # Check if the dtype exists in torch module + if hasattr(torch, dtype_name): + dtype = getattr(torch, dtype_name) + # Verify that the dtype is a valid torch.dtype + if isinstance(dtype, torch.dtype): + float8_types.append(dtype) + except Exception as e: + # Log the error only in debug mode to avoid clutter + if DEBUG_ENABLED: + logging.debug(f"Failed to access torch.{dtype_name}: {str(e)}") pass + + if DEBUG_ENABLED: + logging.debug(f"Supported float8 types: {[str(dtype) for dtype in float8_types]}") + return float8_types def get_directml_vram(dev): @@ -191,7 +212,11 @@ def get_directml_vram(dev): FLOAT8_TYPES = get_supported_float8_types() XFORMERS_IS_AVAILABLE = False XFORMERS_ENABLED_VAE = True -ENABLE_PYTORCH_ATTENTION = True # Enable PyTorch attention for better performance +ENABLE_PYTORCH_ATTENTION = False +if args.use_pytorch_cross_attention: + ENABLE_PYTORCH_ATTENTION = True + XFORMERS_IS_AVAILABLE = False + FORCE_FP32 = args.force_fp32 DISABLE_SMART_MEMORY = args.disable_smart_memory @@ -437,7 +462,7 @@ def flash_attention_enabled(): def pytorch_attention_enabled(): """Check if PyTorch attention is enabled.""" global ENABLE_PYTORCH_ATTENTION - return ENABLE_PYTORCH_ATTENTION or not (xformers_enabled() or sage_attention_enabled() or flash_attention_enabled()) + return ENABLE_PYTORCH_ATTENTION def pytorch_attention_enabled_vae(): """Check if PyTorch attention is enabled for VAE.""" @@ -502,29 +527,37 @@ class OOM_EXCEPTION(Exception): """Exception raised for out-of-memory errors.""" pass -if args.use_pytorch_cross_attention: - ENABLE_PYTORCH_ATTENTION = True - XFORMERS_IS_AVAILABLE = False MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia() else 0.0 -if is_nvidia() and torch_version_numeric[0] >= 2: - if not (ENABLE_PYTORCH_ATTENTION or args.use_split_cross_attention or args.use_quad_cross_attention): - ENABLE_PYTORCH_ATTENTION = True -elif is_intel_xpu() or is_ascend_npu() or is_mlu(): - if not (args.use_split_cross_attention or args.use_quad_cross_attention): - ENABLE_PYTORCH_ATTENTION = True -elif is_amd() and torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName - logging.info(f"AMD arch: {arch}") - if any(a in arch for a in ["gfx1100", "gfx1101"]) and not (args.use_split_cross_attention or args.use_quad_cross_attention): - ENABLE_PYTORCH_ATTENTION = True -if ENABLE_PYTORCH_ATTENTION: - torch.backends.cuda.enable_math_sdp(True) - torch.backends.cuda.enable_flash_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(True) -if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: - torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) -else: - logging.warning("Could not set allow_fp16_bf16_reduction_math_sdp") + +try: + if is_nvidia() and torch_version_numeric[0] >= 2: + if not (ENABLE_PYTORCH_ATTENTION or args.use_split_cross_attention or args.use_quad_cross_attention): + ENABLE_PYTORCH_ATTENTION = True + elif is_intel_xpu() or is_ascend_npu() or is_mlu(): + if not (args.use_split_cross_attention or args.use_quad_cross_attention): + ENABLE_PYTORCH_ATTENTION = True + elif is_amd() and torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logging.info(f"AMD arch: {arch}") + if any(a in arch for a in ["gfx1100", "gfx1101", "gfx1030", "gfx1031", "gfx1032"]) and not (args.use_split_cross_attention or args.use_quad_cross_attention): + ENABLE_PYTORCH_ATTENTION = True +except: + pass + +if ENABLE_PYTORCH_ATTENTION and not directml_enabled: + try: + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) + elif DEBUG_ENABLED: + logging.debug("Could not set allow_fp16_bf16_reduction_math_sdp due to PyTorch version < 2.5") + except Exception as e: + if DEBUG_ENABLED: + logging.debug(f"Failed to enable CUDA SDP optimizations: {str(e)}") +elif directml_enabled and DEBUG_ENABLED: + logging.debug("Skipped CUDA-specific SDP optimizations (math_sdp, flash_sdp, mem_efficient_sdp, allow_fp16_bf16_reduction_math_sdp) for DirectML") def get_free_memory(dev=None, torch_free_too=False): """ @@ -1350,6 +1383,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return False if args.force_fp16: return supports_cast(torch.float16, device) + if directml_enabled: + if DEBUG_ENABLED: + logging.debug("should_use_fp16: DirectML detected, disabling FP16 due to potential instability") + return False if is_intel_xpu(): return True if is_mlu(): @@ -1358,21 +1395,37 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_ascend_npu(): return False if is_amd(): - arch = torch.cuda.get_device_properties(device).gcnArchName - if any(a in arch for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): - return manual_cast - return True - props = torch.cuda.get_device_properties(device) + try: + arch = torch.cuda.get_device_properties(device).gcnArchName + if any(a in arch for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): + return manual_cast + return True + except AssertionError: + # Fallback for non-CUDA AMD GPUs (e.g., via DirectML) + if DEBUG_ENABLED: + logging.debug("should_use_fp16: Fallback to False for AMD GPU without CUDA") + return False if is_nvidia(): - # Prefer FP32 for low VRAM or older GPUs - total_vram = get_total_memory(device) / (1024**3) - if total_vram < 5.9 or props.major <= 7: # Turing (7.5) or Pascal (6.x) + try: + props = torch.cuda.get_device_properties(device) + # Prefer FP32 for low VRAM or older GPUs + total_vram = get_total_memory(device) / (1024**3) + if total_vram < 5.9 or props.major <= 7: # Turing (7.5) or Pascal (6.x) + return False + if any(platform.win32_ver()) and props.major <= 7: + return manual_cast and torch.cuda.is_bf16_supported() + if props.major >= 8: + return True + return torch.cuda.is_bf16_supported() and manual_cast and (not prioritize_performance or model_params * 4 > get_total_memory(device)) + except AssertionError: + # Fallback for non-CUDA NVIDIA GPUs + if DEBUG_ENABLED: + logging.debug("should_use_fp16: Fallback to False for NVIDIA GPU without CUDA") return False - if any(platform.win32_ver()) and props.major <= 7: - return manual_cast and torch.cuda.is_bf16_supported() - if props.major >= 8: - return True - return torch.cuda.is_bf16_supported() and manual_cast and (not prioritize_performance or model_params * 4 > get_total_memory(device)) + # Fallback for other devices + if DEBUG_ENABLED: + logging.debug("should_use_fp16: Fallback to False for unknown device") + return False def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): """Determine if BF16 should be used for the device.""" diff --git a/comfy/samplers.py b/comfy/samplers.py index 67ae09a2551..a672d870afd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -10,6 +10,7 @@ from functools import partial import collections from comfy import model_management +from comfy.cli_args import args import math import logging import comfy.sampler_helpers @@ -19,6 +20,7 @@ import scipy.stats import numpy +DEBUG_ENABLED = args.debug def add_area_dims(area, num_dims): while (len(area) // 2) < num_dims: @@ -942,15 +944,28 @@ def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): - if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. - latent_image = self.inner_model.process_latent_in(latent_image) + # Workaround for torch.count_nonzero on DirectML + if latent_image is not None: + if model_management.is_directml_enabled(): + nonzero_count = torch.sum(latent_image != 0).item() + if DEBUG_ENABLED: + logging.debug(f"inner_sample: DirectML count_nonzero replacement: nonzero_count={nonzero_count}") + else: + nonzero_count = torch.count_nonzero(latent_image).item() + if nonzero_count > 0: # Don't shift the empty latent image + latent_image = self.inner_model.process_latent_in(latent_image) + else: + nonzero_count = 0 + # Process conditions self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + # Clone model options and add sample sigmas extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas extra_args = {"model_options": extra_model_options, "seed": seed} + # Execute sampler with wrappers executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( sampler.sample, sampler, diff --git a/fast_sampler.py b/fast_sampler.py index 47a8efc6d13..a821d78955e 100644 --- a/fast_sampler.py +++ b/fast_sampler.py @@ -4,7 +4,7 @@ import time from torch.amp import autocast from comfy.cli_args import args -from comfy.model_management import get_torch_device, vae_dtype, soft_empty_cache, free_memory, force_channels_last, estimate_vae_decode_memory, device_supports_non_blocking +from comfy.model_management import get_torch_device, vae_dtype, soft_empty_cache, free_memory, force_channels_last, estimate_vae_decode_memory, device_supports_non_blocking, directml_enabled from contextlib import contextmanager import latent_preview import logging @@ -150,7 +150,12 @@ def finalize_images(images, device): """Process and finalize output images.""" if len(images.shape) == 5: # Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) - return images.to(device=device, memory_format=torch.channels_last) + # Apply channels_last only for CUDA devices if force_channels_last is enabled + is_gpu = device.type == 'cuda' and torch.cuda.is_available() + memory_format = torch.channels_last if (is_gpu and not directml_enabled and force_channels_last()) else torch.contiguous_format + if DEBUG_ENABLED: + logging.debug(f"finalize_images: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}") + return images.to(device=device, memory_format=memory_format) def fast_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise, disable_noise, start_step, last_step, force_full_denoise, noise_mask, callback, seed, device, dtype, is_gpu): @@ -170,7 +175,11 @@ def fast_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, neg force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed ) - samples = samples.to(device=device, dtype=dtype, memory_format=torch.channels_last) + # Apply channels_last only for CUDA devices if force_channels_last is enabled + memory_format = torch.channels_last if (is_gpu and not directml_enabled and force_channels_last()) else torch.contiguous_format + if DEBUG_ENABLED: + logging.debug(f"fast_sample: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}") + samples = samples.to(device=device, dtype=dtype, memory_format=memory_format) if PROFILING_ENABLED: logging.debug(f"Sampling completed, took {time.time() - start_time:.3f} s") @@ -330,15 +339,19 @@ def fast_vae_decode(vae, samples): # Preload VAE to device preload_model(vae, device, is_vae=True) - # Transfer latents with channels_last + # Transfer latents with appropriate memory format with profile_section("VAE latent transfer"): non_blocking = is_gpu and device_supports_non_blocking(device) latent_samples = samples["samples"].to(device, dtype=vae_dtype_val, non_blocking=non_blocking) - if is_gpu and force_channels_last(): + # Apply channels_last only for CUDA devices if force_channels_last is enabled + memory_format = torch.channels_last if (is_gpu and not directml_enabled and force_channels_last()) else torch.contiguous_format + if is_gpu and memory_format == torch.channels_last: + if DEBUG_ENABLED: + logging.debug(f"fast_vae_decode: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}") latent_samples = latent_samples.to(memory_format=torch.channels_last) vae.first_stage_model.to(memory_format=torch.channels_last) - if PROFILING_ENABLED: - logging.debug(f"Latent samples device: {latent_samples.device}, dtype: {latent_samples.dtype}") + elif DEBUG_ENABLED: + logging.debug(f"fast_vae_decode: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}") # Decode latents with torch.no_grad(): @@ -408,15 +421,21 @@ def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size logging.debug(f"VAE preload took {time.time() - preload_start:.3f} s") logging.debug(f"Post-preload checkpoint: {time.time()}") - # Transfer latents + # Transfer latents with appropriate memory format with profile_section("VAE latent transfer"): latent_samples = samples["samples"] if PROFILING_ENABLED: logging.debug(f"Latent samples device: {latent_samples.device}, dtype: {latent_samples.dtype}") latent_samples = optimized_transfer(latent_samples, device, vae_dtype_val) - if is_gpu and force_channels_last(): + # Apply channels_last only for CUDA devices if force_channels_last is enabled + memory_format = torch.channels_last if (is_gpu and not directml_enabled and force_channels_last()) else torch.contiguous_format + if is_gpu and memory_format == torch.channels_last: + if DEBUG_ENABLED: + logging.debug(f"fast_vae_tiled_decode: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}") latent_samples = latent_samples.to(memory_format=torch.channels_last) vae.first_stage_model.to(memory_format=torch.channels_last) + elif DEBUG_ENABLED: + logging.debug(f"fast_vae_tiled_decode: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}") # Log before decoding if PROFILING_ENABLED: From 2fd0a1296f9223b4b045759b631115787513a17b Mon Sep 17 00:00:00 2001 From: loxotron Date: Sat, 17 May 2025 22:22:55 +0300 Subject: [PATCH 4/6] Fixes for DirectML detection in fast_sampler.py, bugfixes in VRAM management for DirectML devices, improved logging for better debugging and profiling --- comfy/model_management.py | 71 ++++++++++++++++--- fast_sampler.py | 139 ++++++++++++++++++++++++++------------ nodes.py | 2 +- 3 files changed, 158 insertions(+), 54 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dac5985dc2d..27d79545dfc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -87,6 +87,9 @@ class VRAMState(Enum): # Cache for active models memory in DirectML _directml_active_memory_cache = {} +# Model management +current_loaded_models = [] + def cpu_mode(): """Check if system is in CPU mode.""" global cpu_state @@ -177,7 +180,7 @@ def get_directml_vram(dev): # Try torch_directml heuristic if _torch_directml_available: try: - device_index = dev.index if hasattr(dev, 'index') else 0 + device_index = dev.index if hasattr(dev, 'index') and dev.index is not None else 0 device_name = torch_directml.device_name(device_index).lower() vram_map = { 'gtx 1660': 6 * 1024 * 1024 * 1024, @@ -188,6 +191,7 @@ def get_directml_vram(dev): 'rx 580': 8 * 1024 * 1024 * 1024, 'rx 570': 8 * 1024 * 1024 * 1024, 'rx 6700': 12 * 1024 * 1024 * 1024, + 'rx 6800': 16 * 1024 * 1024 * 1024, 'arc a770': 16 * 1024 * 1024 * 1024, } vram = 6 * 1024 * 1024 * 1024 @@ -579,10 +583,55 @@ def get_free_memory(dev=None, torch_free_too=False): if directml_enabled: total_vram = get_directml_vram(dev) cache_key = (dev, 'active_models') - if cache_key not in _directml_active_memory_cache: - active_models = sum(m.model_loaded_memory() for m in current_loaded_models if m.device == dev) - _directml_active_memory_cache[cache_key] = active_models - active_models = _directml_active_memory_cache[cache_key] + # Invalidate cache if models list has changed + current_models_hash = hash(tuple((id(m), m.model_loaded_memory() if not m.is_dead() else 0) for m in current_loaded_models)) + if cache_key in _directml_active_memory_cache: + cached_hash, cached_active_models = _directml_active_memory_cache[cache_key] + if cached_hash == current_models_hash: + active_models = cached_active_models + if DEBUG_ENABLED: + logging.debug(f"Using cached active_models={active_models / (1024**3):.2f} GB for device {dev}") + else: + if DEBUG_ENABLED: + logging.debug(f"Cache invalidated for {dev}: models list changed") + active_models = None + else: + active_models = None + + if active_models is None: + active_models = 0 + try: + if DEBUG_ENABLED: + logging.debug(f"Processing {len(current_loaded_models)} models in get_free_memory for device {dev}") + for m in current_loaded_models: + model_name = m.model.__class__.__name__ if m.model else "Unknown" + if m.device != dev: + if DEBUG_ENABLED: + logging.debug(f"Skipping model {model_name}: device mismatch (model on {m.device}, expected {dev})") + continue + if m.is_dead(): + if DEBUG_ENABLED: + logging.debug(f"Skipping model {model_name}: model is dead") + continue + try: + mem = m.model_loaded_memory() + if DEBUG_ENABLED: + logging.debug(f"Loaded model {model_name} on device {m.device}, memory={mem / (1024**3):.2f} GB, is_dead={m.is_dead()}") + if mem <= 0: + logging.warning(f"Model {model_name} returned invalid memory: {mem}. Skipping.") + continue + active_models += mem + if DEBUG_ENABLED: + logging.debug(f"Model {model_name} on {dev}: loaded_memory={mem / (1024**3):.2f} GB") + except Exception as e: + logging.warning(f"Failed to calculate memory for model {model_name}: {str(e)}") + # Update cache + _directml_active_memory_cache[cache_key] = (current_models_hash, active_models) + except NameError: + logging.warning("current_loaded_models not defined yet in get_free_memory") + _directml_active_memory_cache[cache_key] = (current_models_hash, 0) + + # Apply safety margin (1.2x) and ensure at least 1 GB free mem_free_total = max(1024 * 1024 * 1024, total_vram - active_models * 1.2) mem_free_torch = mem_free_total if DEBUG_ENABLED: @@ -776,9 +825,6 @@ def register_vram_optimizer(optimizer): """Register a VRAM optimizer.""" _vram_optimizers.append(optimizer) -# Model management -current_loaded_models = [] - class LoadedModel: def __init__(self, model): self._set_model(model) @@ -972,6 +1018,15 @@ def module_size(model, shape=None, dtype=None): """ from diffusers import AutoencoderKL + # Early check for None model to avoid unnecessary processing + if model is None: + if DEBUG_ENABLED: + logging.warning( + f"module_size: Received None model. Assuming minimal memory (1 MB). " + f"Call stack: {''.join(traceback.format_stack(limit=5))}" + ) + return 1024 * 1024 # Minimal memory assumption for None model + module_mem = 0 if shape is not None and dtype is not None and isinstance(model, AutoencoderKL): try: diff --git a/fast_sampler.py b/fast_sampler.py index a821d78955e..ae752cbb904 100644 --- a/fast_sampler.py +++ b/fast_sampler.py @@ -33,9 +33,9 @@ def profile_section(name): else: yield -def profile_cuda_sync(is_gpu, message="CUDA sync"): +def profile_cuda_sync(is_gpu, device, message="CUDA sync"): """Profile CUDA synchronization time if GPU is used.""" - if PROFILING_ENABLED and is_gpu: + if PROFILING_ENABLED and is_gpu and device.type == 'cuda': logging.debug(f"{message} started") sync_start = time.time() torch.cuda.synchronize() @@ -64,7 +64,7 @@ def initialize_device_and_dtype(model, device=None): if device is None: device = get_torch_device() dtype = getattr(model, 'dtype', torch.float32) - is_gpu = device.type == 'cuda' and torch.cuda.is_available() + is_gpu = (device.type == 'cuda' and torch.cuda.is_available()) or (device.type == 'privateuseone') return device, dtype, is_gpu def clear_vram(device, threshold=0.5, min_free=1.5): @@ -87,6 +87,13 @@ def clear_vram(device, threshold=0.5, min_free=1.5): if PROFILING_ENABLED: logging.debug(f"VRAM not cleared: {mem_allocated:.2f} GB / {mem_total:.2f} GB, sufficient free memory") return mem_allocated, mem_total + elif device.type == 'privateuseone': + # For DirectML just clearing cache, cause mem_get_info not working + if PROFILING_ENABLED: + logging.debug("Clearing VRAM for DirectML (mem_get_info unavailable)") + torch.cuda.empty_cache() + return 0, 0 + return 0, 0 def preload_model(model, device, is_vae=False): """Preload model or VAE to device, avoiding unnecessary unloading.""" @@ -129,8 +136,13 @@ def preload_model(model, device, is_vae=False): comfy.model_management.load_model_gpu(model) model._loaded_to_device = device if PROFILING_ENABLED: - free_mem = (torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)) / 1024**3 - logging.debug(f"U-Net loaded to {device}, VRAM free: {free_mem:.2f} GB") + if device.type == 'cuda': + free_mem = (torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)) / 1024**3 + logging.debug(f"U-Net loaded to {device}, VRAM free: {free_mem:.2f} GB") + elif device.type == 'privateuseone': + logging.debug(f"U-Net loaded to {device}, VRAM info unavailable") + else: + logging.debug(f"U-Net loaded to {device}, no VRAM info for non-GPU device") def optimized_transfer(tensor, device, dtype): """Synchronous tensor transfer to device.""" @@ -150,11 +162,11 @@ def finalize_images(images, device): """Process and finalize output images.""" if len(images.shape) == 5: # Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) - # Apply channels_last only for CUDA devices if force_channels_last is enabled - is_gpu = device.type == 'cuda' and torch.cuda.is_available() + # Apply channels_last for CUDA or DirectML devices if force_channels_last is enabled + is_gpu = (device.type == 'cuda' and torch.cuda.is_available()) or (device.type == 'privateuseone') memory_format = torch.channels_last if (is_gpu and not directml_enabled and force_channels_last()) else torch.contiguous_format if DEBUG_ENABLED: - logging.debug(f"finalize_images: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}") + logging.debug(f"finalize_images: Using memory_format={memory_format} for device={device}, directml_enabled={directml_enabled}, is_gpu={is_gpu}") return images.to(device=device, memory_format=memory_format) def fast_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, @@ -195,9 +207,11 @@ def fast_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ne if DEBUG_ENABLED: if model is None: logging.warning("fast_ksampler: model is None") - + logging.debug(f"Starting fast_ksampler, device={device}, is_gpu={is_gpu}") if device is None or dtype is None or is_gpu is None: device, dtype, is_gpu = initialize_device_and_dtype(model.model) + if DEBUG_ENABLED: + logging.debug(f"Initialized device: {device}, dtype: {dtype}, is_gpu: {is_gpu}") try: # Enable cuDNN benchmarking if requested @@ -270,14 +284,17 @@ def fast_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ne # Log VRAM state after sampling if is_gpu and PROFILING_ENABLED: - mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 - mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 - logging.debug(f"VRAM after sampling: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + if device.type == 'cuda': + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + logging.debug(f"VRAM after sampling: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + elif device.type == 'privateuseone': + logging.debug("VRAM info unavailable for DirectML after sampling") # Log completion of sampling if PROFILING_ENABLED: logging.debug(f"Sampling completed, preparing for VAE") - profile_cuda_sync(is_gpu) + profile_cuda_sync(is_gpu, device) # Clear VRAM after sampling if is_gpu: @@ -285,8 +302,12 @@ def fast_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ne clear_vram(device, threshold=0.5, min_free=1.5) else: clear_start = time.time() - mem_allocated, mem_total = clear_vram(device, threshold=0.5, min_free=1.5) - logging.debug(f"VRAM after sampling: {mem_allocated:.2f} GB / {mem_total:.2f} GB, clear took {time.time() - clear_start:.3f} s") + if device.type == 'cuda': + mem_allocated, mem_total = clear_vram(device, threshold=0.5, min_free=1.5) + logging.debug(f"VRAM after sampling: {mem_allocated:.2f} GB / {mem_total:.2f} GB, clear took {time.time() - clear_start:.3f} s") + elif device.type == 'privateuseone': + clear_vram(device, threshold=0.5, min_free=1.5) + logging.debug(f"VRAM clear after sampling took {time.time() - clear_start:.3f} s (VRAM info unavailable for DirectML)") logging.debug(f"Post-VRAM checkpoint: {time.time()}") out = latent.copy() @@ -307,11 +328,13 @@ def fast_vae_decode(vae, samples): """ device = get_torch_device() vae_dtype_val = vae_dtype(device=device) - is_gpu = device.type == 'cuda' and torch.cuda.is_available() + is_gpu = (device.type == 'cuda' and torch.cuda.is_available()) or (device.type == 'privateuseone') if DEBUG_ENABLED: logging.debug(f"VAE dtype: {vae_dtype_val}") logging.debug(f"Pre-VAE checkpoint: {time.time()}") + logging.debug(f"Starting fast_vae_decode, device={device}, dtype={vae_dtype_val}, is_gpu={is_gpu}") + logging.debug(f"Latent samples shape: {samples['samples'].shape}") try: # Disable cuDNN benchmark for VAE stability if enabled @@ -320,21 +343,25 @@ def fast_vae_decode(vae, samples): # Prepare VRAM for VAE if is_gpu: - mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 - latent_size = samples["samples"].shape - model_for_memory = getattr(vae, 'first_stage_model', vae) - vae_memory_required = estimate_vae_decode_memory(model_for_memory, latent_size, vae_dtype_val) / 1024**3 - vram_threshold = 1.0 if mem_total < 5.9 else 1.1 - vae_memory_required *= vram_threshold - if PROFILING_ENABLED: - logging.debug(f"Estimated VAE memory: {vae_memory_required:.2f} GB") - mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 - free_mem = mem_total - mem_allocated - if free_mem < vae_memory_required: - #free_memory(vae_memory_required) - mem_allocated, mem_total = clear_vram(device, threshold=0.4, min_free=2.0) + if device.type == 'cuda': + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + latent_size = samples["samples"].shape + model_for_memory = getattr(vae, 'first_stage_model', vae) + vae_memory_required = estimate_vae_decode_memory(model_for_memory, latent_size, vae_dtype_val) / 1024**3 + vram_threshold = 1.0 if mem_total < 5.9 else 1.1 + vae_memory_required *= vram_threshold + if PROFILING_ENABLED: + logging.debug(f"Estimated VAE memory: {vae_memory_required:.2f} GB") + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + free_mem = mem_total - mem_allocated + if free_mem < vae_memory_required: + #free_memory(vae_memory_required) + mem_allocated, mem_total = clear_vram(device, threshold=0.4, min_free=2.0) + if PROFILING_ENABLED: + logging.debug(f"VRAM after free_memory: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + elif device.type == 'privateuseone': if PROFILING_ENABLED: - logging.debug(f"VRAM after free_memory: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + logging.debug("Memory info unavailable for DirectML, skipping VRAM check") # Preload VAE to device preload_model(vae, device, is_vae=True) @@ -384,6 +411,8 @@ def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size if DEBUG_ENABLED: logging.debug(f"VAE dtype: {vae_dtype_val}") logging.debug(f"Pre-VAE checkpoint: {time.time()}") + logging.debug(f"Starting fast_vae_tiled_decode, device={device}, dtype={vae_dtype_val}, is_gpu={is_gpu}") + logging.debug(f"Latent samples shape: {samples['samples'].shape}, tile_size={tile_size}, overlap={overlap}") try: # Disable cuDNN benchmark for tiled decoding stability if enabled @@ -392,25 +421,42 @@ def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size # Clear VRAM before VAE if is_gpu: - mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 - mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 - free_mem = mem_total - mem_allocated + if device.type == 'cuda': + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + free_mem = mem_total - mem_allocated + elif device.type == 'privateuseone': + if PROFILING_ENABLED: + logging.debug("Memory info unavailable for DirectML, skipping VRAM check") + # Estimate memory for tiled decoding (conservative, ~50% of full decode) vae_memory_required = (vae.memory_used_decode(samples["samples"].shape, vae_dtype_val) / 1024**3 * 0.5 if hasattr(vae, 'memory_used_decode') else 0.75) if PROFILING_ENABLED: - logging.debug(f"VRAM before tiled VAE: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + if device.type == 'cuda': + logging.debug(f"VRAM before tiled VAE: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + else: + logging.debug("VRAM before tiled VAE: unavailable for DirectML") logging.debug(f"Estimated tiled VAE memory: {vae_memory_required:.2f} GB") # Skip VRAM cleanup if VAE is already loaded and memory is sufficient - if (hasattr(vae, '_loaded_to_device') and vae._loaded_to_device == device and - free_mem >= vae_memory_required * 1.1): - if PROFILING_ENABLED: - logging.debug(f"VAE already loaded, sufficient memory: {free_mem:.2f} GB") - elif mem_allocated > 0.4 * mem_total or free_mem < vae_memory_required: - if PROFILING_ENABLED: - logging.debug(f"Clearing VRAM: {mem_allocated:.2f} GB used of {mem_total:.2f} GB") - mem_allocated, mem_total = clear_vram(device, threshold=0.4, min_free=0.75) + if device.type == 'cuda': + if (hasattr(vae, '_loaded_to_device') and vae._loaded_to_device == device and + free_mem >= vae_memory_required * 1.1): + if PROFILING_ENABLED: + logging.debug(f"VAE already loaded, sufficient memory: {free_mem:.2f} GB") + elif mem_allocated > 0.4 * mem_total or free_mem < vae_memory_required: + if PROFILING_ENABLED: + logging.debug(f"Clearing VRAM: {mem_allocated:.2f} GB used of {mem_total:.2f} GB") + mem_allocated, mem_total = clear_vram(device, threshold=0.4, min_free=0.75) + else: + if (hasattr(vae, '_loaded_to_device') and vae._loaded_to_device == device): + if PROFILING_ENABLED: + logging.debug("VAE already loaded on DirectML, skipping VRAM cleanup") + else: + if PROFILING_ENABLED: + logging.debug("Clearing VRAM for DirectML") + clear_vram(device, threshold=0.4, min_free=0.75) # Preload VAE if not PROFILING_ENABLED: @@ -481,9 +527,12 @@ def fast_vae_tiled_decode(vae, samples, tile_size=512, overlap=64, temporal_size images = finalize_images(images, device) if is_gpu and PROFILING_ENABLED: - mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 - mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 - logging.debug(f"VRAM after tiled decoding: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + if device.type == 'cuda': + mem_allocated = torch.cuda.memory_allocated(device) / 1024**3 + mem_total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + logging.debug(f"VRAM after tiled decoding: {mem_allocated:.2f} GB / {mem_total:.2f} GB") + else: + logging.debug("VRAM after tiled decoding: unavailable for DirectML") logging.debug(f"Post-decode checkpoint: {time.time()}") if PROFILING_ENABLED: diff --git a/nodes.py b/nodes.py index 28a38035bba..e3e3d961a50 100644 --- a/nodes.py +++ b/nodes.py @@ -1468,7 +1468,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, # Get device and dtype device = comfy.model_management.get_torch_device() dtype = getattr(model.model, 'dtype', torch.float32) - is_gpu = device.type == 'cuda' and torch.cuda.is_available() + is_gpu = (device.type == 'cuda' and torch.cuda.is_available()) or (device.type == 'privateuseone') # Prepare latent image latent_image = latent["samples"] From cd47286e2faf6d6be17f6bac6e0f580c8ea03c2d Mon Sep 17 00:00:00 2001 From: loxotron Date: Sun, 18 May 2025 07:58:11 +0300 Subject: [PATCH 5/6] vramhellfix without precision vram calculation and tensor's size accounting (for full version check optimization branch) --- comfy/model_management.py | 233 ++++++++++++++++++++++++++++---------- 1 file changed, 173 insertions(+), 60 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 27d79545dfc..e5bed8f2007 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,6 +28,8 @@ from comfy.cli_args import args, PerformanceFeature from comfy.ldm.models.autoencoder import AutoencoderKL +_directml_active_memory_cache = {} + try: import torch_directml _torch_directml_available = True @@ -574,6 +576,7 @@ def get_free_memory(dev=None, torch_free_too=False): Returns: int or tuple: Free memory in bytes (or tuple with free_torch). """ + global _directml_active_memory_cache if dev is None: dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): @@ -584,7 +587,7 @@ def get_free_memory(dev=None, torch_free_too=False): total_vram = get_directml_vram(dev) cache_key = (dev, 'active_models') # Invalidate cache if models list has changed - current_models_hash = hash(tuple((id(m), m.model_loaded_memory() if not m.is_dead() else 0) for m in current_loaded_models)) + current_models_hash = hash(tuple((id(m), m.model_memory()) for m in current_loaded_models if not m.is_dead())) if cache_key in _directml_active_memory_cache: cached_hash, cached_active_models = _directml_active_memory_cache[cache_key] if cached_hash == current_models_hash: @@ -632,7 +635,7 @@ def get_free_memory(dev=None, torch_free_too=False): _directml_active_memory_cache[cache_key] = (current_models_hash, 0) # Apply safety margin (1.2x) and ensure at least 1 GB free - mem_free_total = max(1024 * 1024 * 1024, total_vram - active_models * 1.2) + mem_free_total = max(total_vram // 4, total_vram - active_models * 1.4) # Assume at least 25% VRAM free mem_free_torch = mem_free_total if DEBUG_ENABLED: logging.debug(f"DirectML: total_vram={total_vram / (1024**3):.0f} GB, active_models={active_models / (1024**3):.2f} GB, free={mem_free_total / (1024**3):.2f} GB") @@ -726,8 +729,8 @@ def soft_empty_cache(clear=False, device=None, caller="unknown"): start_time = time.time() logging.debug(f"soft_empty_cache called with clear={clear}, device={device}, caller={caller}") - # Fixed threshold in bytes (100 MB) - MEMORY_THRESHOLD = 100 * 1024 * 1024 # 100 MB + # Use lower threshold for DirectML (50 MB) due to lack of empty_cache support; 100 MB for others + MEMORY_THRESHOLD = 50 * 1024 * 1024 if directml_enabled else 100 * 1024 * 1024 cache_key = (device, 'free_memory') mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) @@ -752,6 +755,11 @@ def soft_empty_cache(clear=False, device=None, caller="unknown"): torch.npu.empty_cache() elif is_mlu(): torch.mlu.empty_cache() + # For DirectML, only run garbage collection as empty_cache is not supported + elif directml_enabled: + gc.collect() # Minimal cleanup; torch_directml.empty_cache not available + if PROFILING_ENABLED: + logging.debug("DirectML: Ran gc.collect for minimal cleanup (empty_cache not supported)") if PROFILING_ENABLED: free_vram_after, free_torch_after = get_free_memory(device, torch_free_too=True) @@ -855,7 +863,31 @@ def model_memory(self): return self.model.model_size() if hasattr(self.model, 'model_size') else module_size(self.model) def model_loaded_memory(self): - return self.model.loaded_size() if hasattr(self.model, 'loaded_size') else module_size(self.model) + """ + Get the memory footprint of the loaded model. + + Returns: + int: Memory size in bytes. + """ + if self.is_dead(): + return 0 + + # Check cached memory + if hasattr(self, '_cached_memory') and self._cached_memory is not None: + return self._cached_memory + + try: + if hasattr(self.model, 'loaded_size'): + memory = self.model.loaded_size() + else: + memory = module_size(self.model) + except Exception as e: + logging.warning(f"Error when calculating memory model: {e}") + memory = 0 + + # Cache the result + self._cached_memory = memory + return memory def model_offloaded_memory(self): return self.model_memory() - self.model_loaded_memory() @@ -913,6 +945,11 @@ def model_load(self, lowvram_model_memory=0, force_patch_weights=False): real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) self.real_model = weakref.ref(real_model) self.model_finalizer = weakref.finalize(real_model, cleanup_models) + + # Invalidate cache + if hasattr(self, '_cached_memory'): + self._cached_memory = None + return real_model def should_reload_model(self, force_patch_weights=False): @@ -984,14 +1021,36 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): return mem_freed def model_use_more_vram(self, use_more_vram, force_patch_weights=False): + """ + Load additional model weights to VRAM if available. + + Args: + use_more_vram: Available memory in bytes. + force_patch_weights: Force re-patching weights. + + Returns: + Memory used in bytes, or 0 if model is invalid or no VRAM used. + """ if not use_more_vram: if PROFILING_ENABLED: logging.debug( "model_use_more_vram: use_more_vram=False, returning 0") return 0 - mem_required = self.model_memory_required(self.device) - extra_memory = min(mem_required * 0.3, 50 * 1024 * 1024 * 1024) # Reduced to 50 MB chunks - return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) + if self.model is None or self.is_dead(): + if DEBUG_ENABLED: + model_name = self.real_model().__class__.__name__ if self.real_model is not None else "None" + logging.debug(f"Skipping model_use_more_vram: model_is_none={self.model is None}, is_dead={self.is_dead()}, name={model_name}") + return 0 + try: + mem_required = self.model_memory_required(self.device) + extra_memory = min(mem_required * 0.3, 50 * 1024 * 1024 * 1024) # Reduced to 50 MB chunks + memory_used = self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) + if DEBUG_ENABLED: + logging.debug(f"model_use_more_vram: Loaded {memory_used / 1024**3:.2f} GB for {self.model.__class__.__name__}") + return memory_used + except Exception as e: + logging.error(f"Failed to partially load model {self.model.__class__.__name__}: {e}") + return 0 def __eq__(self, other): return self.model is other.model @@ -1015,6 +1074,14 @@ def module_size(model, shape=None, dtype=None): """ Estimate memory size of a module by summing parameter and buffer sizes, or using VAE-specific estimation if shape and dtype are provided. + + Args: + model: PyTorch module instance. + shape: Tuple of (batch, channels, height, width) for VAE estimation. + dtype: Data type for VAE estimation (e.g., torch.float16). + + Returns: + int: Memory size in bytes. """ from diffusers import AutoencoderKL @@ -1028,7 +1095,9 @@ def module_size(model, shape=None, dtype=None): return 1024 * 1024 # Minimal memory assumption for None model module_mem = 0 - if shape is not None and dtype is not None and isinstance(model, AutoencoderKL): + + # VAE-specific estimation + if shape is not None and dtype is not None and isinstance(module, AutoencoderKL): try: batch, channels, height, width = shape # Adjusted memory estimate for VAE: reduced multiplier from 64*1.1 to 32*1.05 to avoid overestimation @@ -1056,22 +1125,22 @@ def module_size(model, shape=None, dtype=None): module_mem += sum(p.numel() * p.element_size() for p in model.parameters()) if hasattr(model, 'buffers'): module_mem += sum(b.numel() * b.element_size() for b in model.buffers()) - if module_mem == 0: - model_name = model.__class__.__name__.lower() - if 'vae' in model_name or isinstance(model, AutoencoderKL): - # Reduced fallback from 3.5 GB to 2.5 GB for VAE - module_mem = 2.5 * 1024**3 - logging.warning( - f"Could not estimate module size for {model.__class__.__name__}, " - f"assuming 2.5 GB for VAE" - ) - else: - # Minimal memory assumption for unknown models - module_mem = 1024 * 1024 - logging.warning( - f"Could not estimate module size for {model.__class__.__name__}, " - f"assuming minimal memory (1 MB)" - ) + if module_mem == 0: + model_name = model.__class__.__name__.lower() + if 'vae' in model_name or isinstance(model, AutoencoderKL): + # Reduced fallback from 3.5 GB to 2.5 GB for VAE + module_mem = 2.5 * 1024**3 + logging.warning( + f"Could not estimate module size for {model.__class__.__name__}, " + f"assuming 2.5 GB for VAE" + ) + else: + # Minimal memory assumption for unknown models + module_mem = 1024 * 1024 + logging.warning( + f"Could not estimate module size for {model.__class__.__name__}, " + f"assuming minimal memory (1 MB)" + ) if VERBOSE_ENABLED: logging.debug(f"Module size for {model.__class__.__name__}: {module_mem / (1024**3):.2f} GB") @@ -1153,11 +1222,20 @@ def minimum_inference_memory(): def cleanup_models_gc(): - """Clean up dead models and collect garbage if significant memory is freed.""" + """ + Clean up dead or invalid models from current_loaded_models and collect garbage. + Removes models where is_dead() is True or model is None, with aggressive cleanup for DirectML. + """ dead_memory = 0 - for cur in current_loaded_models: - if cur.is_dead(): - dead_memory += cur.model_memory() + to_remove = [] + for i, cur in enumerate(current_loaded_models): + if cur.is_dead() or cur.model is None: + dead_memory += cur.model_memory() if cur.model is not None else 0 + to_remove.append(i) + if DEBUG_ENABLED: + model_name = cur.real_model().__class__.__name__ if cur.real_model is not None else "None" + logging.debug(f"Removing invalid model at index {i}: is_dead={cur.is_dead()}, model_is_none={cur.model is None}, name={model_name}") + if dead_memory > 50 * 1024 * 1024: # 50 MB threshold if PROFILING_ENABLED: @@ -1169,12 +1247,8 @@ def cleanup_models_gc(): soft_empty_cache(clear=False, caller="cleanup_models_gc") - i = len(current_loaded_models) - 1 - while i >= 0: - if current_loaded_models[i].is_dead(): - logging.warning(f"Removing dead model {current_loaded_models[i].real_model().__class__.__name__}") - current_loaded_models.pop(i) - i -= 1 + for i in reversed(to_remove): + current_loaded_models.pop(i) def free_memory(memory_required, device, keep_loaded=None, loaded_models=None, caller="unknown"): """ @@ -1283,16 +1357,32 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu Args: models: List of models to load. - memory_required: Estimated memory needed (bytes). - force_patch_weights: Force re-patching model weights. - minimum_memory_required: Minimum memory needed for inference. - force_full_load: Force full model loading regardless of VRAM state. + memory_required: Estimated memory needed in bytes for all models. + force_patch_weights: Force re-patching model weights, even if already patched. + minimum_memory_required: Minimum memory needed for inference (optional, defaults to minimum_inference_memory). + force_full_load: Force full model loading, ignoring low VRAM state. """ + # Clean up dead models and run garbage collection cleanup_models_gc() + + if DEBUG_ENABLED: + model_names = [m.model.__class__.__name__ if m.model else "None" for m in current_loaded_models] + logging.debug(f"Current loaded models before load: {model_names}, total={len(current_loaded_models)}") + with profile_section("load_models_gpu"): - # Memory cache for efficient memory queries + # Cache memory queries to reduce overhead memory_cache = {} def get_cached_memory(device, torch_free_too=False): + """ + Get cached memory stats for a device to avoid redundant calls to get_free_memory. + + Args: + device: The torch device (e.g., DirectML or CUDA). + torch_free_too: If True, return both total and torch-specific free memory. + + Returns: + Free memory in bytes (single value or tuple if torch_free_too=True). + """ cache_key = (device, torch_free_too) if cache_key not in memory_cache: try: @@ -1302,20 +1392,32 @@ def get_cached_memory(device, torch_free_too=False): memory_cache[cache_key] = (0, 0) if torch_free_too else 0 return memory_cache[cache_key] + # Set default minimum memory if not provided if minimum_memory_required is None: minimum_memory_required = minimum_inference_memory() + + # Get the current torch device device = get_torch_device() + + # Skip loading if VRAM is disabled or shared (e.g., CPU offload) if vram_state in (VRAMState.DISABLED, VRAMState.SHARED): return - + + # Clear VRAM cache aggressively for DirectML to minimize fragmentation + if directml_enabled: + soft_empty_cache(clear=True, device=device, caller="load_models_gpu") + if DEBUG_ENABLED: + logging.debug(f"VRAM stats after initial clear: {torch_directml.memory_stats()}") + + # Create a lookup table for currently loaded models model_lookup = {m.model: m for m in current_loaded_models if m.model is not None} - - # Reset currently_used flag for all loaded models + + # Mark all currently loaded models as unused for loaded_model in current_loaded_models: loaded_model.currently_used = False - + + # Prepare list of models to load loaded = [] - # Prepare models to load for model in models: if not hasattr(model, "model"): continue @@ -1325,11 +1427,11 @@ def get_cached_memory(device, torch_free_too=False): model_lookup[model] = loaded_model loaded_model.currently_used = True loaded.append(loaded_model) - - # Unload unused models only if necessary - device = get_torch_device() + + # Unload unused models if too many models or low VRAM to_remove = [] - if len(current_loaded_models) > 10 or (is_device_cuda(device) and get_cached_memory(device) < 1 * 1024 * 1024 * 1024): # >10 models or <1GB VRAM + mem_free = get_cached_memory(device) + if len(current_loaded_models) > 10 or (is_device_cuda(device) and mem_free < 1 * 1024 * 1024 * 1024): # >10 models or <1GB VRAM for i, loaded_model in enumerate(current_loaded_models): if not loaded_model.currently_used: model = loaded_model.model @@ -1345,33 +1447,44 @@ def get_cached_memory(device, torch_free_too=False): logging.error(f"Failed to unload model at index {i}: {e}") for i in reversed(to_remove): current_loaded_models.pop(i) - + + # Configure low VRAM mode if applicable lowvram_model_memory = 0 if vram_state == VRAMState.LOW_VRAM and not force_full_load: lowvram_model_memory = max( - int(get_total_memory(device) * MIN_WEIGHT_MEMORY_RATIO), 400 * 1024 * 1024) + int(total_vram * MIN_WEIGHT_MEMORY_RATIO), 400 * 1024 * 1024) elif vram_state == VRAMState.NO_VRAM: lowvram_model_memory = 1 - + + # Load each model, ensuring sufficient VRAM for l in loaded: l.currently_used = True if l.should_reload_model(force_patch_weights=force_patch_weights) or l.real_model is None: + # Calculate memory needed for the model mem_needed = l.model_memory_required(device) - mem_free = get_free_memory(device) + mem_free = get_cached_memory(device) + if DEBUG_ENABLED: logging.debug( f"Loading {l.model.__class__.__name__}: mem_needed={mem_needed / 1024**3:.2f} GB, free={mem_free / 1024**3:.2f} GB") - + + # Check if there's enough VRAM; free memory if needed if mem_free < mem_needed + minimum_memory_required: - free_memory(mem_needed + minimum_memory_required, - device, keep_loaded=loaded) - mem_free = get_free_memory(device) - + free_memory(mem_needed + minimum_memory_required, device, keep_loaded=loaded) + if DEBUG_ENABLED: + mem_free = get_cached_memory(device) + logging.debug(f"After free_memory: free={mem_free / 1024**3:.2f} GB") + + # Load the model using a stream for offloading stream = get_offload_stream(device) with torch.cuda.stream(stream) if stream is not None else torch.no_grad(): l.model_load(lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) - if loaded_model not in current_loaded_models: - current_loaded_models.append(l) # append for efficiency + + # Add the model to current_loaded_models if not already present + if l not in current_loaded_models: + current_loaded_models.append(l) + + # Synchronize the stream to ensure loading is complete sync_stream(device, stream) if DEBUG_ENABLED: logging.debug( From c38bb9709613dff7767af0163e1068eb5018d2ad Mon Sep 17 00:00:00 2001 From: loxotron Date: Sun, 18 May 2025 08:00:40 +0300 Subject: [PATCH 6/6] nasty hack to avoid some memory crashes on higher resolutions with DirectML --- comfy/model_management.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index e5bed8f2007..a5a9e14d565 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -184,23 +184,7 @@ def get_directml_vram(dev): try: device_index = dev.index if hasattr(dev, 'index') and dev.index is not None else 0 device_name = torch_directml.device_name(device_index).lower() - vram_map = { - 'gtx 1660': 6 * 1024 * 1024 * 1024, - 'gtx 1650': 4 * 1024 * 1024 * 1024, - 'rtx 2060': 6 * 1024 * 1024 * 1024, - 'rtx 3060': 12 * 1024 * 1024 * 1024, - 'rtx 4060': 8 * 1024 * 1024 * 1024, - 'rx 580': 8 * 1024 * 1024 * 1024, - 'rx 570': 8 * 1024 * 1024 * 1024, - 'rx 6700': 12 * 1024 * 1024 * 1024, - 'rx 6800': 16 * 1024 * 1024 * 1024, - 'arc a770': 16 * 1024 * 1024 * 1024, - } - vram = 6 * 1024 * 1024 * 1024 - for key, value in vram_map.items(): - if key in device_name: - vram = value - break + vram = 6 * 1024 * 1024 * 1024 #NASTY HACK _directml_vram_cache[dev] = vram if DEBUG_ENABLED: logging.debug(f"DirectML VRAM for {device_name}: {vram / (1024**3):.0f} GB")