Skip to content

Add fast_sampler.py with optimized sampling and VAE decoding, enhance PreviewImage #8136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
"""

import argparse
import enum
import os
Expand All @@ -11,7 +16,7 @@ class EnumAction(argparse.Action):
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)

# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
Expand All @@ -22,9 +27,7 @@ def __init__(self, **kwargs):
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")

super(EnumAction, self).__init__(**kwargs)

self._enum = enum_type

def __call__(self, parser, namespace, values, option_string=None):
Expand All @@ -35,6 +38,8 @@ def __call__(self, parser, namespace, values, option_string=None):

parser = argparse.ArgumentParser()

parser.add_argument("--debug", action="store_true", help="Enable debug logging.")
parser.add_argument("--profile", action="store_true", help="Enable profiling.")
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
Expand All @@ -53,34 +58,37 @@ def __call__(self, parser, namespace, values, option_string=None):
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")

cm_group.add_argument("--model-dtype", type=str, choices=["fp16", "bf16", "fp32"], help="Force model data type (fp16, bf16, fp32)")

fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp32-vae", action="store_true", help="Force VAE to use FP32 precision")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fp_group.add_argument("--force-fp16-vae", action="store_true", help="Force VAE to use FP16 precision")

fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64 (not recommended).")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16 (requires SM >= 8.0).")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16 (may cause black images on VRAM < 6 GB).")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Run UNet in fp8 e4m3fn (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Run UNet in fp8 e5m2 (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).")
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Run UNet in fp8 e8m0fnu (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") # UPDATED: Clarified requirements

fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16 (risks black images on VRAM < 6 GB).")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in fp32 (recommended for GPUs with VRAM < 6 GB).")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16 (requires SM >= 8.0).")

parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU (slower, but safe for low VRAM).")

fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Run text encoder in fp8 e4m3fn (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Run text encoder in fp8 e5m2 (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).")
fpte_group.add_argument("--fp8_e8m0fnu-text-enc", action="store_true", help="Run text encoder in fp8 e8m0fnu (requires SM >= 9.0 or SM 8.9 with PyTorch >= 2.3).") # NEW
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Run text encoder in fp16 (may cause issues on VRAM < 6 GB).")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Run text encoder in fp32 (recommended for GPUs with VRAM < 6 GB).")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Run text encoder in bf16 (requires SM >= 8.0).")

parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")

Expand All @@ -96,7 +104,6 @@ class LatentPreviewMethod(enum.Enum):
TAESD = "taesd"

parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)

parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

cache_group = parser.add_mutually_exclusive_group()
Expand All @@ -117,7 +124,6 @@ class LatentPreviewMethod(enum.Enum):
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")


vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
Expand Down Expand Up @@ -199,7 +205,7 @@ def is_valid_directory(path: str) -> str:
"--comfy-api-base",
type=str,
default="https://api.comfy.org",
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
help="Set the base URL for the ComfyUI API (default: https://api.comfy.org).",
)

if comfy.options.args_parsing:
Expand All @@ -215,6 +221,8 @@ def is_valid_directory(path: str) -> str:

if args.force_fp16:
args.fp16_unet = True
args.fp16_vae = True
args.fp16_text_enc = True


# '--fast' is not provided, use an empty set
Expand All @@ -225,4 +233,4 @@ def is_valid_directory(path: str) -> str:
args.fast = set(PerformanceFeature)
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
args.fast = set(args.fast)
Loading