Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,12 +1342,15 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()

model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)

phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t

old_sigma_down = None
old_sigma_next = None
old_denoised = None
uncond_denoised = None
def post_cfg_function(args):
Expand Down Expand Up @@ -1375,29 +1378,34 @@ def post_cfg_function(args):
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
t, t_old, t_next, t_prev = lambda_fn(sigmas[i]), lambda_fn(old_sigma_next), lambda_fn(sigmas[i + 1]), lambda_fn(sigmas[i - 1])
h = t_next - t
h_eta = h * (eta + 1)
c2 = (t_prev - t_old) / h

phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
alpha_next = sigmas[i + 1] * t_next.exp()

phi1_val, phi2_val = phi1_fn(-h_eta), phi2_fn(-h_eta)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)

if cfg_pp:
x = x + (denoised - uncond_denoised)
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_next * h_eta * (b1 * uncond_denoised + b2 * old_denoised)
else:
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_next * h_eta * (b1 * denoised + b2 * old_denoised)

# Noise addition
if sigmas[i + 1] > 0:
if old_denoised is not None:
sigma_up = sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt()
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up

if cfg_pp:
old_denoised = uncond_denoised
else:
old_denoised = denoised
old_sigma_down = sigma_down
old_sigma_next = sigmas[i + 1]
return x

@torch.no_grad()
Expand Down
Loading