From 4ce6a1dd9978d97c14435dfb17588d69c42b1219 Mon Sep 17 00:00:00 2001 From: Balladie Date: Wed, 20 Aug 2025 17:16:17 +0900 Subject: [PATCH 1/3] fix res_multistep and its ancestral. --- comfy/k_diffusion/sampling.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index a2bc492fd011..ea128425459c 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1342,12 +1342,16 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - sigma_fn = lambda t: t.neg().exp() - t_fn = lambda sigma: sigma.log().neg() + + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + phi1_fn = lambda t: torch.expm1(t) / t phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t - old_sigma_down = None + old_sigma_next = None old_denoised = None uncond_denoised = None def post_cfg_function(args): @@ -1361,43 +1365,46 @@ def post_cfg_function(args): for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) - if sigma_down == 0 or old_denoised is None: + if sigmas[i + 1] == 0 or old_denoised is None: # Euler method if cfg_pp: d = to_d(x, sigmas[i], uncond_denoised) - x = denoised + d * sigma_down + x = denoised + d * sigmas[i + 1] else: d = to_d(x, sigmas[i], denoised) - dt = sigma_down - sigmas[i] + dt = sigmas[i + 1] - sigmas[i] x = x + d * dt else: # Second order multistep method in https://arxiv.org/pdf/2308.02157 - t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1]) + t, t_old, t_next, t_prev = lambda_fn(sigmas[i]), lambda_fn(old_sigma_next), lambda_fn(sigmas[i + 1]), lambda_fn(sigmas[i - 1]) h = t_next - t + h_eta = h * (eta + 1) c2 = (t_prev - t_old) / h - phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) + alpha_next = sigmas[i + 1] * t_next.exp() + + phi1_val, phi2_val = phi1_fn(-h_eta), phi2_fn(-h_eta) b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0) b2 = torch.nan_to_num(phi2_val / c2, nan=0.0) if cfg_pp: x = x + (denoised - uncond_denoised) - x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_next * h_eta * (b1 * uncond_denoised + b2 * old_denoised) else: - x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_next * h_eta * (b1 * denoised + b2 * old_denoised) - # Noise addition - if sigmas[i + 1] > 0: + # Noise addition + sigma_up = sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up if cfg_pp: old_denoised = uncond_denoised else: old_denoised = denoised - old_sigma_down = sigma_down + old_sigma_next = sigmas[i + 1] return x @torch.no_grad() From 8c6c3e5e0405f22f849053570ee09eceb158e082 Mon Sep 17 00:00:00 2001 From: Balladie Date: Wed, 20 Aug 2025 18:38:22 +0900 Subject: [PATCH 2/3] remove unused var --- comfy/k_diffusion/sampling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index ea128425459c..d63051552b36 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1344,7 +1344,6 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') - sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) From e82f886351a12d57a956d894e028417e7ce5e693 Mon Sep 17 00:00:00 2001 From: Balladie Date: Fri, 22 Aug 2025 00:38:22 +0900 Subject: [PATCH 3/3] restore first stochastic step. --- comfy/k_diffusion/sampling.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index d63051552b36..738bb23ec678 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1364,17 +1364,17 @@ def post_cfg_function(args): for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) - if sigmas[i + 1] == 0 or old_denoised is None: + if sigma_down == 0 or old_denoised is None: # Euler method if cfg_pp: d = to_d(x, sigmas[i], uncond_denoised) - x = denoised + d * sigmas[i + 1] + x = denoised + d * sigma_down else: d = to_d(x, sigmas[i], denoised) - dt = sigmas[i + 1] - sigmas[i] + dt = sigma_down - sigmas[i] x = x + d * dt else: # Second order multistep method in https://arxiv.org/pdf/2308.02157 @@ -1395,8 +1395,10 @@ def post_cfg_function(args): else: x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_next * h_eta * (b1 * denoised + b2 * old_denoised) - # Noise addition - sigma_up = sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() + # Noise addition + if sigmas[i + 1] > 0: + if old_denoised is not None: + sigma_up = sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up if cfg_pp: