From 7d95a436c6435a651d17e725b5db0dfdbd38f4e0 Mon Sep 17 00:00:00 2001 From: ZT220501 Date: Fri, 4 Jul 2025 10:47:08 -0700 Subject: [PATCH 1/6] Add STORK Scheduler --- src/diffusers/schedulers/scheduling_stork.py | 1459 ++++++++++++++++++ 1 file changed, 1459 insertions(+) create mode 100644 src/diffusers/schedulers/scheduling_stork.py diff --git a/src/diffusers/schedulers/scheduling_stork.py b/src/diffusers/schedulers/scheduling_stork.py new file mode 100644 index 000000000000..6b8b9e2b8b79 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_stork.py @@ -0,0 +1,1459 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import numpy as np +import torch +from scipy.io import loadmat +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from diffusers.utils import BaseOutput, is_scipy_available, logging +from pathlib import Path + + + +@dataclass +class STORKSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +current_file = Path(__file__) +CONSTANTSFOLDER = f"{current_file.parent.parent}" + + + + + +class STORKScheduler(SchedulerMixin, ConfigMixin): + """ + `STORKScheduler` uses modified stabilized Runge-Kutta method for the backward ODE in the diffusion or flow matching models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + solver_order (`int`, defaults to 2): + The STORK order which can be `2` or `4`. It is recommended to use `solver_order=2` uniformly. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) or `flow_prediction`. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + derivative_order (`int`, defaults to 2): + The order of the Taylor expansion derivative to use for the sub-step velocity approximation. Only supports 2 or 3. + s (`int`, defaults to 50): + The number of sub-steps to use in the STORK. + precision (`str`, defaults to "float32"): + The precision to use for the scheduler; supports "float32", "bfloat16", or "float16". + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + stopping_eps: float = 1e-3, + solver_order: int = 4, + prediction_type: str = "epsilon", + time_shift_type: str = "exponential", + derivative_order: int = 2, + s: int = 50, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + ): + + super().__init__() + # if prediction_type == "flow_prediction" and sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + # raise ValueError( + # "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + # ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") + + # We manually enforce precision to float32 for numerical issues.Add commentMore actions + self.np_dtype = np.float32 + self.dtype = torch.float32 + + + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=self.np_dtype)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=self.dtype) + sigmas = timesteps / num_train_timesteps + + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = None #sigmas * num_train_timesteps + self._step_index = None + self._begin_index = None + self._shift = shift + self.sigmas = sigmas #.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + # Store the predictions for the velocity/noise for higher order derivative approximations + self.velocity_predictions = [] + self.noise_predictions = [] + self.s = s + self.derivative_order = derivative_order + + self.solver_order = solver_order + self.prediction_type = prediction_type + + + # Set the betas for noise-based models + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Noise-based models epsilon to avoid numerical issues + self.stopping_eps = stopping_eps + + + + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + timesteps: Optional[List[float]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps) + + self.num_inference_steps = num_inference_steps + + if self.prediction_type == "epsilon": + self.set_timesteps_noise(num_inference_steps, device) + elif self.prediction_type == "flow_prediction": + self.set_timesteps_flow_matching(num_inference_steps, device, sigmas, mu, timesteps) + else: + raise ValueError(f"Prediction type {self.prediction_type} is not yet supported") + + # Reset the step index and begin index + self._step_index = None + self._begin_index = None + + + + def set_timesteps_noise(self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference), for noise-based models. + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + seq = np.linspace(0, 1, self.num_inference_steps+1) + seq[0] = self.stopping_eps + seq = seq[:-1] + seq = seq[::-1] + + # Add the intermediate step between the first step and the second step + seq = np.insert(seq, 1, seq[1]) + seq = np.insert(seq, 1, seq[0] + (seq[1] - seq[0]) / 2) + + # The following lines are for the uniform timestepping case + self.dt = (seq[0] - seq[1]) * 2 + seq = seq * self.config.num_train_timesteps + seq[-1] = self.stopping_eps * self.config.num_train_timesteps + self._timesteps = seq + self.timesteps = torch.from_numpy(seq.copy()).to(device) + + + self._step_index = None + self._begin_index = None + + self.noise_predictions = [] + + + + + def set_timesteps_flow_matching(self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + timesteps: Optional[List[float]] = None, + ): + """ + Sets the discrete timesteps used for the flow matching based models (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(self.np_dtype) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(self.np_dtype) + num_inference_steps = len(sigmas) + + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + # 5. Convert sigmas and timesteps to tensors and move to specified device + sigmas = torch.from_numpy(sigmas).to(dtype=self.dtype, device=device) + if not is_timesteps_provided: + timesteps = sigmas * self.config.num_train_timesteps + else: + timesteps = torch.from_numpy(timesteps).to(dtype=self.dtype, device=device) + + # 6. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # Modify the timesteps to fit in STORK methods (Add the extra NFE) + self.timesteps = timesteps.tolist() + self.timesteps = np.insert(self.timesteps, 1, self.timesteps[0] + (self.timesteps[1] - self.timesteps[0]) / 2) + self.timesteps = torch.tensor(self.timesteps) + self.timesteps = self.timesteps.to(dtype=self.dtype, device=device) + + # Modify the timesteps in order to become sigmas + self.sigmas = self.timesteps.tolist() + self.sigmas.append(0) + self.sigmas = torch.tensor(self.sigmas) + self.sigmas = self.sigmas.to(dtype=self.dtype, device=device) + self.sigmas = self.sigmas / self.config.num_train_timesteps + + # Create the dt list + self.dt_list = self.sigmas[:-1] - self.sigmas[1:] + self.dt_list = self.dt_list.reshape(-1) + + # Modify the initial several dt so that they are convenient for derivative approximations + self.dt_list[0] = self.dt_list[0] * 2 + self.dt_list[1] = self.dt_list[1] * 2 + + self.dt_list = self.dt_list.tolist() + self.dt_list = torch.tensor(self.dt_list).to(self.dtype) + + self.velocity_predictions = [] + + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + + + def set_shift(self, shift: float): + self._shift = shift + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=self.dtype) + timestep = timestep.to(sample.device, dtype=self.dtype) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor = None, + return_dict: bool = True, + **kwargs + ) -> torch.Tensor: + ''' + One step of the STORK update for flow matching or noise-based diffusion models. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. + + Returns: + result (Union[Tuple, STORKSchedulerOutput]): + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. + ''' + original_model_output_dtype = model_output.dtype + # Cast model_output and sample to "torch.float32" to avoid numerical issues + model_output = model_output.to(self.dtype) + sample = sample.to(self.dtype) + # Move sample to model_output's device + sample = sample.to(model_output.device) + + """ + self.velocity_predictions always contain upcasted model_output in torch.float32 dtype. + """ + + if self.prediction_type == "epsilon": + if self.solver_order == 2: + result = self.step_noise_2(model_output, timestep, sample, return_dict) + elif self.solver_order ==4: + result = self.step_noise_4(model_output, timestep, sample, return_dict) + else: + raise ValueError(f"Solver order {self.solver_order} is not yet supported for noise-based models") + elif self.prediction_type == "flow_prediction": + if self.solver_order == 2: + result = self.step_flow_matching_2(model_output, timestep, sample, return_dict) + elif self.solver_order == 4: + result = self.step_flow_matching_4(model_output, timestep, sample, return_dict) + else: + raise ValueError(f"Solver order {self.solver_order} is not yet supported for flow matching models") + else: + raise ValueError(f"Prediction type {self.prediction_type} is not yet supported") + + # Convert the result back to the original dtype of model_output, as this result will be used as the next input to the model + if return_dict: + result.prev_sample = result.prev_sample.to(original_model_output_dtype) + else: + result = (result[0].to(original_model_output_dtype),) + return result + + + + #################################### + # Main phase for the STORK methods # + #################################### + def step_flow_matching_2( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor = None, + return_dict: bool = False, + ) -> torch.Tensor: + ''' + One step of the STORK2 update for flow matching based models. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. + + Returns: + result (Union[Tuple, STORKSchedulerOutput]): + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. + ''' + # Initialize the step index if it's the first step + if self._step_index is None: + self._step_index = 0 + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(self.dtype) + sample = sample.to(model_output.device) + + # Compute the startup phase or the derivative approximation for the main step + if self._step_index <= self.derivative_order: + return self.startup_phase_flow_matching(model_output, sample) + else: + t = self.sigmas[self._step_index] + t_next = self.sigmas[self._step_index + 1] + + + h1 = self.dt_list[self._step_index-1] + h2 = self.dt_list[self._step_index-2] + h3 = self.dt_list[self._step_index-3] + + + if self.derivative_order == 2: + velocity_derivative = (-self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output) / (2 * h1) + velocity_second_derivative = 2 / (h1 * h2 * (h1 + h2)) * (self.velocity_predictions[-2] * h1 - self.velocity_predictions[-1] * (h1 + h2) + model_output * h2) + velocity_third_derivative = None + elif self.derivative_order == 3: + velocity_derivative = ((h2 * h3) * (self.velocity_predictions[-1] - model_output) - (h1 * h3) * (self.velocity_predictions[-2] - model_output) + (h1 * h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + velocity_second_derivative = 2 * ((h2 + h3) * (self.velocity_predictions[-1] - model_output) - (h1 + h3) * (self.velocity_predictions[-2] - model_output) + (h1 + h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + velocity_third_derivative = 6 * ((h2 - h3) * (self.velocity_predictions[-1] - model_output) + (h3 - h1) * (self.velocity_predictions[-2] - model_output) + (h1 - h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + else: + print("The noise approximation order is not supported!") + exit() + + self.velocity_predictions.append(model_output) + self._step_index += 1 + + + Y_j_2 = sample + Y_j_1 = sample + Y_j = sample + + + # Implementation of our Runge-Kutta-Gegenbauer second order method + for j in range(1, self.s + 1): + # Calculate the corresponding \bar{alpha}_t and beta_t that aligns with the correct timestep + if j > 1: + if j == 2: + fraction = 4 / (3 * (self.s**2 + self.s - 2)) + else: + fraction = ((j - 1)**2 + (j - 1) - 2) / (self.s**2 + self.s - 2) + + if j == 1: + mu_tilde = 6 / ((self.s + 4) * (self.s - 1)) + dt = (t - t_next) * torch.ones(model_output.shape, device=sample.device) + Y_j = Y_j_1 - dt * mu_tilde * model_output + else: + mu = (2 * j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 1)) + nu = -(j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 2)) + mu_tilde = mu * 6 / ((self.s + 4) * (self.s - 1)) + gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j-1)/ 2) + + + # Probability flow ODE update + diff = -fraction * (t - t_next) * torch.ones(model_output.shape, device=sample.device) + velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + Y_j = mu * Y_j_1 + nu * Y_j_2 + (1 - mu - nu) * sample - dt * mu_tilde * velocity - dt * gamma_tilde * model_output + + Y_j_2 = Y_j_1 + Y_j_1 = Y_j + + + + img_next = Y_j + img_next = img_next.to(model_output.dtype) + + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + + + def step_flow_matching_4( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor = None, + return_dict: bool = False, + ) -> torch.Tensor: + ''' + One step of the STORK4 update for flow matching models + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: The next sample in the diffusion chain. + ''' + + # Initialize the step index if it's the first step + if self._step_index is None: + self._step_index = 0 + + # Compute the startup phase or the derivative approximation for the main step + if self._step_index <= self.derivative_order: + return self.startup_phase_flow_matching(model_output, sample, return_dict=return_dict) + else: + t = self.sigmas[self._step_index] + t_start = torch.ones(model_output.shape, device=sample.device) * t + t_next = self.sigmas[self._step_index + 1] + + + h1 = self.dt_list[self._step_index-1] + h2 = self.dt_list[self._step_index-2] + h3 = self.dt_list[self._step_index-3] + + + if self.derivative_order == 2: + velocity_derivative = (-self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output) / (2 * h1) + velocity_second_derivative = 2 / (h1 * h2 * (h1 + h2)) * (self.velocity_predictions[-2] * h1 - self.velocity_predictions[-1] * (h1 + h2) + model_output * h2) + velocity_third_derivative = None + elif self.derivative_order == 3: + velocity_derivative = ((h2 * h3) * (self.velocity_predictions[-1] - model_output) - (h1 * h3) * (self.velocity_predictions[-2] - model_output) + (h1 * h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + velocity_second_derivative = 2 * ((h2 + h3) * (self.velocity_predictions[-1] - model_output) - (h1 + h3) * (self.velocity_predictions[-2] - model_output) + (h1 + h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + velocity_third_derivative = 6 * ((h2 - h3) * (self.velocity_predictions[-1] - model_output) + (h3 - h1) * (self.velocity_predictions[-2] - model_output) + (h1 - h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + else: + print("The noise approximation order is not supported!") + exit() + + self.velocity_predictions.append(model_output) + self._step_index += 1 + + + + Y_j_2 = sample + Y_j_1 = sample + Y_j = sample + + ci1 = t_start + ci2 = t_start + ci3 = t_start + + # Coefficients of ROCK4 + ms, fpa, fpb, fpbe, recf = self.coeff_rock4() + # Choose the degree that's in the precomputed table + mdeg, mp = self.mdegr(self.s, ms) + mz = int(mp[0]) + mr = int(mp[1]) + + + + ''' + The first part of the STORK4 update + ''' + for j in range(1, mdeg + 1): + + # First sub-step in the first part of the STORK4 update + if j == 1: + temp1 = -(t - t_next) * recf[mr] * torch.ones(model_output.shape, device=sample.device) + ci1 = t_start + temp1 + ci2 = ci1 + Y_j_2 = sample + Y_j_1 = sample + temp1 * model_output + # Second and the following sub-steps in the first part of the STORK4 update + else: + diff = ci1 - t_start + velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + + temp1 = -(t - t_next) * recf[mr + 2 * (j-2) + 1] * torch.ones(model_output.shape, device=sample.device) + temp3 = -recf[mr + 2 * (j-2) + 2] * torch.ones(model_output.shape, device=sample.device) + temp2 = torch.ones(model_output.shape, device=sample.device) - temp3 + + ci1 = temp1 + temp2 * ci2 + temp3 * ci3 + Y_j = temp1 * velocity + temp2 * Y_j_1 + temp3 * Y_j_2 + + # Update the intermediate variables + Y_j_2 = Y_j_1 + Y_j_1 = Y_j + + ci3 = ci2 + ci2 = ci1 + + ''' + The finishing four-step procedure as a composition method + ''' + # First finishing step + temp1 = -(t - t_next) * fpa[mz,0] * torch.ones(model_output.shape, device=sample.device) + diff = ci1 - t_start + velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + Y_j_1 = velocity + Y_j_3 = Y_j + temp1 * Y_j_1 + + # Second finishing step + ci2 = ci1 + temp1 + temp1 = -(t - t_next) * fpa[mz,1] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz,2] * torch.ones(model_output.shape, device=sample.device) + diff = ci2 - t_start + velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + Y_j_2 = velocity + Y_j_4 = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + + # Third finishing step + ci2 = ci1 + temp1 + temp2 + temp1 = -(t - t_next) * fpa[mz,3] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz,4] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpa[mz,5] * torch.ones(model_output.shape, device=sample.device) + diff = ci2 - t_start + velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + Y_j_3 = velocity + fnt = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + + # Fourth finishing step + ci2 = ci1 + temp1 + temp2 + temp3 + temp1 = -(t - t_next) * fpb[mz,0] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpb[mz,1] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpb[mz,2] * torch.ones(model_output.shape, device=sample.device) + temp4 = -(t - t_next) * fpb[mz,3] * torch.ones(model_output.shape, device=sample.device) + diff = ci2 - t_start + velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + Y_j_4 = velocity + Y_j = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + temp4 * Y_j_4 + img_next = Y_j + + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + + + def step_noise_2( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor = None, + return_dict: bool = False, + ) -> torch.Tensor: + ''' + One step of the STORK2 update for noise-based diffusion models. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. + + Returns: + `torch.FloatTensor`: The next sample in the diffusion chain. + ''' + # Initialize the step index if it's the first step + if self._step_index is None: + self._step_index = 0 + self.initial_noise = model_output + + + total_step = self.config.num_train_timesteps + t = self.timesteps[self._step_index] / total_step + + beta_0, beta_1 = self.betas[0], self.betas[-1] + t_start = torch.ones(model_output.shape, device=sample.device) * t + beta_t = (beta_0 + t_start * (beta_1 - beta_0)) * total_step + log_mean_coeff = (-0.25 * t_start ** 2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + + # Tweedie's trick + if self._step_index == len(self.timesteps) - 1: + noise_last = model_output + img_next = sample - std * noise_last + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + + t_next = self.timesteps[self._step_index + 1] / total_step + + # drift, diffusion -> f(x,t), g(t) + drift_initial, diffusion_initial = -0.5 * beta_t * sample, torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device) + noise_initial = model_output + score = -noise_initial / std # score -> noise + drift_initial = drift_initial - diffusion_initial ** 2 * score * 0.5 # drift -> dx/dt + + + dt = torch.ones(model_output.shape, device=sample.device) * self.dt + + if self._step_index == 0: + # FIRST RUN + self.initial_sample = sample + img_next = sample - 0.5 * dt * drift_initial + + self.noise_predictions.append(noise_initial) + self._step_index += 1 + + self.initial_sample = sample + self.initial_drift = drift_initial + self.initial_noise = model_output + + return SchedulerOutput(prev_sample=img_next) + elif self._step_index == 1: + # SECOND RUN + t_previous = torch.ones(model_output.shape, device=sample.device) * self.timesteps[0] / 1000 + drift_previous = self.drift_function(self.betas, self.config.num_train_timesteps, t_previous, self.initial_sample, self.noise_predictions[-1]) + + img_next = sample - 0.75 * dt * drift_initial + 0.25 * dt * drift_previous + + self.noise_predictions.append(noise_initial) + self._step_index += 1 + + return SchedulerOutput(prev_sample=img_next) + elif self._step_index == 2: + h = 0.5 * dt + + noise_derivative = (3 * self.noise_predictions[0] - 4 * self.noise_predictions[1] + model_output) / (2 * h) + noise_second_derivative = (self.noise_predictions[0] - 2 * self.noise_predictions[1] + model_output) / (h ** 2) + noise_third_derivative = None + + model_output = self.initial_noise + drift_initial = self.initial_drift + sample = self.initial_sample + + t = self.timesteps[0] / total_step + t_start = torch.ones(model_output.shape, device=sample.device) * t + t_next = self.timesteps[2] / total_step + + noise_approx_order = 2 + elif self._step_index == 3: + h = 0.5 * dt + + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_third_derivative = None + + self.noise_predictions.append(noise_initial) + noise_approx_order = 2 + elif self._step_index == 4: + h = dt + + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_third_derivative = None + + self.noise_predictions.append(noise_initial) + noise_approx_order = 2 + else: + # ALL ELSE + h = dt + + noise_derivative = (2 * self.noise_predictions[-3] - 9 * self.noise_predictions[-2] + 18 * self.noise_predictions[-1] - 11 * noise_initial) / (6 * h) + noise_second_derivative = (-self.noise_predictions[-3] + 4 * self.noise_predictions[-2] -5 * self.noise_predictions[-1] + 2 * noise_initial) / (h**2) + noise_third_derivative = (self.noise_predictions[-3] - 3 * self.noise_predictions[-2] + 3 * self.noise_predictions[-1] - noise_initial) / (h**3) + + self.noise_predictions.append(noise_initial) + noise_approx_order = 3 + + + Y_j_2 = sample + Y_j_1 = sample + Y_j = sample + + ci1 = t_start + ci2 = t_start + ci3 = t_start + + # Implementation of our Runge-Kutta-Gegenbauer second order method + for j in range(1, self.s + 1): + # Calculate the corresponding \bar{alpha}_t and beta_t that aligns with the correct timestep + if j > 1: + if j == 2: + fraction = 4 / (3 * (self.s**2 + self.s - 2)) + else: + fraction = ((j - 1)**2 + (j - 1) - 2) / (self.s**2 + self.s - 2) + + if j == 1: + mu_tilde = 6 / ((self.s + 4) * (self.s - 1)) + dt = (t - t_next) * torch.ones(model_output.shape, device=sample.device) + Y_j = Y_j_1 - dt * mu_tilde * model_output + else: + mu = (2 * j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 1)) + nu = -(j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 2)) + mu_tilde = mu * 6 / ((self.s + 4) * (self.s - 1)) + gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j-1)/ 2) + + + # Probability flow ODE update + diff = -fraction * (t - t_next) * torch.ones(model_output.shape, device=sample.device) + velocity = self.taylor_approximation(self.derivative_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + Y_j = mu * Y_j_1 + nu * Y_j_2 + (1 - mu - nu) * sample - dt * mu_tilde * velocity - dt * gamma_tilde * model_output + + Y_j_2 = Y_j_1 + Y_j_1 = Y_j + + + + img_next = Y_j + img_next = img_next.to(model_output.dtype) + self._step_index += 1 + + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + + + def step_noise_4( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor = None, + return_dict: bool = False, + ) -> torch.Tensor: + ''' + One step of the STORK4 update for noise-based diffusion models. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. + + Returns: + `torch.FloatTensor`: The next sample in the diffusion chain. + ''' + # Initialize the step index if it's the first step + if self._step_index is None: + self._step_index = 0 + self.initial_noise = model_output + + + total_step = self.config.num_train_timesteps + t = self.timesteps[self._step_index] / total_step + + beta_0, beta_1 = self.betas[0], self.betas[-1] + t_start = torch.ones(model_output.shape, device=sample.device) * t + beta_t = (beta_0 + t_start * (beta_1 - beta_0)) * total_step + log_mean_coeff = (-0.25 * t_start ** 2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + + # Tweedie's trick + if self._step_index == len(self.timesteps) - 1: + noise_last = model_output + img_next = sample - std * noise_last + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + + t_next = self.timesteps[self._step_index + 1] / total_step + + # drift, diffusion -> f(x,t), g(t) + drift_initial, diffusion_initial = -0.5 * beta_t * sample, torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device) + noise_initial = model_output + score = -noise_initial / std # score -> noise + drift_initial = drift_initial - diffusion_initial ** 2 * score * 0.5 # drift -> dx/dt + + + dt = torch.ones(model_output.shape, device=sample.device) * self.dt + + if self._step_index == 0: + # FIRST RUN + self.initial_sample = sample + img_next = sample - 0.5 * dt * drift_initial + + self.noise_predictions.append(noise_initial) + self._step_index += 1 + + self.initial_sample = sample + self.initial_drift = drift_initial + self.initial_noise = model_output + + return SchedulerOutput(prev_sample=img_next) + elif self._step_index == 1: + # SECOND RUN + t_previous = torch.ones(model_output.shape, device=sample.device) * self.timesteps[0] / 1000 + drift_previous = self.drift_function(self.betas, self.config.num_train_timesteps, t_previous, self.initial_sample, self.noise_predictions[-1]) + + img_next = sample - 0.75 * dt * drift_initial + 0.25 * dt * drift_previous + + self.noise_predictions.append(noise_initial) + self._step_index += 1 + + return SchedulerOutput(prev_sample=img_next) + elif self._step_index == 2: + h = 0.5 * dt + + noise_derivative = (3 * self.noise_predictions[0] - 4 * self.noise_predictions[1] + model_output) / (2 * h) + noise_second_derivative = (self.noise_predictions[0] - 2 * self.noise_predictions[1] + model_output) / (h ** 2) + noise_third_derivative = None + + model_output = self.initial_noise + drift_initial = self.initial_drift + sample = self.initial_sample + + t = self.timesteps[0] / total_step + t_start = torch.ones(model_output.shape, device=sample.device) * t + t_next = self.timesteps[2] / total_step + + noise_approx_order = 2 + elif self._step_index == 3: + h = 0.5 * dt + + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_third_derivative = None + + self.noise_predictions.append(noise_initial) + noise_approx_order = 2 + elif self._step_index == 4: + h = dt + + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_third_derivative = None + + self.noise_predictions.append(noise_initial) + noise_approx_order = 2 + else: + # ALL ELSE + h = dt + + noise_derivative = (2 * self.noise_predictions[-3] - 9 * self.noise_predictions[-2] + 18 * self.noise_predictions[-1] - 11 * noise_initial) / (6 * h) + noise_second_derivative = (-self.noise_predictions[-3] + 4 * self.noise_predictions[-2] -5 * self.noise_predictions[-1] + 2 * noise_initial) / (h**2) + noise_third_derivative = (self.noise_predictions[-3] - 3 * self.noise_predictions[-2] + 3 * self.noise_predictions[-1] - noise_initial) / (h**3) + + self.noise_predictions.append(noise_initial) + noise_approx_order = 3 + + + Y_j_2 = sample + Y_j_1 = sample + Y_j = sample + + ci1 = t_start + ci2 = t_start + ci3 = t_start + + # Coefficients of ROCK4 + ms, fpa, fpb, fpbe, recf = self.coeff_rock4() + # Choose the degree that's in the precomputed table + mdeg, mp = self.mdegr(self.s, ms) + mz = int(mp[0]) + mr = int(mp[1]) + + ''' + The first part of the STORK4 update + ''' + for j in range(1, mdeg + 1): + + # First sub-step in the first part of the STORK4 update + if j == 1: + temp1 = -(t - t_next) * recf[mr] * torch.ones(model_output.shape, device=sample.device) + ci1 = t_start + temp1 + ci2 = ci1 + Y_j_2 = sample + Y_j_1 = sample + temp1 * drift_initial + # Second and the following sub-steps in the first part of the STORK4 update + else: + diff = ci1 - t_start + noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci1, Y_j_1, noise_approx) + + temp1 = -(t - t_next) * recf[mr + 2 * (j-2) + 1] * torch.ones(model_output.shape, device=sample.device) + temp3 = -recf[mr + 2 * (j-2) + 2] * torch.ones(model_output.shape, device=sample.device) + temp2 = torch.ones(model_output.shape, device=sample.device) - temp3 + + ci1 = temp1 + temp2 * ci2 + temp3 * ci3 + Y_j = temp1 * drift_approx + temp2 * Y_j_1 + temp3 * Y_j_2 + + # Update the intermediate variables + Y_j_2 = Y_j_1 + Y_j_1 = Y_j + + ci3 = ci2 + ci2 = ci1 + + ''' + The finishing four-step procedure as a composition method + ''' + # First finishing step + temp1 = -(t - t_next) * fpa[mz,0] * torch.ones(model_output.shape, device=sample.device) + diff = ci1 - t_start + noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci1, Y_j, noise_approx) + Y_j_1 = drift_approx + Y_j_3 = Y_j + temp1 * Y_j_1 + + # Second finishing step + ci2 = ci1 + temp1 + temp1 = -(t - t_next) * fpa[mz,1] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz,2] * torch.ones(model_output.shape, device=sample.device) + diff = ci2 - t_start + noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, Y_j_3, noise_approx) + Y_j_2 = drift_approx + Y_j_4 = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + + # Third finishing step + ci2 = ci1 + temp1 + temp2 + temp1 = -(t - t_next) * fpa[mz,3] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz,4] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpa[mz,5] * torch.ones(model_output.shape, device=sample.device) + diff = ci2 - t_start + noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, Y_j_4, noise_approx) + Y_j_3 = drift_approx + fnt = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + + # Fourth finishing step + ci2 = ci1 + temp1 + temp2 + temp3 + temp1 = -(t - t_next) * fpb[mz,0] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpb[mz,1] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpb[mz,2] * torch.ones(model_output.shape, device=sample.device) + temp4 = -(t - t_next) * fpb[mz,3] * torch.ones(model_output.shape, device=sample.device) + diff = ci2 - t_start + noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, fnt, noise_approx) + Y_j_4 = drift_approx + Y_j = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + temp4 * Y_j_4 + + + + img_next = Y_j + self._step_index += 1 + + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + + + + + + def startup_phase_flow_matching( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + return_dict: bool = True, + ) -> torch.Tensor: + ''' + Startup phase for the STORK2 and STORK4 update for flow matching based models. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned flow matching model. + sample (`torch.FloatTensor`): + A current instance of a sample created by the flow matching process. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. + + Returns: + result (Union[Tuple, STORKSchedulerOutput]): + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. + ''' + dt = self.dt_list[self._step_index] + dt = torch.ones(model_output.shape, device=sample.device) * dt + + if self._step_index == 0: + # Perfrom Euler's method for a half step + img_next = sample - 0.5 * dt * model_output + self.velocity_predictions.append(model_output) + elif self._step_index == 1: + # Perfrom Heun's method for a half step + img_next = sample - 0.75 * dt * model_output + 0.25 * dt * self.velocity_predictions[-1] + elif self._step_index == 2 or (self._step_index == 3 and self.derivative_order == 3): + dt_previous = self.dt_list[self._step_index-1] + dt_previous = torch.ones(model_output.shape, device=sample.device) * dt_previous + img_next = sample + (dt**2 / (2 * (-dt_previous)) - dt) * model_output + (dt**2 / (2 * dt_previous)) * self.velocity_predictions[-1] + self.velocity_predictions.append(model_output) + else: + raise NotImplementedError( + f"Startup phase for step {self._step_index} is not implemented. Please check the implementation." + ) + + self._step_index += 1 + + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + + def startup_phase_noise( + self, + model_output: torch.Tensor, + drift: torch.Tensor, + sample: torch.Tensor = None, + return_dict: bool = False, + ) -> torch.Tensor: + ''' + Startup phase for the STORK2 and STORK4 update for noise-based diffusion models. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + drift (`torch.FloatTensor`): + The drift term from the diffusion model, calculated based on the model_output and the current timestep. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. + + Returns: + result (Union[Tuple, STORKSchedulerOutput]): + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. + ''' + dt = torch.ones(model_output.shape, device=sample.device) * self.dt + if self._step_index == 0: + # Perfrom Euler's method for a half step + self.initial_sample = sample + self.initial_drift = drift + + img_next = sample - 0.5 * dt * drift + + self.noise_predictions.append(model_output) + self._step_index += 1 + + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + elif self._step_index == 1: + # Perfrom Heun's method for a half step + img_next = sample - 0.75 * dt * drift + 0.25 * dt * self.initial_drift + + self.noise_predictions.append(model_output) + self._step_index += 1 + + if not return_dict: + return (img_next,) + return STORKSchedulerOutput(prev_sample=img_next) + else: + raise ValueError("Startup phase is only supported for the first two steps.") + + + + + def __len__(self): + return self.config.num_train_timesteps + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + def taylor_approximation(self, taylor_approx_order, diff, model_output, derivative, second_derivative, third_derivative=None): + if taylor_approx_order == 2: + if third_derivative is not None: + raise ValueError("The third derivative is computed but not used!") + approx_value = model_output + diff * derivative + 0.5 * diff**2 * second_derivative + elif taylor_approx_order == 3: + if third_derivative is None: + raise ValueError("The third derivative is not computed!") + approx_value = model_output + diff * derivative + 0.5 * diff**2 * second_derivative \ + + diff**3 * third_derivative / 6 + else: + print("The noise approximation order is not supported!") + exit() + + return approx_value + + + def drift_function(self, betas, total_step, t_eval, y_eval, noise): + ''' + Drift function for the probability flow ODE in the noise-based diffusion model. + + Args: + betas (`torch.FloatTensor`): + The betas of the diffusion model. + total_step (`int`): + The total number of steps in the diffusion chain. + t_eval (`torch.FloatTensor`): + The timestep to be evaluated at in the diffusion chain. + y_eval (`torch.FloatTensor`): + The sample to be evaluated at in the diffusion chain. + noise (`torch.FloatTensor`): + The noise used at the current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + The drift term for the probability flow ODE in the diffusion model. + ''' + beta_0, beta_1 = betas[0], betas[-1] + beta_t = (beta_0 + t_eval * (beta_1 - beta_0)) * total_step + beta_t = beta_t * torch.ones(y_eval.shape, device=y_eval.device) + + log_mean_coeff = (-0.25 * t_eval ** 2 * (beta_1 - beta_0) - 0.5 * t_eval * beta_0) * total_step + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + + # drift, diffusion -> f(x,t), g(t) + drift, diffusion = -0.5 * beta_t * y_eval, torch.sqrt(beta_t) * torch.ones(y_eval.shape, device=y_eval.device) + score = -noise / std # score -> noise + drift = drift - diffusion ** 2 * score * 0.5 # drift -> dx/dt + + return drift + + def b_coeff(self, j): + ''' + Coefficients of STORK2. The are based on the second order Runge-Kutta-Gegenbauer method. + Details of the coefficients can be found in https://www.sciencedirect.com/science/article/pii/S0021999120306537 + + Args: + j (`int`): + The sub-step index of the coefficient. + + Returns: + `float`: + The coefficient of the STORK2. + ''' + if j < 0: + print("The b_j coefficient in the RKG method can't have j negative") + return + if j == 0: + return 1 + if j == 1: + return 1 / 3 + + return 4 * (j - 1) * (j + 4) / (3 * j * (j + 1) * (j + 2) * (j + 3)) + + def coeff_rock4(self): + ''' + Load pre-computed coefficients of STORK4. The are based on the fourth order orthogonal Runge-Kutta-Chebyshev (ROCK4) method. + Details of the coefficients can be found in https://epubs.siam.org/doi/abs/10.1137/S1064827500379549. + The pre-computed coefficients are based on the implementation https://www.mathworks.com/matlabcentral/fileexchange/12129-rock4. + + Args: + j (`int`): + The sub-step index of the coefficient. + + Returns: + ms (`torch.FloatTensor`): + The degrees that coefficients were pre-computed for STORK4. + fpa, fpb, fpbe, recf (`torch.FloatTensor`): + The parameters for the finishing procedure. + ''' + # Degrees + data = loadmat(f'{CONSTANTSFOLDER}/ms.mat') + ms = data['ms'][0] + + # Parameters for the finishing procedure + data = loadmat(f'{CONSTANTSFOLDER}/fpa.mat') + fpa = data['fpa'] + + data = loadmat(f'{CONSTANTSFOLDER}/fpb.mat') + fpb = data['fpb'] + + data = loadmat(f'{CONSTANTSFOLDER}/fpbe.mat') + fpbe = data['fpbe'] + + # Parameters for the recurrence procedure + data = loadmat(f'{CONSTANTSFOLDER}/recf.mat') + recf = data['recf'][0] + + return ms, fpa, fpb, fpbe, recf + + + + def mdegr(self, mdeg1, ms): + ''' + Find the optimal degree in the pre-computed degree coefficients table for the STORK4 method. + + Args: + mdeg1 (`int`): + The degree to be evaluated. + ms (`torch.FloatTensor`): + The degrees that coefficients were pre-computed for STORK4. + + Returns: + mdeg (`int`): + The optimal degree in the pre-computed degree coefficients table for the STORK4 method. + mp (`torch.FloatTensor`): + The pointer which select the degree in ms[i], such that mdeg<=ms[i]. + mp[0] (`int`): The pointer which select the degree in ms[i], such that mdeg<=ms[i]. + mp[1] (`int`): The pointer which gives the corresponding position of a_1 in the data recf for the selected degree. + ''' + mp = torch.zeros(2) + mp[1] = 1 + mdeg = mdeg1 + for i in range(len(ms)): + if (ms[i]/mdeg) >= 1: + mdeg = ms[i] + mp[0] = i + mp[1] = mp[1] - 1 + break + else: + mp[1] = mp[1] + ms[i] * 2 - 1 + + return mdeg, mp \ No newline at end of file From 7462d9f3847a095d7745b07f3b04278056a70256 Mon Sep 17 00:00:00 2001 From: ZT220501 Date: Fri, 4 Jul 2025 10:54:49 -0700 Subject: [PATCH 2/6] Fixed style issues --- src/diffusers/schedulers/scheduling_stork.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_stork.py b/src/diffusers/schedulers/scheduling_stork.py index 6b8b9e2b8b79..4608506e2112 100644 --- a/src/diffusers/schedulers/scheduling_stork.py +++ b/src/diffusers/schedulers/scheduling_stork.py @@ -518,11 +518,6 @@ def step( result = (result[0].to(original_model_output_dtype),) return result - - - #################################### - # Main phase for the STORK methods # - #################################### def step_flow_matching_2( self, model_output: torch.Tensor, @@ -693,7 +688,7 @@ def step_flow_matching_4( ci3 = t_start # Coefficients of ROCK4 - ms, fpa, fpb, fpbe, recf = self.coeff_rock4() + ms, fpa, fpb, fpbe, recf = self.coeff_stork4() # Choose the degree that's in the precomputed table mdeg, mp = self.mdegr(self.s, ms) mz = int(mp[0]) @@ -1088,7 +1083,7 @@ def step_noise_4( ci3 = t_start # Coefficients of ROCK4 - ms, fpa, fpb, fpbe, recf = self.coeff_rock4() + ms, fpa, fpb, fpbe, recf = self.coeff_stork4() # Choose the degree that's in the precomputed table mdeg, mp = self.mdegr(self.s, ms) mz = int(mp[0]) @@ -1388,7 +1383,7 @@ def b_coeff(self, j): return 4 * (j - 1) * (j + 4) / (3 * j * (j + 1) * (j + 2) * (j + 3)) - def coeff_rock4(self): + def coeff_stork4(self): ''' Load pre-computed coefficients of STORK4. The are based on the fourth order orthogonal Runge-Kutta-Chebyshev (ROCK4) method. Details of the coefficients can be found in https://epubs.siam.org/doi/abs/10.1137/S1064827500379549. From 49faffe7aa2c11bab6c9b6583dda3014ac0f4442 Mon Sep 17 00:00:00 2001 From: ZT220501 Date: Sat, 5 Jul 2025 18:28:23 -0700 Subject: [PATCH 3/6] Add necessary tests --- src/diffusers/__init__.py | 1 + src/diffusers/schedulers/__init__.py | 2 + src/diffusers/schedulers/scheduling_stork.py | 30 ++- .../schedulers/stork_parameters/fpa.mat | Bin 0 -> 2495 bytes .../schedulers/stork_parameters/fpb.mat | Bin 0 -> 1733 bytes .../schedulers/stork_parameters/fpbe.mat | Bin 0 -> 2125 bytes .../schedulers/stork_parameters/ms.mat | Bin 0 -> 224 bytes .../schedulers/stork_parameters/recf.mat | Bin 0 -> 33309 bytes tests/schedulers/test_scheduler_stork.py | 187 ++++++++++++++++++ 9 files changed, 217 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/schedulers/stork_parameters/fpa.mat create mode 100644 src/diffusers/schedulers/stork_parameters/fpb.mat create mode 100644 src/diffusers/schedulers/stork_parameters/fpbe.mat create mode 100644 src/diffusers/schedulers/stork_parameters/ms.mat create mode 100644 src/diffusers/schedulers/stork_parameters/recf.mat create mode 100644 tests/schedulers/test_scheduler_stork.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4c383c817efe..5e16e2a1853c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -286,6 +286,7 @@ "SchedulerMixin", "SCMScheduler", "ScoreSdeVeScheduler", + "STORKScheduler", "TCDScheduler", "UnCLIPScheduler", "UniPCMultistepScheduler", diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 29052c1ba0cb..bd87ecf14e52 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -71,6 +71,7 @@ _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] _import_structure["scheduling_scm"] = ["SCMScheduler"] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] + _import_structure["scheduling_stork"] = ["STORKScheduler"] _import_structure["scheduling_tcd"] = ["TCDScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] @@ -173,6 +174,7 @@ from .scheduling_sasolver import SASolverScheduler from .scheduling_scm import SCMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler + from .scheduling_stork import STORKScheduler from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler diff --git a/src/diffusers/schedulers/scheduling_stork.py b/src/diffusers/schedulers/scheduling_stork.py index 4608506e2112..23c02dad6160 100644 --- a/src/diffusers/schedulers/scheduling_stork.py +++ b/src/diffusers/schedulers/scheduling_stork.py @@ -38,7 +38,7 @@ class STORKSchedulerOutput(BaseOutput): current_file = Path(__file__) -CONSTANTSFOLDER = f"{current_file.parent.parent}" +CONSTANTSFOLDER = f"{current_file.parent}/stork_parameters" @@ -120,6 +120,7 @@ def __init__( use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, + set_alpha_to_one: bool = False, ): super().__init__() @@ -160,7 +161,6 @@ def __init__( self.solver_order = solver_order self.prediction_type = prediction_type - # Set the betas for noise-based models if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -172,6 +172,13 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # Noise-based models epsilon to avoid numerical issues self.stopping_eps = stopping_eps @@ -1451,4 +1458,21 @@ def mdegr(self, mdeg1, ms): else: mp[1] = mp[1] + ms[i] * 2 - 1 - return mdeg, mp \ No newline at end of file + return mdeg, mp + + + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample diff --git a/src/diffusers/schedulers/stork_parameters/fpa.mat b/src/diffusers/schedulers/stork_parameters/fpa.mat new file mode 100644 index 0000000000000000000000000000000000000000..8ed9e3fd820eec5e9d7803c563d820f682d6e773 GIT binary patch literal 2495 zcmV;w2|)HuK~zjZLLfCRFd$7qR4ry{Y-KDUP;6mzW^ZzBIv`L(S4mDbG%O%Pa%Ew3 zWn>_4ZaN@SVRRrtaB?6ZH6SrFIy5;tH8dbHFfuhDARr(hARr(hARr(hARr(hARr(h zARr(hARr(hARr(hARr(hARr(B00000000000ZB~{0000t2><|ioGq1kIF;!e#}7%1 zbTS>u`a_H*nqo#|d1Q$S8Ke=VBvUCnQ9_h0W65$H63Nmc>#>CQ+&DzWc246cs`^R%#&sGYBlDdgP5hb~i@_+t&*8it<{;LU5)=|9t z&Qi!$${F$!PG@aimS=++Pu?HqUqDR8qv2kcPSE$6w!GZ(0t^^};F!q{upMTgyIVSd zxtSl~YSj+D^=a<1`D~!h8wV#!u%Rw|UDm$x+xaq^kvp8!2F#+X7ty%8Ld;EwB{J*p+YJ0>-uX8EQoleHZwBG0?$@V!>Wx1 zJ*^I&z7JTi+V0L>y3K+*yP^U$4;Bys=RN!^STKGzFBae z`lzZ~#nL$Hu&?aBqlKgD!+!`&kKkx?p!FwnZye3uUARcOi=&|f1GX#a(ee#{Dgo+Lt`=-?E}oH&8*92rs)ND_#Z;Tv;e zJAru?IQt27+CtAn z*@!?f?MGF!5P@zsoWg?531lSdB@C7Xnosq7d7i9WUr;=rY(}6HgA?|{hX}NHgePz~ zOdx$t+Hl~H1nLOF4NjAk&t>5%fn*+q(To#U2{bl-U?uq`nZsDOW1iGf`oOCr_2~&} z{_AuC?ZI@;7$y-YPloe>N+;0U63j5Ki3gQq4k0FL9QdGeai5!EKNOrclc!nr!49jh z*2T(waK`xP!Tu9{U~{c3Hm;@*RN0@VYNYzXdV^6^L0CV8*T*Uh778mL`c-g1Zmn|Wbqfdf zy=kRd0tZZKsoL4S9AGCteK0r7fy<3}?CJyuB+AtFeqG>z-W{~nQiuzIelM*f#kioU zd+66oG%iHMP(=%5xiH?^G<#te7qnboy|q;5!aDDOx8Epm;n*AJs>>=|klr5KR;tAX z?;B2>+I?K0EnKK%m~w&9pb`0nWc_Bb$Unll5P4>OPS-6i9I4^(_S3oWVY+uaJ(deA zDURXF*<5(*Hvqv6TsY;ZD`NDN+^=uYO=oeTs$U*8<}xuoiQ={lnX6Xm+%xI_uqrc4fx%ipXdGWZL>hgqbF#xA`4!1at;+s(INh#toN~#bSMfPeCs=&1*4+N zQ4$7pP;Skb;yk2-{4Gnfx+yyR@%Jga%!k>~Ic;#h|7kXC)cnB6NX-S;bD!N>=5wL3 zZwtNc^*z`>X6DsQy$9bUFe8*^?t#z2RQB&7_hC=J>o4L%c@Sc#h*6qq9^8u)Yxp9W3qf7#M|&(7fVGJnKFOoQ zM5n9J0p~1W?%g<_C6x)COUYJ|LK)zKf&#i$?}8Xjd}5S(7qIIIHMFo)P;3l5>Go4H zJQRB}F|#iTSmiT2D!miIr`kTkB$Y`ThDY{ zGH!#?NrhtOiCd6;fpx6fHU@0#@**?>qv3k#LAqP*O)#stM%yfV6E?P9^2`g3g7(@& zHYU_4*bu}~)~3xTe@OU(?K;-k`F#;{8;A7$RJM_2dp z5m!k`ENBBCnTJf*^$2hj+I+pph=(KWt4GD`OB{Vtf4GCzjw6X58+S09aWon_Gcw$W zqf8g4CgyYUKN>e#cdiac+uY^7)a%KAXQ0aW5vk>bcYChZ;^_74V)-6YyWd`{3MDzD ztqS|ezABAF{bM9m)HJS!HIREmtlN1cU;NfR{XIF?d{+arwhBjwcTV*8R^w>T`BLZa znK;VcW~ZV-&fWjst*D8d$GpUuQzCtN47WvH(&uzcILL_1*IPn4_>m-gAmx+=nS0J| zy#>dOtmZvN6(kN%>+X@ znj8s?^B=}how879#VC%}&Lv2uzroQ=TdGAqfg@3x$4_c=ILi6uhR)y=j+$qtb^AZ! J=pRmB52Y!DzlHz+ literal 0 HcmV?d00001 diff --git a/src/diffusers/schedulers/stork_parameters/fpb.mat b/src/diffusers/schedulers/stork_parameters/fpb.mat new file mode 100644 index 0000000000000000000000000000000000000000..497512970ce1bcc59a30216d3f98aeb80e9abee0 GIT binary patch literal 1733 zcmV;$20HmoK~zjZLLfCRFd$7qR4ry{Y-KDUP;6mzW^ZzBIv`L(S4mDbG%O%Pa%Ew3 zWn>_4ZaN@SVRRrtaB?6ZH6SrFIy5;tH8dbHFfuhDARr(hARr(hARr(hARr(hARr(h zARr(hARr(hARr(hARr(hARr(B00000000000ZB~{0000z1^@tfoGp+GIF)xC$6q74 zM9dRfrK{9~W>)8<(a#i-Tahf6Q;U?9BdlpEDyfi0DVNhGE!QlKtakHrTypAI4(Gh@ z>r#uo&N=U?iBTxS-aXs@`Tw5J|MUER&-4BNf5rqs9M&ZWZWDUMfBbjN`rjS7(&0*r z&?ce|?6qyEOcjjuo+HfftHAl@sNtd>6=Vo_d}c-k){9ul6Du`z=fpcQPHLEw6dxxJRKuDz zst2|Q)o|(a15M(~YKUfg%gpN35Iovnd|Rf5L~VzN!)i6$ODZmIc&mn3?Um&>Ce$$H zwrX{j9t)DJ-*Wd_urM{`?2UZ^>E~;ww^fr$1()D(zEg2eY6( zVgAs4Hw#-P9Ro)8v#`IBso z-izCLo`p77Nlyhwfq2!f*ftg}XlZZblPpXu9&(w|u#n}uc6t@f!qJCOzQfO1XxBV- zi=L~22$!XaafAl`u4^Ve%r$T;ebTAcrx(L@!&G4go%sB^JSn=o9fd@FlxNOkG%8<{ zWLo_g^FL`I2EE%+IvzAtVt9Pej>@pJKC(ylRED&NPR^=L9XNMsNW_xz4zzE6 z6Zq;)2j)4<73KJK;wL+Qs5@WSi47Z%nO~OQyvcm+V>DJ?uHO*qL*W&@ zLKKXUXkYDFEl`q}`1HHH#@8hN;t6e4CKS%IFxNHUa>tGB7Or-u&|0{pS1O=z*5K#v zBS$GbkuH`|nG`;)oGp8FokA&*6<|_J;r5UT{!65Ar3k$|+bP_Cu)APgFNISVjmrX* z6w`(mZHf z|E_=Y;;l3q2A*>o^Pw?$i0`mv8;y%si`UN&pixJbsX63LW7U`SZdQIY{t$-NZ?@Ce zFiN@T?x1mGUwdIuD2==Y$rrYAtQ4HSn#Ij=nTx6?IEBU!TVum}&eQ0W@oPv1=S|A) zD1q$hGS(jEsIr3Xgzu4^d18xBm`W3 z=kApOh`d{4cw^fDwA`5~Oj<>RFc5rp zky`Q{0|k=@|M=?`0|KW=(~wFABwzn{M_k50w)gbzsuBiBc}d;Gw+y^uPlcu2V8F=r z%QUZI2J}m1QqeUAc>J_IRvcBXrt35!27J=4!8MLy&X0Qbaoj{ZZFJ{Y>uV>p;#e%r zk1^s{bn!u~rjUURy?%R_aemU`9`h&XYtQq>w&!v83HQxXIF3Eko^Z`)z@t0R(34|p z(CO!wxW0Mm;jc9X4A{*6Uay6Fr&pftdy-Nvwu;=JAa0)D bocL-6Y||&%eN_zfdzAR;b949yC%bf9ZwyP^ literal 0 HcmV?d00001 diff --git a/src/diffusers/schedulers/stork_parameters/fpbe.mat b/src/diffusers/schedulers/stork_parameters/fpbe.mat new file mode 100644 index 0000000000000000000000000000000000000000..4c10da31c43fc11dfcf4e81d386784ed4d939a67 GIT binary patch literal 2125 zcmV-T2(tH0K~zjZLLfCRFd$7qR4ry{Y-KDUP;6mzW^ZzBIv`L(S4mDbG%O%Pa%Ew3 zWn>_4ZaN@SVRRrtaB?6ZH6SrFIy5;tH8dbHFfuhDARr(hARr(hARr(hARr(hARr(h zARr(hARr(hARr(hARr(hARr(B00000000000ZB~{0002R2LJ$goGp@jIF*SO#y7de zDC826GE$ViOIsbeEQ1K;5=u%^q?F{+Cb!W|iU=8{M2ONwxg-oL4x=<>XnfyhE0=F) z``E5EM3c@w=gj)&x8CP{)_UGQ-bD=2$S@dE|HuFM-;w-B8~wA3F(ep=PagDNDBAy! zW-#c>=5G7-r(ja1vpH|c1bnkjOxkyT9NaS+=Tp7M!Qj!_MDB)hm|hniIW;y0=C<0s zYqQ57f92~1y&J|rLf4TK-7pGgcAa=&?K29`oTdVv$dAG+wcLSCr9zn0$u<%O2*I{r zaxPa_2n`WOoDT6upfb68O15YO6qnZwoDLoV@3N++EVB^^XKfxHTs#8jY#cH;f??2O z$0WYZ8-`ka?+SJQVF*foHX30x4Ezc0Ekjbn;1Qt(M_N!w6$oufUQ9adl3k&w3gHssqdUIyBUr8FwR9X*eA( zXgRishL@eSMlBLF#05&qmkbWU<=Wdj3wwv)fN-yu!uugGbWGJ0BZi<&W2COBX9)Vg z?__;_Hv~I6KTI)iUX?&)^^Xxokpxaj`h1s-6a9NvIX?Xhft@AMzm0_wSf;(6*_}jS+zQ>@ z&PX78Sz)YMf#@zY>|ZJk1gbZ-Z*~v*6GL40R#&L=(AW89a-%us?NSZSvtC@@C!VuVvlp9N zexN5KKVYBuc$Q$|J^sQeG;j%hj~)X9m2;naP$EYt$Vllyvq{$lZ`So-fI&fr_&Xko z-D2%XKgYue{cRDaOnDe^OVN2>eK$Un{+6k|s~Zb~lz1=ZyO5_f@TRe%3&U*Rq{?!- z@ZY;k>^nAy`ZH4}|3I<}4O2bib^Ey3k}XZp^<3OMeN|g6lZ#asg$E)-xkz61^|SEc zqK}$~RjC;lD`OVRo?I>Blu8G9(Bv(c;lBk&HpnJB0L}P*KY})T6LTKwU{}vMGlf`$Zyd+V7M?8C56^UPF zQ@yBa68)-*cjz~h=uK>SZSkDMl#aM)vQE^$>khY+1foV*OHbT1?}z=ygF9D-48ZP3nf1>0>@qQYw^SbA4aPhNr#QgbTnY}fIjw{qX9a&taB*^-cw{1YD>jGZ>LGmrjN4(V-);2b$E@{M+*LyH@-00L%~!~jd58g1&`9A zEKavkVC5|P=wUMjMbaC;oNu6jlNB$nT~EQOQk%p%QGSaoS@q>r6jZ!l8YwKN;I<$( z&fqBpGr{H_{f{Y#`8dk&ct`==Hy9H5fC3x4{izN2DM%~1#BwO4!1Hm`pk@ID%QJf( z)aO(1E3;57=?(>g?jo~~BAuO7rm#z-quvVT9*X`DVvUO@MAY_9kd7{;;M2*ym1GG8 zEGd?4)FTRf%bSjyigNTZ3xXbt{Fd;5HJ{2T$PoU;hA*|y>859*vabUC_xZmWSYHZy z%g-sA9=Qjd_xIXMhZjJ5OkO%GB^L(cRF0kI8Fcna(h8z3LjM87;oS7wAI8s?R3odvw53<*v=y@7nMm2;&Jx DQIqsD literal 0 HcmV?d00001 diff --git a/src/diffusers/schedulers/stork_parameters/ms.mat b/src/diffusers/schedulers/stork_parameters/ms.mat new file mode 100644 index 0000000000000000000000000000000000000000..f1bb4d41acd757836f46e349dde9063c06295f68 GIT binary patch literal 224 zcmeZu4DoSvQZUssQ1EpO(M`+DN!3vZ$Vn_o%P-2cQV4Jk_w+L}(NS}aZCnTI;NHSA+ z)^JJT0M}Dvrj3F#(vnW-rZG$kuoE`s)|R$rR~I+u*O%YluwcT5ju|aedge4u>YBBz gZ{Ng~ty}kQ-tC_d;S-eR=$d?8q+=(8-4w1a0FrS(<^TWy literal 0 HcmV?d00001 diff --git a/src/diffusers/schedulers/stork_parameters/recf.mat b/src/diffusers/schedulers/stork_parameters/recf.mat new file mode 100644 index 0000000000000000000000000000000000000000..8f672cab4830dd3a4ad11642455f4145de61b979 GIT binary patch literal 33309 zcma%>Q+Fi{fNVRqZQFLzv2C5$C$>8_IyO7DZQHhOd%pW{=XvU>{y?p&wQ9=?tH}wA z5OXjx6Uz&$F_>Ec&FG1hfJQFn_D;6E#7d%?(h8jH^u(f0W=1Y%ro{Gkyu_+TF2uqP zPQ=6<#4K#Q>^!_2?8L0htQ^GuNB#Q$0+E(S`p+i^{QBjQMXw3Pjs;jf&MgpQRY6gq zmM72mVvk1`$N9_Pebe=_?JY4jS_*{G7V^Hi6D>uSM!$cXe%W%q_MCB^7C71d=$`T1 zImxFaAu)O&BUx*~qmlY%`o3BJ_(^kdZ@~ZFk&-a3u6i^6VFWx=kkp;2-e;JF7f z->~GiP6SqppQx&4xU0|KV9^Y;>w(2l>BRlY9Fca7hLcB6R+vVGRK}Z*s2z8vLLp&M zB>6k0fXM=MGYO`m^1RHFNRi@M-|2V#pSc1{8{Y0zcESnre|fDNH)0oZkZP0Fr~bNA zk`u|qA=l-aT0734D_TGrmuXsb>sQnKsu|smrWZ8ySI?fNoXLQ-Q%&iMDactR81k1e z39`oeU_T{t5N|0v%<>JsUQwG3!=C~4(Xx^)tU^VIr#`Z5CtfS!)RiBP*p6Qi_2bql zLh~f6q(iveYOq`0HS7MeX2f=lckFdfXvdjr14I~R1y7G;9S7%6X~emO2p2vGAWz+9 zBK~lpQDQ_~gt5QhC68QEgIH#HfL8+Am%rZpkD#Wxhsd%bj=%i&)GFyJ!^$=&NhG}^ z15}!SVMrj!OFx&0M0RLbgxA3z%!O7uVlDjkJ&rd%?l0MTkc_knUijNPey^m{O6cOa z7~yiyba$9BtR~{krV)$jwZdk;Kah`H?)O)Le{YNB#1}X61;Xme+nO}~En#1EsuPTn z?p)@FJ~njE^FT|rXzQ~C_3QO!XzvOFsu;2Y^wi1BQ_73t7sW;0w2Ikk`{Vs|fe9*h z^#I-XoCJi#6ll=M8N}yis!gN_qHQ%ejRC}xtN7BZbZVO|+cIIo)!0O!1 z$B2dn)lOD&L@6||Tbm0MFX7QZ2D0D~cHqt^g}K&OAcs!;BD}N_g=)9}dQ+L!y0lhd zC?w->Y(ZApkin0pltB5|`XDyY-*h_9oiWJSYEODRjl1{6$1tDV}7>S*R zsxT0ln5@#9eR&gX12mlK*gv74grO$Gd7B6cgL% zR>gJjUrCj@fs=?)b{f)AdzS+g13dLyKUOuXb*GW1cSpU+>xBg)CTi0&iPzOvvS9{? z{rYIK^F@6DcwmsQV_||cH3P<=U^-!ul6 zScqcvdvK$<&dVRVx*&)#SsW0wLiS4+&A7^H%tqgVS}s58*q;(S!(;f?yILIrKezha zdgjJl3t9_O?xr1tZnB{)M}d%YMUQz zXp$>5fa+?Ae4ggan4J#1d7Ge09EiDAG?oly+6B0%=^TjWVCmiu4q2t73{vsWQL=^J9Pe3Bm;FH=f0^qUYRb~RC$sXw4Js2!5sI`CK9Pst!q1_ zzuO;mnA1wcM)zMhe#gaDq_HnJdPW`cHwpHVwxJlF(HMO%jzxl%86L-U(7QP&U$e4tlrUyeg%T zQ0pzUaxv~N49!6{SBPa7_FGP~ui=11%Z!?H@@EOA;3U|QP@S6fP;qMV$j#*b$*`?i zk~*-W<~>&GQ6$UCM9`Z89jcrgccJ5hc=gI(NrD2N4mF0M+6nA--Qi0098PqXMxN|0 zM)=Stv$Vq&H+EnhcarsxOlzZiTD}AYL?|sG&VVq?eNwWbSA&n*tP)05wzZeI!1=-^ zfQP3{B_Nw0)H$}cr;}3^O-w&Ryg!KHcEm?EB(N)jyXP8 zByop~Drq{ZJqcjI9wOu-9hQ_WYsEX4-f=^%vW2{3s;GKtmW}x_%!3vxHQq($Vrv0* ziddyigOCP5{HiWGT#wZ6SU-6PtgvHAYghNx8~ZDgd5nUK<$)Z5q?}hpnGfz3AP?`o zLC=%xLx0rkY?N^f$8Mo~r{?TAf}P}Zx6&SD>9y2rmcdisyjA8fyDVrdPCS+O_Y(vf z7nM<89>EkZl@sCQn~~j3aT$p`mZ`&2np3jSHho>h51GrrpW2b5KQ_OTGVgO3P9F=t zdX~M2wtq)E(8pEyUx@bney|r2P^vASjSXerfxDzjdgMglJP*lAC?Xj>g zue+9M`16%Hqll21*A+adJhPde%NbnqJn|X89H!6Z{!jz>jtvW#PFWfX4ICV4NfI$f z)kbZ4iAktp{^9mpRhlI{tY)Y=%f!$3-BCz*Y-c0y-a$*yl^)_FwMy(yqtj_}1w}}M zbClSSs@9nxAoUuLgI$~u{$5TnE1U7_wodKW?;U@lcm);EF;e%*S#SfVA%Be-B=sDu^sZTL*? zj!@PcZ=d$$V1;h3(`UHYI_&q8hh?uZFulY78%rTZ<%@Ge~Gcz;mo0Zj>zo9UD0_l89;c5`X@F^E_wc%vkEF%lJYf_H)`cxMC= zE~gjbh8;1pN-Bu-`g{BHHL&Xx+GT=wLVGugNSk27CZ*6auZ`v^`2sG?GXLL=xXG<) z+t9d6AJV(mTgn&}74rMF6V68-5i|ya^L=%q$7*6B&)PxUzqa6{jYKgYAHM`+O_#kj zj`b{{ARKtMZc%oQyO_O>i#kad$MSS!T%z*70PVn|4uc4WSJkP=s~XI}>SV>Ut9Ahe zyOAh(w$qOV-iRU$c; zPa{=hD5vv$$lFV$4nzg1^B-rolyf%`T&P)k6(!*v>dzavJn~~!8SSzEhhl}yHADqeafuF; z#2D3HhxBKw`bFYU-;IKVkFy5Pk9TQxQz(C4E#ygVZ1o=Pg+$^F6h7PUV6L1 zWtMeB09eD9Y5TNP+0n@H0b{T?{silSTR`sXhA-Zu0QTKjeATW4#7&D5RCPT7J!4S_ z3C3+}>f?`Rtd!7W=T~x8s@_~9g+m9W{V$!$oA08Cwty7&-{c29B$b>z=kz0&=MWNq zskU1}mE9MbB#hS)suWm0r9;V99UqYHg)U+h;nAdDx+ME(d|h6f0{94zs-!c5WAf#t z9ezBT`Cc(L%EkspSh_IY%q+GEN+z|+?I9F@bL~q^@30Y`VqN)oJ#CYd!l}fX-P{ug z_!&o5+oM*1gz7v#5!Pl;Mpn9!7o{3>2bSF3 z(c$pVoCRKlxHH6bqEu2S{9FBY9_&KbDQSjKwX}NSRAnwUDu0{y`iSPSL-zzCL|}<9 za>R}>t)WJ8`OWN1F^u>|N%bVDE{~2{)&q)5#s5_F(8LuY9G{B%;*#O!^@*Z(KP_Y4 zL@`GCf7J4dDO&~Nbvhwold`-RkOKWrA5_*J3Q7DW$&wsaD^_EN2 z+hafx*_rWuq;m!x&Pvv|h?(2K!w$4R+g=$SU_rnuTlk!m-%nGgjE-<}!QjNDpi;npRTlZFB zw)&I8b69U0wUd-2s<^~G$+HdWntG2LgS29N+x6=#Uqe=jpl$1kya*&8(*O}b*c4da zt0(rY2wF_fk*Ld8>1aPwG($DYoA;N}q$$M~tc5WK$>P!NV>L58;LE((5~xxi(Y^@w z)l?QGT5=Wtv!{~&1|AO}0+k6BJ5ysSDP5|5xc8lqXV){A- zJXpPSPiGK~2Dbc<>tewGP90eDXGjN30q+*WrC^BN?AY*8gu=@CMpfCtWOfvUTg%kg z;%-%oQVY70?Di>-Q7JJ2Z*DYEV999s{5(G&`s}e@^;nyTOj+>*SnLextO1y2XbG`S zR1WJA;INT$1zPuSK#n3dz(0O``I&67A(r;Ak)*8r)zQjkr{M5Ch`S09o@s1@h6j&B z1u=!wNoFbiTSQ?HgbJ~w&53B=lD)vcW)s6yn5v_G8@H7vqYx_YcHIcY8N)ai@+3ke zn#ib-VWc0$DC>2tqRD%(pSC6$8QFdK?tipEyC|K)?{sN4c`ZL)<)y7(zPduZP(U!= zZx3|u=LFM3>-TTTK0>2wpHHHmu+~I847!O2eRt3^(pds5tj8UX_#qM-8*sEByhnXw zmyOxZT<;T+oo~4=|5Era{Ia>Fd~en5eHrf74kZA^A6b3iQ<8- z)fVnpHBZ^-gcG>YsW9uP{5_$+)l9S6!1r?o82E%;*d&?1Z(;9+58{L6JmXKDF)J+OlPHgcqiw?)9}McDpPaY zliw=J2jsb2sK)(C#i-uzl!1@g=`j>)IktZ>xe^3H=O^1_e=)*K7m z9@aHKD6Ujlo!1L%S3E(5npoFX3w(vMkA?uZU98hY3{`pN3Q>0brz+(_%w_~X^V(US zJu#1o%=;c*rE0F*hCL}TOW!k6yvM%kj_;=;5t8lsnzD3Q9HRld+E49_BXh_jg(3@~=axKb5eprOMSNYbv8b9= zUSzX@Pbv2)e5j=)Se&bl-Nvp;LM$tg>3@v<>3jmam?o!!Lq-u{%`5xW49_CBWO185 zDHPuobFd!9J8=cTV)d{@jndD{=`U-eU$&2u-mjunw{)_v>tuB#{5eYMnD4eB7U^n= z$zP$$cdF3~4-p(w6S}z_MY~r-o%oOs2*PEIL!^Bf#)g@@L|5yP+SsC1Bj?i9=96>erCmC|OGrZO9#lUID)UuI~59Q4$cz-|2s5s%v)1mR=&K!e}0~>?( z*YAi|XYgL68-b^aPMvm;e(p^!zd<_@cb3B&T9x{`7Wa{-iIFPiAxAU-G`?I)7TtA} z;Wc7<@cD^!Sor=j5R#-7Y-;}fkq}c}HHl2#gzFvD?tT*A*x8`{$Q-LJPS542P3kmB z$5QAAJ{=?UZPVO5xVnbcy)~$uRZ5wm|78k8!lS>O7Q%AreKMXZY!2}xBvA+He4ymt zXMpFp0_)~e7P=;YHd^hoV|!r@(dJyN@q$5>|JNa=p?ne6SuVoOv#G<^Pu%Xe#XGqsrF9-s=1#17=&u#b6N6_syFRK<4X^p|NKfrMNVM{0R1aWtSVmC0Oz zL|QMi28t9bpSB-Z{ z2ryGD^X*PzD35DpD5CTib}xeEai3s}onJM*%@J?>CTb{?r{Is-El*q32udnRDZp$% z%keFmZI)Ubd~+Sp#XFgIXik~r@X)~HUm++CUV?T1>WibYec;uS( zJ{RV@7>bEzYkGY>Qo|R6SCTUD3NCI3fimOX=mOyEpTsoGKg%%?vsAqra#;lDW)n>e z*0}XfI5H)2hQLri%pbROcEo4OU(dJ9pw~$K=m|0`1YJX9;PX+b)*5i+p4Fc7Z-KVZ{CEEBT#Pq4BJu zHC<-|$Ox$NB%5qu41o#HZgIP6yG?bwN)P(kFBMlO^VG)x0I++>1d&Y}Ad%f;Qrg6k zvBCoc!5JXk-(BaX*1pftlGe%N@v&gEvIZ@&ODT41*Eak+)@z|}i%J$F0r0(HsI{F+ zFqBUV_Ubbn?Tv;=(?5JLV?t$dmVJPRs-l?YtHmZnBrOo4@3`nikPUXQ9lxDI5%%t4 zGgsO`y?7^+E#NASPuu;mWt>GFOh~)7Cs8%*L8naDHjOY)&PUy$EGCxjPR?LAjF26I zphaR*2y*gxlU_}s2vu)_CX8KJ)m?L-3EqN5DYj+3t6K1Q|GeB+mZp;vY`T0BoIuKuvwxkfWY2LQ3hy2cZ4t=9fw={K_%QG!EJ?m1C#EB08iGV&6Pt) zORY^OFSJdFY{i4&yB&nQ1Q+d(a22FxE*oouRJ_2U-L!HraUyE=@AHi{>^`N%1^V;^KSv+g036Z_7NV%w@d;C$ zG8*z_LzQ-k1acN|@9wFSCTHIJWA<-!{!hU8Fb#pKkLGHQ6E8PnVfYEGj&yM?%4_F> zJWGHwhHJ5CcJVG+hA} zQp)ZLJOGtrGY^|qtuAf#kb*GJxx)g$$ysdj(^G_%E5zn7ahnDISP5}Cg=#NbBJ%el z!!h6S>n4^W9#AOeGJCcXLfv_C{DT$u`aJb2#s6C&_L$nMf+oXBk^f$8~c%IrUrh^x$oUQ8)qQ5CZ z=SvT>RE5jF*5?UAuYhZuI*Ea}#knu`Xp#PT=b7GIVSVTfOR7R>y1)EAC~px7t~AZ% zeb8f?KRM<#+!|^&HnlS4&|R8_*8fP zWZPe%6f$F$(sP^>kxFVUe?tZ8k*0aLW93OiBGqUObX5M9nkFS9*VC?vJPU4U$P_(g z3+49YynKD-e00}Tn8+7v`*EPK(z_NQH2qp~cQKhq7<*kNbUq_1gh$7d#Nnv+&y!ZhfW*Dv#yE*GE|$szF?Hsl`dVZ#d34&BUeWSy zkpsNM{a_+(n*Xg5Yf|3+g??=nuB;Hj^Pa9t`BLF;17&Q!$Zq zn>{{+hwrUQfnJ5V+b`aAa@Ce!MAS*Bs45|0yWS)5Z%rFeC~sS|_z^~(F$uU4bGJi< z6}bh$T(P^9_HE8nt_Jh~ui7$)GitlZcO=51kN!aGekOeO$RdBuT)C|SsQL$7)$~2( z%eA-bbcI$TK;>EjRxwv-EGwr#-8#}6neEk{m5GPTz6C5(jnJU{KdXqIVP*gWF8?q8 z!!3Yuo~3Zm(J}QxqQ$4tk`Y%@>j>=b&mlw(Jd4ZKhlMmTIg>a9SuJ z1L}&2eU$%_(kSN4;Z>B3wk&LJb@@d~nXB2@Rbp~8R#kC;2%dNPl|92YfurQ1f44D4 z`|D!sYR@eieBUEi@Ju<2_g!JF~wM#DTKns5TVdsSzHq)Wtw(A2op2R7gkvj^< zC%e9)@^Q0Rvix2;?aX~Pa69XxZ8UFEwSZH9!K~dsuD;dYq>Br^FBVY_9c88MY<%;V_@Rxz0~-c(g7?sWUd36%!SW5Kp4lajI5uYYD8eij5Csy zx~18x0sy*?t%$(P|9Y)_U!OW>cJ=5fJuusZRBH4(^jPN}Fr$Q&bkgziaCJHqZ@OUJ zQ$sjNV0KHwhP4W#S zv4QW23{IUL=&Y0T*4rB=yIfkmThim628~q{HF;lE$w}X zpy{6G+2NsR`ke%|6W#0pFwBW(y;9fmrHK}QwD319`~SGT_$?3>bW5Ro#Vv9}n!C@E zIxYI5JUp*Oyt>0@Ejwy>+TKLv{N0D6`eDrj4&C}I zx*+-Qc~_~WZ4}xuZ;)=z70fD>5D_=%~s@>TN6bU}rQZQV#OL<#JE^|CerukT9 zf|`>R1b~vP&ZiQ4=Ro$_+ziub4CT4}5^j02 zxx5+sn-b}g!0#NP<7HXgLn4pvZN_~V+w?h(6xmZ|-FKf>-(0G}2~0|X=u?M;)fmG9 zs$e2n5BpAOr~{c{!kDO>RtDqZ_sXbPoWw*~By;cm52$qpHWo-Js;3q#c5l+F>ZGza z7|80KnJC)KXK=oxhfwEikoJ33mFRe^;r_sJXaxAIBv{YNqTP`Gcg%hyLQ26c0Ubv3 z(-5Hunu!$~*1OZGWPEQVCh93J8Q_DF$eqr=$SDXWWPM*rLLp=0$_EP%sSK>7vZ8Yh zeYwLFfm2?+qmbB1O@aMD`dXX{8j=CPDEF~;yKF85-V=B>(NCu#hCr*2Bs$zqs5Yd3 zQKwP5GpD^M0vr@?m|Zo?!*ESogRecr8g|`NTiZbvSMW}YDOKXyZu**`PyQqx9Ki8i z6#-ajPrt@9b`TUrQY5qy>)=qp(TU6{=&+!vOUvC4qDW+Wn4L~T$wUoRTbeQ}yV;#q zjSq9qwLkAhVB_O=Zh_V%gwXFUk2~E{H`G-nJ)Q{)+k3msn96Jx_c)AhwZmk#ila#) zw008ykKu#&8ozQVZX9^^c%e{>=1gyfv`BJ#av@{Hf>n8mZ1{N2E?h2SotlS<5~q;s zoqWC=Y7%yd5a}KMiq6~FLDNMr2p|N&x+{tZm?hLJxD|v*Txk!GkYrf-9L!HMCUv&J zIk{2}^n3TmX!T$HfO&r>b0x5BmgHNM&%;r?`C-$x70Af{k(WLlP@QR6Ce_qqLCssu zSX+s6&4`<+57T!9u2rwihy;1uFAs4F&mpL1uuk~XAsD&nKfFdEa{sNJWi2J!LxR#TU*zoYC%E#2N&X*naeFKv8#mbUIH8CT`b@2N3^pUk{`wuHIA zO*2-9rO4nkH7`bBs+h?%^&ecrjvctQyq^F*Nsi1tzUt^yf?AS5r z9cwho4#YnsRZ%hn$2;>1{y>`7IvuLU(&$iFoso1^WppV)pO%k+Ujz;3wqe?e)P`NO zKg)K~lIa$;%BBJkfbWWCfmX?|jzTE`qrkxMz^p0nu+!Zcj|L`_A}Ez5kgx}w-k_ZcD*PX9r$n(&WTv1`^@)5dPAxC$tp7rSiq_tD>=cegeZ4TRk6 zO-t6sIEf#AAH_il{I>_vUXOX4Xo_+7C=;;!@rj=IYl!}o^ga5QyxXH)aptmnvL^jo zawF@3#{to~C&u3E>9wt5^;&hO0l;(jZ?jJNO%+ddo%W2#f&fnx7IjP)9nhB4F;z!< z4Qol1bcl*Ssq1rIYKN&aJ{9d#{YL6b(H7ve{YMEwFbMSg7d-U57&1d6Td`!Foj_)9 z$`?ZE0#c3#QKtk%JQ$*2G|GuG;yVSRR*#67`Sxe;W;t4|00!eJUm*4?3th}m>&FbK zG_;NLE%|AzXHXlJ8zzH#d&-FLWL?ozmYCwKb!7{OLgVJ?C#Ke)T@L#|iY=hyQACjbcMGeS zJ>Z4jp+lhCRXQ2^Y;ix%)ctkq1d+Q#PGZCOAEJJ<`SC>v^r}jt!+xtUTC{m)>?Qv- zcwx!CTRp&b)?qKoxScQP^B<+7t^}WX2~bs4K#y2YW7UW6ADS(&NVCC7%(Clxi;Q|D zBqo88Yuye-nXTfJvc3Rs0J4Pf*q^LF@63t0M}q0|1owV^N=hnApeLWa4ykKoW*{3>#H_@evcKyprxqUOGD z8RN2e{$i1+P^p8QtQ}$ z1tbk5V3_gBr`J*R9Ib6ggcYoCj*jpKaE|!Jo;XK=gwy&U?)%(Ko@%o~LBZCJGfHla zJ$30Sxv!kKEGr_{*x;liAH7xgkm<4sDRAA54u!!Lzt!iK(5^}_C{|z@<)c-~7cC0> z#^5&-S-(h^ap>gRT&4-PKf0D>Z8NTCnj#1rxJo%^FxwO8`H6kkyhfh6*>zVo>dQa= zB#DD`XMRLKzks0!N-bDhAhj2!&D4c>riVP(99J+-w-u`JM;5$Ah6G!c?(7yChS_Ed zr4Ng|*HsAS;L)s-|D(-s%0PfiMGvzXN7wtVeQ|Ce#N&{>fu4~D^RV;zap#<7x*WUmkP{A(F1xe_L;NRNU0iLO-+Ry? zVX{buh;9rSZ2qmpoNVHurDloWausq?%pst3H$D9R?SH_G{B*~&;sgIOPk-EP%Mrp- zC0D7=fXW5&+RfFf)64e!IIapcUu%0=`IWchJ!F1XSa0(TY`qUW8R&GizP}gz;a-BkvHmo(9`ikw)sPzh zx(EzVqH4&j^kWBjDK|Q#K^OWw-*yfXCsi_53=N|yYj<$dcv(UdHMV;*t{~UB%gwXQ zyI?Q>b2*F86$9_-`l6n20jF3n$_|MnKUHsj^;zSMe4c!uOU#lUC1jbL>|lwk%!1g; zvmsyI@y;ZEwm>JWqd-`Zk-|pPIe!Y~4^>Z{TDBvgASrHg>a25vKc1Gqk)9aU+NrsI z^khS5v6nI%ye2(h?$K~`Kt_1sHG>`+Hr-I*6~HYPVQA1Oh4W7=C12QRM)=O`D+WTSm->+ zlet*y8wo5YyR4*#?DTga&bYRc>URQ@iMZUCL`w*Or`5NzXzI~Sh50aw}~h zNr%23VY%XmEVCnlWXUCam1QcQL=agABSN%jc)gT-I>#T|qNsseRZ;5Pwa)uYFLEDO z&2*Y>wxM@#8MYmy;y$a5cdXEH%+)KmyFh&h5aWzw{OoP*g6Uc2ySo}zX!_B&snOg& z*3C&%PXkT0%pLXeI|I1oXNuc{Y!|hb?_N>5jR@6{T5`(r8N{+!kL)}aQBxIMX9P&r z1H?TWVN=&`%#MxSZl|7ncAeuG3u`t=Rxze6@1qLgKCt(ZxeX%(zF?|Q$&A>zp~Xsx z1ihs{Je+Uccvtm!*+N(D!BN$-!yCL^wO)jG*CK^vV&?PO*Jr+^=-+Z_oMd?t3BUvP zEu{3<1M~w*(2Nz0Uh@;;&k%crZvIp7(nbI>H!FBtleRiy~d!^hLP4e_Vqg1 zX|&k=c9p%YB=K~vpwi7AJW&Yz-iu0MDIB{Hk=|K!FyKnSHKNcVM((r<#j#IPKSEe& zo24?@_nJx4A>d1h=Ll?(#Z+H9P48zAf@gVXsqD!FyKIA*ug^_bc`QYTkOsHiJNNNN ze`4#Se-CYjPpAdQ0 zHYYWh`bTfT1_ZwI>%nzx5E_Ne@nzL#2B($g&Zy>1la!Z0L-w8(XZG&x{dU0#bvvrO zd+hJ7Q(;QO+mJjtV+&9#iN(Trb#i_`2^7Y&C*~ZrFrB$IemT=KJ{Y5rD^9lCRK%o@ zUuG%^T*ocWcLp1cFU`wHw@mEmaf2tgW7egx5y5u?W zm4ycv;zvoUSQDR@%k{pDz+P094&UQhymn0sU zVf+gZ%<*CWNzXu^~y0Z+FpHbLNst>&cRF5 zM-BxCO_)k{v3=O_umut-$P^qJbhiY zb9xFoPORSkj0B-Q%)4uPlWk}w?`ueu1Ogf!ryj}QEH*M-)%w_fqLu%C!AVDHfINZ*JK{kq0qba&AxjIlUVGv~L8!}JsfrhFcnEU9O)BC+SjSz%DVZS7W* z(|}SLN2!e)_``8iP}7vb4A(5=lV+Ju;0d3jUx@a?cTBU^M%&v~N|%1TJ-P?PciEl4#96(wdR_TsnG!FcV^D2Ys@-S(wZRg^*0{lTqgHc%6b!f#t1PYPH;( zkv-bBrIwu4!b!`4@znbuz3Qo`j;47K!M}eI@efE2Cj@#7lwQ@o?v%8P# z6`^M@qRM$3VncThI##Tn&}!LIi}+`S2SzR=FM@)zl(vnVI8FHoHsGL_NyD-Xy%fnV zgOLdEJlF-TKe+P)Mpl{+DG2JRRQF92vzYJvV*Z;-ytKu|3I0pmdmqx>w{O1i_^p(O zZ0y&;2fX`$^OCl|QO0#6>YSw;p6Se8n|V#?LH zB)f@hf^wW#pGVAfX=!@Of+x`egu|Tt#`OH756Pd5r-73UWgnZa+OMlWFOL8mtjk#) z3!lYIhz_i_*g6N0h7A+(CkD)D)8j_JwRXde?RjYh960Ni9~XfKcVXHZ_@7;BxUBnR zCRN7v-RVBr5b;kSetP9}uB=;MK$>j{+TmKO)i^x^qf6goaI|3{3m%Iiq9Agqdmq@( zbtOQE)0zQ-lYzh~U?p;@%^kH>ppHSI>AL}rP&Wc#aCrR?C(6RfD;q&9o|b{*Qg7s6 z`(sdzj6#rsrLg2V*|)&k^AmzfD3{(vt`s81uYU{e|!1ksg;ygF-UG zpLBgS*;g(jvXF#{QeqZU|1#p_;BqO|Fvvps{gLG6$+2QVq*$Mo0Bq%(U90&_-vGN$ zV6o*uCzV^o&4v&gsq>cJzdte(!|#0PuHS#*MaKFAGM*q~2iX zm`3t21%IO8>DV%9E8WPe5;Hj$BP*q*=}&6$S95s;@8o9vCdKm@^wFp0)^Q^;?2_Bi{L`8Ee*6mO%qhqN(_v+fZlO z7rH-NO##O%6m%SZkD_i@a*qw-c+lZ8_$kVSgX6E`qV=53QdU0hHA7C&R;oV|jm|(4 ztpuOCd0T8;JeU0gOk0%yh#JIX*1&ttrp&UiE96Y46eoEj^byvHn(xSlfsM5l^&i66 zeT{_KdUbg>rS9iKv!GpUH`IAnb%@KpmsDh96mCeiE2>Y|_Gjd;Y>K)f6Zof@XP`Eo zbsSbPe-peoE@N2bsMB=a$Ya|HUh42XG53>1C`a#6es}esWkTG~EA&iBO>RSAs>7!( z4lnjsPUy)ZL9J8pVJg1V$bMl!?aVY{t=96lJZdSA$H-4-s$UyBlP^ zGyZR!yEIR~@Bn65w#M2ikV*>!Z!xkqKc6?9#a#<1W|J~Ogd+G`rbHM1o?#Rm&z2SX zxBlKD{}RTOE63nf!24k>J{lB8t2di%nPuMY_-S6qgvYOOY0zLLaT%fR^`X< z$0-RSv%_{jSo-=a`R2MAvi`Q=F?{NV0^IaNzf!qpra!{tV6hni{MCH(QAlk7v^W&} z6c&7Qx8eOIY$$NYmW;_)A9UX&Zg3dJiO2WHV?GxR{9i319Rqb8i9Pb8FlYSMYMI=%_}lpjI%}%l+rE(iK2m(3+oG3yj>K+LH) zam00(s#ug07;Zn8z4~GnsF(R30i*+G&>GX{j2Zy`qOSdVFWGUh&v}gSI0ho6hYV8L zD9_%|8rF!Px$~{T@wXP8EtQsn7v7cK-XG1gzt&ahXxF!c=G$MBeOeLDGz(iWC1Y%M zLW3juRG+mv7nNukD~ViJg8Z@LH|cFDFIh?py?%{cmiSMyjha5Rj{Wr~Z@BMk2jRCx z)d&~v{WKeO=~$PJ;ZHdHD2J}0{`D^%+CwnGUb>?^G|t}sb`g&EYcNXFx;<+*=h30_ z5pfb=Xs32DUr%-S$d^b`Na5YHLA{DUPY*hOSKp7zn&ksQ_Ni)?bC%>q1E*{?f!v`4 zshEV4AsB`Su3K*v=Aq2CKTdCO_B^B2G>OHK!`#3#k!@40sN(rvvZ2mm?p42wfq*k! z?4%w0wM;|EnwD6z?`HW{iwBw*s%q zQ^-gspu?ng#X6>#aDUZ zJKGp&OLAQg5}M$r7UwM#2lOnJW1URFmEIenM`zQ$6F;{i(ZMZnPQR!+I5mW7?(SRY zwbwCkR7!{!(2HkO3;z0Ep^{-JEL)xSUo#FFAkIn+1lo(KhG{nyMiC>>@Xuyu8Ei_l#4LCgA1Lg5s<9|S!^v*u80LA z$6Cten|;{DLcryl-{blj>>7A`coAp>KLp|s1pUuUYP!d zPJ<(w9W1eAbyi=fKSYPxswbhJ^|3T=w$-WyDFj*Be+#i3IYR&3+)Z32yHPu!MROox zkMKpf1x$Hg*6P1O?HFHiH(s5Hc1XDy>bI?Uo8u5{sXg<&Yx5nBHAB8&&+XmYR}7Vr z^R;#m>=VWn7wVVnyK1EqY@g&LE%{52yl}(dWxKn*+O!T_Jkt~8I4%s&th%)pbz1N{ zBY|ZEplpbxu8CUnGOjJf1R#Gkrw+3|TM}*ab#f*LgdyzkGnQ@??vCo6eQ?3!$2bBE zWNIb-J@+5>ci??^PmPVGPq6=o=X-F zB!i{Dwz#EVZR&fs8f?0dl7&T>I}7C79zCz}{rz|bu%St#2+Y0UUj2RjEw;9Yhi0sL zrzq-m>Daj?^Us$9uC6*{!=Uzx1@GF=;R^(lzI*fQ1w$ePIuv=PXZVtI?Hqf<{katu6hEEtfer>I5E|*RRS?mG zY4FTAy+OE$ozp6m`;VZk>3Yup060L$zhL7^UWdHL6$jV}ix@^kuyN_YE1N5~9I%d0 zQTlo&8-~ezxk5e27~+^Uh}A~RAw&(_Nkq`A~BCv3RzIkfY`Y`i0aecpIf-r%CvSWaoI)Dgan z6&D%3T%0Si_s!^a#JOcRO;nU8@Y7?XaDv7meD`%@`Crz9qU^R;mfE3c>{xucC~1ft zM-FW|XOIJD(Ufa%D8`uP8_BvG^cY8FvV{Ixm_pU@;oGmvpJ61DR7n&t$6tlY*i1q_ z(p9B*vf6_K&|Tb|_7YJmWoKV0y0|v_6Fvh z8-~>2X%Rj;8N>G|P4P?oQtVF650s_vz$$jwJ;KPt`|fSC3-fru_;J zLMV8$WXFe$z){SYa%)l_QDCx0-v8a_QA}q9uU_orahE7jJab?StW(oTD;FpbnRvw) zaD5CkzjC2TIV!?xwq|FKj$vh4=CNL5Dn6$;sn)BGqm$yR&h??29ZxI`$qvA#kFBiXTmSI!YFF*q+T_>o`J%KW%99 z=gV~Tyq>ZYny13;v#?!%18=>=*81{VTM%F5?X~|)N7;+2A+EA5mW16Z57^H@(%S>F zhDNpspq-%~In4lHY+!MYqb_3`O z_-BEE*E*ROFQnPR;q+AlRV5~NRy>*xF0@6R)3bXoY?;s~P4;SeVT&9w!JevcCRTgy z(K+?O7Ot9Ck83<+BEmy|)rk>Xv?)fMu=>D6l_-(hIBAQ;@4|JW(@caHn&>1ip&_X2 zs7`GbE}HpdrEQ`iH90Hs&ruc{4%{^zRHk98%bOVC2p0I_ z*o{S6G&H>8vZb2$~hP`B_>oZ=?Ij7hExA7etk{diu?r5SReOKnAUC+u>xF=b`nVI1uX}j!imbhs1ERU#h=3IFPb(vVmZSg|Lc4mo{=Cf9lDV(j&aw z4cq>t&4`Q1ZOfYS{Oll}9{Zr>EEfYSmG${A+d=YyFY$057yFH)f>m;Ox%yVN)^~)9 zG!>s+kuU7vS^MD1tlR{Qf+I(q2JO&$$ny!?a{>`Puaw@KXhKe^vpgnD2gSdZ&Yso@ zh2dCa%yt2N7(AK_*D=gN^P&Prf!`3v&RmUq6aSy)nYC>DT*`_@D9Sl%T{= zOb|^J?fz_8k7MBruKwkw7>OCXDfXiYNe*3#DPPTSI4v*Q`*0if{3y;nAz*=fK_Y>x z;yyq-^ttQ#6$HG{ARoQ`v;!BzyJas=5OB|u@_t`@C!%B9Pc}Wb#M;oT`B8L1UV6>G zkHJ>xJNxErMo%}4k21EX>WhypIz9+}5Ch%gCl{lbrEk46w4qu1-9Ny3P) z{IO+{qv&5ddTXUC39UOX&%F;B#ifk4U#jsW2+dx)tu!=>+wples~(e(@w!1QRBH^I zHy=0Y{X~L9;i-4>Nn;2@(Z%xLB;03c&Yok8;oRzTJDVlR;2*fQ>zVpET;SfKxu1;A z)$a}TBE}K^<)XI&nT#H1Lw=*?ag?rjq@M3h1~GPb6Q3v@8`f0c{TxR|y3aaBsWlyx zz*U1?*<_fCuct&L&=Dz{yt1TM@pBoge)fNMeZ;(9R(4isBTxZPtQvQ8(#SdN09lDD3bI~lkx zXQfoMkAlyUVmYo$d7P6L@^vt#;HqLN;qe|O_SPRK zPoKj*3e+q%*86W}VWW@%-)bn{&f06s*hk z4HLV>LQ^WCXHPc;QwPc&H|DXxczG>$;}`|r!^?AbHL|cyDxoa@D+ON?QU?g*EL4c^ zCH2ix(B|VTdU**ODg!svDwpwcYuowm2^lsRyMZY8ZhNO?{W6SlMifEF1uDzPNZW0$GeoDOg1uzq-hriD#Yy+ z44Qcy^jkfq`^|;QyElo&#%noPEcDTJK21gCkH_PoN*t8SK7L+zmWr!~R?lA2J+ zFFFog4tWfTKaMw4Gz@UJmi5Cz@EPj(EjGSOGK6dDznY!a^(b7`xXg#g-2y|Os^Ona zP`v)rg!9=LHwD*rzB}3meP(0n>W?P8`zjx~^6Cc&RC_nF%1oiCRCDcgbq9*JOtd@= zGK0N{Y%VRK6Pe0p%|>eGFwB3cci5l{`^OSZhx*LHT{`99^SK+P=1zGR&RKxkA2EK{ zt`|a*vI-ST2vC0j;*#fm*p@&D{u@O=W14ivWubojv6yJQGEcy%c15~`!2pgLY}PyG zX$kd@Gle9NLFf#fJnYtKiCW659eqJVIFQyH?W}Hv8n+Q2$uM3o)}4C8cf$&%gFmqA z;0PiVsI?>GJkIY-9Cs}mLCM~GCDHq=vF=iflkWOaoRLmD;1XsH^M(068WE$YA219Z zYp{lDc7xB$(NQpSM^>H?uz_Y!PrHup7gsj-3Qn;DsT&2jX7SvVc@otG278T5xSiRY}mHSQ}PA(<7qhkTUlgg^tZA!RR{|q2)=0o7KH* zP4DSA@;X;YH=KyK511Y?f(&T5SGyliBI3gD!ICjB;Fge6S9za^zr@!^SNJm!RN*XT z{2visA?>xI4;ko?cymnV4G{~K-|uYbW8hC+VU|e`5kkioZkP!&A-}MF{vDHu!t+GQ zC;OSW`f%^yvuPq^+)a9XU6}B1=KC$ZoCG%6s6H*83Bet{wqGSkusHB(zd;2P!Y6YA z^fr@lp;1>ZqMr$C5k2yHWfIC{SIeK^WAS)Ra}ML>Tqyt4{iWMkQ1xC{>~BPZ>?g75 zljbaxl-?a&P9&k_bI6f=PZsXa|2eqVi3FAA-h*$i@%D$ASN=Oe!gnDjVMY-PlbbG1 zpFcywnas!Wf7@7ykUQd-5=6p;!sl8AE(?8XvCB`ykT7keyWCZP4d2;{epUhrkJPD( z4f1SoF6pPprIO&)D0qepHo~NwYSv|w;JjA5T*{G+T#iS{n};Om+GV@Sp6Bh4yiuve z%Uyv5Qm$758;8!O&-+!AV6){aO)QU%*kqH;Gp|U9q4a79*Rr9!o0xa7g@mS_qu)Hc z*tq#oN2R@;1QWZg?FYZG5o%~|t@W#4jjhSr(j4qw_CYg- zNrL5uE7q~QIoK2^*>mPA34X8}IcC7Y>~Q9%L%&E^>zFIv!{h1>S3Psi90^~}?a%zp z>&<6(Zm0_7BcrB=5z-dMfi+*`&c6a=+%o;N-z|yP!^1b2pAjQNHRb;Hjt3l+Rl8sQ zyp{}CBl~HEYTkN;rHRo-G7L|-EcJiO!CfQig-SUxROSM{^$l|%?`eAG;0`i=GNaYb zOmUDqL-^jOLdHkms8)LcE>tfoUA@oqjM4IxscN2|zPRS<_)wdSUso&_UaE32Qa;u4 z+mMVeZz2@#m~fH4Xg!!qAfr+4Mcj=eT;wM$1UFI0DDinN)a%bh^!Dy27Y~zh&rnFS zk=NU`)4JI`C&+Lz>=P|5D<3pFrD^IF--IWH5d#lq73RAdy;r=uIveQ{(Cv#N8&4_hh^uA&p`WTTHr}Qt$>A~<`QrbhIXDEwK4moYB4+isC*gl(j{O`ER zVzave41@UQiNBg4=OUy3-q8@Je?Q=sIk%x%A=smCzY)~Vi+uIC_5pFbd;1mVjd1F9 zBfGDr1DlFp)&49t#>gByMCe8*{OX%`CJ{|QAN@+F8g;?_8T+FI!vy5Ovv*1cx=~|E zm{#&OMPR+C=1#|6{MldXPT`tjec^apXKf#n%}?BGIbw#ZUe^5iBK??2gU_92Gx$nw zX$m$Tz}s=j+s}8IV^+q+QowT%Q_)oG?{Vhv<2#n|Icx|jZVuwCL33O!+Li1YHH_*h z&L15W3v|5v85eM91UF(O*DMC=WS#uB)@aV1PfhYvU;9$TD62O9?o- zd(G0Rs8K{8QwjK^M?ghi*5kYMQN;atwrMJu!1GH_F$4WEDC~S&QdCZW!|nBf8tG%O zFF$F1l1+f=h{x=wFJo|S>L$2su*CJ@MTwKz<0!kGuB?2?cuW4K!$F`eWR}EaS zL_?>>rkM6|h%IOrSLay5le&Wc$~rnu3shfEXthLwKtgJW9UZNk-HL*LSfXa{vC70- zbi7U9)_Pyk3SYzKbf-Sj(YU~E->${Wt-`vq??o8s+E0*?J#2;T|D7^6H)5bkKB_i5 z#0tlc>@155WFWm`>+|MJE0k%c{VphCz~o}SUO}}LlHJIRjzI>hpY^`A@3umWl-$bo zqD;_U4j*Iwu!5kX7c)bPiGAu%GhIcjk*Az4bH<&CO?oMe;ceEaFq}v$O=QAb;bYV( zU2F8#y8Y)<&BV4XFP@0ftWiaib$&I%MDvYTjlEvhIB9ftPYpi{-R67Mx}&TS<#{^p z^ez_m8vV^%m1d32(|3x;tXbIOcwZ)|$QqT0Cr9@Bv7kc^p}W>uL-U=Pk9!i2q=hIvVlX4O!Mz!Y+Tq- zmNa2*1J}V@6^nsvyjD<(AGYV^xN>^C#sAni9PPQs!@~x`TV`ySg}mPXq})Hw-v&Y& z*{gbAv$5?)(^lhX8)SZK&r|GYb0-d{UQy%npKWkl-e`EzhXe8*4$F&S1KG~;>^(dmju+K^5x!#U0hqA46+ zKN2`*J8uKON$u*J`5e4YGWze|QX<52u9BO0-1fh0&V?WmJ9plCT+zmX!Aa2+A5kJ? zcTL{%806sFuJ|K2RuhrOteN9}lCazxlMJ2P2ZdE8#!zkF;v5mRz&ZkwocQRg?&$FEF8332G~dt)vR3hi;%+e-v# zKDRy1j*D=okb9{Ih_IYl((w8e7Y!+QyCk)V@QmE#?H9tuZpqDcE&4nUv9Vgw!sB|- zsUs~fOo>PrnLIU<%Z2(^FXg4yym?|KUhb*pVp4j`)`wIgJaQdY7k%JDVcXd9*G@#- z>03Eg%;F+(u;A25S0ZkSA5RVAo50TDNyRdrNBu}Qe(fteft8^Qvx@UX=(yNQ_8*$S zo~rww62pk-7`ppi?#Kk}9wnW1y+TCbpwZEy$O+6lXmnVl5K$1+?ZtXDfoJChTTk30 zg3)guZ_v&2_ceAJAIpe%+?nuwm%taaSN8TAg{nf<@TYZ+pgNBJJLu}E7mEGZRJLcU z2I38w(npMQU^{%Vp`iC5;{TU>L_F*G`go?xg zVmgx)Qtl)gz9PAl)7SuJFQl3BCa zk)dvAe(T6P6k!DO29ieZ;a+^c{%g47s}b%bJ+YM$AI%1EF5CZ=iGT?tH&nGh^clpKJ3E%~+neBg_(y>^QA5b{U2~h1Yl4B_ zhbA*F4dYh5t5X@r1lb}%nN>O?aK0|=mA1gYRs|jTYzmt+iuHNcC~TST4}`?buysjUotDuUJ|^rB8q_sI zves-D<@Oj{G{1JN_BMlef#%h4`pT@D?TlrP6xH(3L#y1ex({Y0+TyM0`9K}MB5^hdB?vE@z>Sk*W zZy)jRsTp+afR^rzzd6j8IVP;{p@XkA;F?s5Irg8uwQY?A1B3h`8iP;Fu^{vzztD^U z*(!78qaVz%TJWEsKo|pwVK1(iaLr+-`t!2MV+L*tD~^vWw?Kn>-{j>n21dV&9^~I- z0i~@c+qe=;v~vrrRSsBSzEvyr5}24>_PWf&$^ruK6n>reWFod`uKT5{1@1`wYkQl* z#KfmHI)=O)9*#N}kXy^di${ThiwPE3-K#My!ek<7(TDUW#{!MVWHL62u;AajqfM{U z0uQ}S1V>d_a2{E!JNSn8etm|gHeSirw)r?6BS z3$r6H8kR3uK#Wm&+Tgnse9Pq?shOZY({|BWdkcs0vl8LLdE(t0$!L(CyKbU z@hJMNR;?=mJ$pkZWg>VSS13&CKEvyy^Z$MHOJ_s)Z?^1>Py*=E0m^HhumNvh!7o<` z;3u{x2)45EbNfY$A1MSZS^sIL!zdeVf|dpMatN>=`Y|Lt%SMCV`TAwW1O!j3K~apy zZE0nG=_&%aR=#PE=74ifSiMiT2S zp%J)19UI}`&fVo^_vI`R{l0GNqaPduiGEELSF(h_Zr=y;OL^Sy>3UC7v&0x>U6%c7 zE^ekvzYEo}MD9&{q75$>4p*FOxoltw$yIKJ;%dCy_!;OLVqpoU$vW42`dnOpbgt?M z*%Ik5zsX3DxCq%e6Qt&7iBN%Kx80q&VB3X?@*TH?Z2ZJ@i7ywm&Fb0hrz}ypsjYb? zjEkeG=Hy3ymiTXHD!W~lfcX1$ho4qiL3@jYmXRuf@@uAE6Ys3Ra+r008U)@gkGB5RX@xW1@9vxL zByjsc^2_fdR`4m`!SXpoV2)elFOlz77;?{3zm`D2{FQ#1rKmNwPUmhjyhUKeY+uny zX=|+7ud(pVTLOxuyZcm?t+95I)%*_v0=|Lk`=t!6aelDNF+y5^jqC13c5Ji8jeK|E zRuci7?M{7bP*g)=Nbgil=_ce2#ha|9hQn=BL4e?pHvU~(N%S#t5m{r}p8G);Qo)pG zw<1cAU#!UBe$~e#zi3fOi6^kq+j>#0(GaV7Y8%CDUtw>^o=lk_V~A}R-`Y9(8bV{L z&$Gr%Q2%yRSIDygOItO!xKjaz^M%fd$xR3!-JyKIfC7e})`aV`7WmV9Zq`&&@P6O> zQ(_m|5W8`iO*WMZi8tTZyfg2>UER@v_%bRiy^f#C@Be_}w<|pdSI`jL)EINkwG%6M z5j8atG{pMb-D#};gjHevD(gFGSQGYX%{r+r=+|p`_Nme_DqLV2ZP^WlVW07w5IPo# zoR&!o?16b}%vcw`Bip)!1Vw@`c^;`T_leS{fuYBYr6(GTVw zk;v7@&9F?A{&GWUKd#)ZHSf4@hB+5SD*RUsz_YZ+#%$OOXB1uOO|b(=y0SZ=T!w*% zn!8xfc>`Fi+%e%tW8nAVIM?gULGX$s^Rf>zAQqn6*m-pj*6J0d5tkTv_s{rz`L{tF z(C`p5dB#9@`+@m!v?16$%pMdTWPsW*FgrbQ2<;0uBz>F3M1fzB;_8n>SYEZ@m8LQi zH`!vz9?OR@q1QaWi@`+4cwAbF>o9mxhbWSsOuS<%&+E?~hN7kK&HhLx0z`@0u+PI_ z#(3{Nki~?x;L%K?%m`c~>1nm)Ot2W&(kNCVND@uU{@27rrn~UjvZxW5OExn$a+w&d z*)?C{;RtRk5Aj_7F)_J;Gq#UCg6_eES=qAYn2O0Q>6K+;SNlz~f2!tqd-G9iGMx>x zZq?`KVGe_J_2a1lY`ocSIkw&29QVGKeG$%NBPrUX+s(@yUUTg?j=o|;*yT!1#W8aP z@4O%9%V$GBS!cukcyokMC9j@a$U$kU1QB|Lte2Ol^VH!WC*!u+M3Ff(!_#&E-5$}$jWhH$Ec&hUA3_XU6H?23vPEsxK zOnYVh^{ZSwnwTe4Wo3a3@4-y#M_hQy%qhIT(*i0s0Y@#H$hD58qn(EZ>Wzy7^4MHV zeq=h|I$!~xTAwYu{%|4uv97Kn*aDN6)cj9K^N?=fkZ~%~0-H`eo1CM1?)G1rs_v6T^F4i_yD{;P#?+KvZmr&06XEOOsEVKU5%hYpVIe7yp)zZg%!WGD~! z@19O=y<>q>+2u=T&haoTeXm%*+yY~xzUJ{cJk)uAPV{?De$P#Yd~pd6u7yWE556YP zw;WgStmdI`Ux_89(E?3gg*D~vJXAhgSG5M{bVrkozm6z`e%;>+1XCOeT5Q zsN`f?Hf({iBTsyfi}P_w@2N>YVS$M9y3>0X^U>e1 z2povbwbJ3^VdW#c3qmZ=uBxjznDJo>1u?__SeW(j)xwqbd@Mfpx%jmt3)gSy)fT(+ z5tSb#Y9vFh9~_kvJH&?|j6HLC84CvDMmyCb_?VqCRmWb%LMye$UE>@dwJ}S-uU*H2 zNZ^XP1(|$2X@BGHsK$a_Z9Mm5As=a9Z+OLMu^@cuv2@BqKHAuko%an`SoZjCDWjSX z-3scOez0J+yCp=pg^%DCN#~_Z7Q~O=j%gve#5)(|wAYG-`TqHLYXp2W3(IxZ+Oc4s z-gBFn;lpW3PAtobh4jB)9i8S7C^8P;JnhE9LoHVkML7a9_uaa>y;!K>y~^9Oj)0PO zagV4!3jw7c_0;qTj967mh6S;ZcG+08(40Wf!sj1t!&x8}CiRXx5?DCW>3lhc1!=?e zVQPLPheA|OdM2`v^XEZ!Tm*rdZC6C<(pl(ZI9P1HNZ|WbyU@?MEKD`TymKNsR-`Iq zQdZ1Dt-+&8-6{f;in)w{3KpaW?DIvt2nbo-9Z{|&`)dli{&$jq{(?<=r&?JkOTBc# zM^*sCU?cB811wZW-v4<@SAe1duE~yXEEFYviMZ%2fSt{a!KJgdK!bPT>&GYo-j*G* z*%79KsPYz}wjb)?C2BnTNOI3bJvvCPe-l2}6imA^bMfuC3ehsHh2z)iN^r6So|dM2 zf}3>^z58XRtmG5y5s~~W8mI@`a_7KCl6xNi)!yU`>Z2pGMC;4cYa~XLoLZ@72#?T| z7ykPi@OZG%@$e2Kc$YS-J6~wRy|yjqZ|pI~`dJGsf}gh_mNWFL#Lfg|o?Xc{7u(RZ z?qu|Ux+z|_erdL1bwJ|O^!Mp0a$R$(A2s!+r$m%|DpRLYL-qVSf-|tQEH&U>y z?)9bV+E4Ic+a|0IqoD4I`Thi%E^PC*I228C)ySkSWA)Z<^rwB;RxD1%jp<7r0Y`gq zTqEG%dIl9$i5Azy<9c!6cAB*RNh)3}XxMLmz7K6TP4F0(crUcY}FrP z04I-Ji8aZfp>R~U@~_1p2AZ>6=07Lv%^ia6a>=z{b9?6i4WpbviuI2{913xmUn@$- z(gO)+j2S~<^*voZrAUXvU9~B%lp&PsSm!0t=n(a^*OL4^gbC~S62)$G?6DK|_FXxQ zo;guVy~F7E!)s_Ob00=vz*E^zX>`1N+iaSJcO#&I;!eZxjn?%^xiXU)*v^SQZh zjE(2J##WW3n_=2l|?vBn&F(ZZhziC4j2b)c5W15!0ma5+tL+W2t{%4X;O`WD?7LAMaOgDn77nug+2p)N6*BK z<#EyX%W!g0+7Z`Vl@UUprf+Lk)4BYlAF8eFT!#8R99o7L1T$5aNGgY03JKFOZH$oVAt6$un zMdxAhp_Jq|CmFcB-=So=BM&P|<9$kF8JHXId1UDU9dn2jl@JvK4z z+YR~XwkjF?-pjy5#n4_g79Wq-I-a;Q%pm)IV_mtEk4M}mOIdsd3KIAJ{=1Ll*3l-V zx8n>H*;)?&JwoQ68}?EfKNz@`!bx}*#mDi>IxZ(>7GZ8ueXvun#dnIBE z*tOD3D1~g;@t}zh9SMo|*X5XKJaQ^kuA2`-B|Gt>E0~y)iU!9b!Vu82^Oy+l9=FM*FBs$M}XPJr-v_PFd=mI3Ojj~00BmqB==os;`)-3 zl{*XsXtqf3NyIHpf|xE+|Sfm_j0G3_LG_nuEO zQ*zjZ{uSGT>PhZuS1qYeAK#3tH7Y0P&VPc+BB_4{PTKhXby4kEyH`;E(LLajsRNmu zhwc;8uTg9|S8>5(U05A4Fn#6KfRkmr4R==R;gtNS?yS@%xX>fyoUZHRX!KW=&=)Oe zJ@@zS>7xd)&^J19Dzy!D4}8oXX&K^yPIy|$mJSrjK2_g2Xo${^B{AQJJ|Ob>H(}ps zBP>xV&e`qJ3Ff^j>J3?A6n>rkW#-K%m>JQcqq2?B9&eyqv$zY4&6X8L@+Np!RTY(D z(~aV@<#ya86YNt+-nAs82f~{xQ-gk+KzOmX5<8(6i_`ujK5#chQWjm^FS!r*Rdxt% ze`$)lB}w~!()zJim(!uO4}TRVW#Wx)#Gl|Za{n)GPG z09yae`6?ko!Q;?@ZrQN`#J3G^YF$ggTG!EiT9$)&7aE>_mq9_0)lc5_yg?|+?^*EH zlY)9DKkLU+g9z2-P|Tt!FuQ9tS#3Uq)adQ!ag_oV%|3BX>JZKal^K4lpy0m|yZiOs zL#XMC=~>%K!Pb~rr)X=2@xv(XCioPv8MEnyp2G;Pka%D$OhsWH@k}*;7*Y*;G&Gh_ z;q3D1#Z3P&(0}aK`i)e~6FMA3l^;RFxW1|Zg$gq-rD`Sn5&U_&$1}j5iimMsY&biD zV@io1I`>g=ne|kD`_mB&A6~$Z3ZbI+P5tuT(Ge_cUot=VEERT|Lvm+Uu(4#QrRhZm z6}gv}$9h?^k&y9t+~+10A5sqqWrVO1kdl7%_#-OHmxuXmy3U46s_)a^wN!L_=t)^L zu<_9`Dq*UP3Yl%H1LNP==yCnC^zZ-`#2Qy`qoo|gd>&tUWQ>Zh(`ST?O*v4OH|iJq zMTJyKHec9-1FrV&&8z>T;bE8duB_7>WU}cApQLCgYJHut^acl`ml37p5F3uz;bz2)&;+o6SSj6SqxKX*7g4n9bd-z{Ab0e`|ka(J=0IN&c@E z57oVaoqG8+5Q&X%7P3h0@~-b=7140pttLEkHxGQ5Hk*wPXqf4~>2ocRhpfw%`vxn> z`J`JOT^GZ{yl?!S2Vc7?ODouT)19}j_r>7H#pG^|;=JaOO)58NXURKkX6 zU?vPC`UvxJI`qL)VIB?g4+B0WFXH1yRJpa!7!6~4-LA(m4+|2@}Zg@^P*FP4t3)NG?U$Y1lT?=xc46& z(tRHF{(gM)y^Z{QK!OgPP$J18l#g(em3E63(UJHnnKE^TJpbm#?`&B*j0UAn+NbjI zE>7x=^fEeb9&%?N%q6|R{X*6e1@b(9qa^DtnGYp--p|+2q4E4#9H)|x%tdUOAIfxm zR#-7<`Gybc;eRP|8|aYbWPLx@#>by1=~fjDI+Fi;uN~OOherL@MinhO^tX*2(<8|J z%a&)C>d~S4Y~o7eBp;#wc}UC{((!SYQig#Dfw_75N8g#!Q4tuN7$!+z@0v_r3hC)n zgGG($BzL>)Gj3Tj>8M-b+Ml9GK=Ssi*}qtHEMNEQ+(C5$oeRv5HO0oAHkb)UA=k?XXj@90(nSIm`60vze6q$_2JIuYo$kLnTGMMubw zZC>X{4lDV}voDdkz&_EkOxB-(U{_S5+Fm-&xZTdUag0D%UB6Ae7adDpWDv|#1YYKT z3k>k3BgsqC^&iRQ^X$$YEN zNY^++M`V|*hHoDMS&HWDhw*f1xpy!M1SI!s6n6wB(NWr$SRFA#z?@on;6N(bPu)MZ zB*}4WtuO4mm+A2H3;5XEs-L;XtpwO_e$RpNfR5d=8?+4e2p~pp;+=X-$GpD3t{y=GTob)IYfBX!jk{0G zS$0l3jjq@=eRKTIra{P(1bKxd4O5EGsG9bX471xIw&20IB8? zZ>p~%{wB96j$c}bRhz9h2+_mw#wg+YVUrDb^V_OJ(>xcvrl5|{0CiM8yd9_)S%ON{ zYVVCXnz%7M*i%n>z)p&4W`6f3G>7V|$lJd{W4`#AQsz3Img_=)fFE1&ya7Jl(L}giHySd2&b=~WfPlsY${L|P zNH@D8{?*P96UV>VswVbA!}D@lSCJu7Mb`iAy3mJNWhb0}{xt+O`---OSw9jo2i!ZX zjKJb*6&hCbW1Fx19DbS+j7u+28kGj%SPrke&qgR+dm1Ot4`9Od5zAfP7f*lsC0XNorE ztTW0lMzBe4{oY4Krcf3Om8tnMf>4*tTQP4<(UN(>{^Tk)tYz+BU(Y5ztM^XD;jL_3 z8dE<|A`E0yFY78l&c^%`f*y_KfZ4%ISsn#!q^*xAe4qmux*gq`(8$I!*(bi{TY=Qz zl2)@xHfAmyX`9&x^vaGs&Rx#I3&XPF&J#enR*GvWm4l6|`zDABfbm23KXLmwSbwdH zXL=L39R9gM`z!~{-qbZ+Pk{p!K@zWSanL+Db>VF@&@{7lHR?$Z2+hitBK2r%tx{G6 zkAs{%V}<9FfM52mZHBYS{4BMga{nv}hMBvZo0Pb)SE#AoC`m!i8AH8j8W$qAnn#Lc zDHzeWdA-Y(3%U0qt=tt9xZ3@5^a zTriAAmg?wG5bdx0T>Lc`)9H&s&l*#3wkf?xrI(ABE?z!`bP5QA5~+e8T(l)cEK1!% zLAtC`?nw!f(|;@?ZEPv{p?;+J#Tp*67YP;g?xeus%V|A7eUkIMnH>~Y3JTZvPsLjE z&@Js^8$x>M&O4#b@*X@iZpk!_@uk3gtL9nNU>-{3f8FytM1foT)c(pi9^T0JDQ*lV zJ#*QU;-)MfB3Fuu*M(7__2O5J|2?w)R>^%z6a`1!biQ4w;=w(;`e6DQ3WUe!&|KPh zC|~#}mS<9M z+G)o#$HjbjW07T04h1*krNdQ}`Ec)!_m;~i=fm#M$<^UwM17q??oA5R$L9{cGUG$) z{?Q!^?@*xHsNj)e&&SgKF}_y`1d_>utT+>xWfvMP9-6faF`W%^KuWBinnP07?a)Xc5 zC)cae-cn$s80`P^0h!mkUfkXBo`OU@$z|~`_^AIp%@=8*pyST3-8~I_s4&A7uC`N< zwkdSykB@w;{M92;62@9^-S#`IYlu$S)^gbxzH3seppY z2X7)9l?nJK({>n*Q$UT(%61_0_kd;IO51M~#BJo;lp7IXNbNl6@`HjahXeWJ43e9F zXuqAODezB{Zx^#6P#e*7(dHNVejMD>XGjlt`{9we$zKXSJDLtPlOB-!>OY;;WDZGa zdVV3{AOVx_mh=CLQla%!uS_k3?~)aVaYPUK7*WcZ0x>8T#uc87e}R z=TuaaULc!(vN4m)J!cgpAFX>#;PS79r}i%+_oHK|>90r+5^^R?R#36GF+yzVI|A31 z@o39dQBjog%(bk8z~ScZze0*+uKE7gCy?|6&FZAK?@Cl`u{4)<;t_C7imIGmPsNy` zf_BgZfo-yW(+kL4)&4H#YV#igP3`*>O*N=cFYwX}nJs|rX)%@9%~Tw`H*a3Bv;Z3# z^|z1fP_f@e&7(^}fJ@&W+l3iWk-nK^#ssR?vInElJ_s1 z4wN+$AY>aya|45lUu|vU&usefjCX8qU}v&@l-Q!~OLqsanLxjH5}yQw%&w_fa+M|#BRud=82Qjw!J zYAE)I^qNr(^OpT&PD?-Iv*?`w5-zH9lMYeA{jYX@AMLwC+6Jqa3??&Dbh*z51dkbzv?32LJ&7{|%9OJXHw;hE@2aEK$vbK0;_! zx)=A}bIzMAk&>GjX6j3_Ml#BtNl_7zrWy*>Fp^3m$&wnC5S4Bz*^;HZont6V*Ot%! z@B4e6-}}xESZ5M#9gF_C+JTCXjS(BpGe5QY8W#AQCi|aWiK@?3NOoK?e0!2?T&Gr} z#$qq?o7w~9P7l;;7@MNhGL)6)@&uik3dR9R>(CLkraE@K7Dp3j5B%%B9{%fj6Se-0 z*coa)I=I3NA)Wv9rd({o`1`Dtm)p%iJx_?4^H_kcINj0IBO4I?P|4({i*Ipl*fn8F zl?6#fUx;AqI~@FYX}sz@3(wA|eDM;#hs{XiuO4b_q_KpVIzAt;O!zTQKaP!((NVqm zjjc$M70)*pv#})p@T7@WJ6N_kgNgR$sCJ2-*s`+&_8-(n`mdQ|+u3~wjbb`snz=Yc zd&C@$Z_5)L&wfNfKz7TWO%}+yp0d~YLKiY+nz8I;3q)_e_hY>6C!7eLo3p&h0-?DZ zn@(4Jf_zvREmgHdsr>%?8;rY=k5zfSTP*Rkaa>_MwHv>j(cG~w-VzGrXuQ*xZnP)0 zS>G$Sgpz-L{vS>~@YdW~_akkIm1ZN^X4iX=SGbq;LW_g0)Hl|LCwd@jE^nG+&%rXs z%;6i3y^w6XkUS8|!H8@QUn`>*UWepwS6=3T!%d+Tx_hzS^NvT#6AmO@`n}hU`rtw> z)<_g_pgm!Dz&xN2+vnS9zf$Icc+=o?Oe#>ud<$|m} zaLG~cGq%1|I294d#o2G46y4oF!@=};>~JC%#>!(ymZyFO$=Y+r;x8_Y|5CrSaUr!Y9dt1eBJ1fvtKCCf*gl#+?7z2QsXo289jkeWA(9^oUW#xo%JrzBV8&+|=h@HwBJ?;# z)*b5R;bq>&GJ2jEGU3A^eBnVB5pg$dg&5m@byrD~Be3OXyLEg!F*01Ms(sZ6EEo8x zbG^ifw2Sn9v6Mhn-mP4hSTUymJUm^qihy%#XKDT)Vk{23;pEC8pzI~M$SW3O)}QVT z$Lt7roUhT>Y7par=Z8Wk7Xp^SDh&=jV!SPrG5_o#a8SW4Log;r&1Dt0)<6Qa^dWA! zx&%2^`4Kr01fGnvE65v5FyJzFS0kQ4U#4KPl8`{g%qKXh1V)tXn|t^Y2u@tpS(HiO zqRDAloUa63fgjCpUR32XbB7-w{Di*CgAN*dWw}QK}%YvM82HB`#8UoF4rWG zD5qZ(JSOlwu19XNM1s8~!v&s=1g1_7vdn5F(0bFY~|ql>}X@fHc}2NOD8brR5+ ztgQb@BEb>&u$6a(1imNjy5}@5!L(sH*K>eC%%QCC9%U&UZG6OS-v}Jf@jS3tTZ(Vq z{4oPL5}iIv1^p|fSj=+JC6!1#e>-?TkSoQpD$ba(I*F+TS2TWikYYM_^FrYdBo=%M zE%$YoB31OBiPR;rq(Xj=#8(R4s|f?$MkEYO>8ne_q+o0oM(eL7!4fy6%bEbFc^ zDR$&%yO=qX7^qZ<4z7{HBIC6r(~U$_lm5t~*E2rf!>I0EBpRZA%X`xy#iM->UFYs2 zvE47&|E5?9{<4WxngJy8by%8?!&2M~@Vq}AJkyVL-d{2;#i4bjy-kNnl&Llw3zccW zcTsjiG>NvX`xAYOXeg3qYq3nUa@U7V;;7=8vQZZrK^6|1 z9$g`^;j*Gb*=`!AZNd}o=aERcxwgXDkA_R*2Gz?qNPH>WlUEZ$qrJ$YI;e<*ms@Gk z!e|=b(o*AAmyj^*`}5?ccp61wEuy__X20yGqU|X(?o6=7QB@?IqnT{#0*$X$;R$me z&;0T7M51~Q4NvYBo#baT{Nt>gx_lZt${saLKcC5WwGH~*p|Nv3ro#UfiCf&U;{Gxk zaXQ&Xb#F-Y{V&2|Z#9i5-#Q_6D?*kSQ&GxJW?b-IpMf;w}=i zzNS0xYNeqcDCrpLC2_FSq$9JNhBwptVuP4OkmrFKfdQZaU;c!~M|KAt%8)pwQ(W<% zuQW!N*L#!A^IUQc$rAQ&=sDGrcA%nKnXV Qy`XF-ECUU zpeJ&4QH&ae*M>#ag=_|PgT*Rq7g9(Q@te!38Q%2uTDc_@>aEvw&bDW8epaG~Y$=8O zoVbj0dkdP*?(fPuEgJL^rmDfHOz!}J;%ynUt@ zQ{Y1(W>Pp~|D<$G1d3{tqz1$eR%iRqp@* literal 0 HcmV?d00001 diff --git a/tests/schedulers/test_scheduler_stork.py b/tests/schedulers/test_scheduler_stork.py new file mode 100644 index 000000000000..f7a3133da310 --- /dev/null +++ b/tests/schedulers/test_scheduler_stork.py @@ -0,0 +1,187 @@ +import tempfile +import unittest + +import torch + +from diffusers import STORKScheduler + +from .test_schedulers import SchedulerCommonTest + + +class STORKSchedulerTest(SchedulerCommonTest): + scheduler_classes = (STORKScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + + for scheduler_class in self.scheduler_classes: + + scheduler_config = self.get_scheduler_config(**config, prediction_type="epsilon") + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + + output = scheduler.step_noise_2(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_noise_2(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK2 noise scheduler outputs are not identical" + + output = scheduler.step_noise_4(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_noise_4(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 noise scheduler outputs are not identical" + + + + scheduler_config = self.get_scheduler_config(**config, prediction_type="flow_prediction") + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + + output = scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK2 flow matching scheduler outputs are not identical" + + output = scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 flow matching scheduler outputs are not identical" + + @unittest.skip("Test not supported.") + def test_from_save_pretrained(self): + pass + + @unittest.skip("Test not supported.") + def test_add_noise_device(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(prediction_type="epsilon") + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.ets = dummy_past_residuals[:] + + output = scheduler.step_noise_2(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_noise_2(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK2 noise scheduler outputs are not identical" + + output = scheduler.step_noise_4(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_noise_4(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 noise scheduler outputs are not identical" + + + scheduler_config = self.get_scheduler_config(prediction_type="flow_prediction") + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.ets = dummy_past_residuals[:] + + output = scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK2 flow matching scheduler outputs are not identical" + + output = scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample + new_output = new_scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 flow matching scheduler outputs are not identical" + + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + # Test for noise based models + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="epsilon", stopping_eps=1e-4) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(10) + + expected_timesteps = torch.Tensor([900, 850, 800, 800, 700, 600, 500, 400, 300, 200, 100, 0.1]) + expected_timesteps = expected_timesteps.to(dtype=torch.float64) + assert torch.allclose( + scheduler.timesteps, + expected_timesteps, + ) + + # Test for flow matching based models + scheduler_config = self.get_scheduler_config(prediction_type="flow_prediction", shift=3.0, time_shift_type="exponential") + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(10) + assert torch.allclose( + scheduler.timesteps, + torch.Tensor([1000.0000, 980.0647, 960.1293, 913.3490, 857.6923, 790.3683, 707.2785, 602.1506, 464.8760, 278.0488, 8.9286]), + ) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps) From cd39dc24e5834dd588830564bcdca7032e4130ae Mon Sep 17 00:00:00 2001 From: weizhen Date: Sat, 5 Jul 2025 18:46:24 -0700 Subject: [PATCH 4/6] Done with styling and QC --- src/diffusers/schedulers/scheduling_stork.py | 705 +++++++++++-------- tests/schedulers/test_scheduler_stork.py | 43 +- 2 files changed, 456 insertions(+), 292 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_stork.py b/src/diffusers/schedulers/scheduling_stork.py index 23c02dad6160..fbaa6906a418 100644 --- a/src/diffusers/schedulers/scheduling_stork.py +++ b/src/diffusers/schedulers/scheduling_stork.py @@ -12,15 +12,16 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from pathlib import Path +from typing import List, Optional, Union + import numpy as np import torch from scipy.io import loadmat + from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput -from diffusers.utils import BaseOutput, is_scipy_available, logging -from pathlib import Path - +from diffusers.utils import BaseOutput @dataclass @@ -41,12 +42,10 @@ class STORKSchedulerOutput(BaseOutput): CONSTANTSFOLDER = f"{current_file.parent}/stork_parameters" - - - class STORKScheduler(SchedulerMixin, ConfigMixin): """ - `STORKScheduler` uses modified stabilized Runge-Kutta method for the backward ODE in the diffusion or flow matching models. + `STORKScheduler` uses modified stabilized Runge-Kutta method for the backward ODE in the diffusion or flow matching + models. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. @@ -81,11 +80,13 @@ class STORKScheduler(SchedulerMixin, ConfigMixin): solver_order (`int`, defaults to 2): The STORK order which can be `2` or `4`. It is recommended to use `solver_order=2` uniformly. prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) or `flow_prediction`. + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) + or `flow_prediction`. time_shift_type (`str`, defaults to "exponential"): The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". derivative_order (`int`, defaults to 2): - The order of the Taylor expansion derivative to use for the sub-step velocity approximation. Only supports 2 or 3. + The order of the Taylor expansion derivative to use for the sub-step velocity approximation. Only supports + 2 or 3. s (`int`, defaults to 50): The number of sub-steps to use in the STORK. precision (`str`, defaults to "float32"): @@ -122,7 +123,6 @@ def __init__( use_beta_sigmas: Optional[bool] = False, set_alpha_to_one: bool = False, ): - super().__init__() # if prediction_type == "flow_prediction" and sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: # raise ValueError( @@ -130,26 +130,24 @@ def __init__( # ) if time_shift_type not in {"exponential", "linear"}: raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") - + # We manually enforce precision to float32 for numerical issues.Add commentMore actions self.np_dtype = np.float32 self.dtype = torch.float32 - timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=self.np_dtype)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=self.dtype) sigmas = timesteps / num_train_timesteps - if not use_dynamic_shifting: # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - self.timesteps = None #sigmas * num_train_timesteps + self.timesteps = None # sigmas * num_train_timesteps self._step_index = None self._begin_index = None self._shift = shift - self.sigmas = sigmas #.to("cpu") # to avoid too much CPU/GPU communication + self.sigmas = sigmas # .to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() # Store the predictions for the velocity/noise for higher order derivative approximations @@ -171,20 +169,17 @@ def __init__( self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") - + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 - + # Noise-based models epsilon to avoid numerical issues self.stopping_eps = stopping_eps - - - def set_timesteps( self, num_inference_steps: Optional[int] = None, @@ -211,7 +206,7 @@ def set_timesteps( Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed automatically. """ - + if self.config.use_dynamic_shifting and mu is None: raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") @@ -237,14 +232,13 @@ def set_timesteps( self.set_timesteps_flow_matching(num_inference_steps, device, sigmas, mu, timesteps) else: raise ValueError(f"Prediction type {self.prediction_type} is not yet supported") - + # Reset the step index and begin index self._step_index = None self._begin_index = None - - - def set_timesteps_noise(self, + def set_timesteps_noise( + self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, ): @@ -257,7 +251,7 @@ def set_timesteps_noise(self, device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - seq = np.linspace(0, 1, self.num_inference_steps+1) + seq = np.linspace(0, 1, self.num_inference_steps + 1) seq[0] = self.stopping_eps seq = seq[:-1] seq = seq[::-1] @@ -273,16 +267,13 @@ def set_timesteps_noise(self, self._timesteps = seq self.timesteps = torch.from_numpy(seq.copy()).to(device) - self._step_index = None self._begin_index = None self.noise_predictions = [] - - - - def set_timesteps_flow_matching(self, + def set_timesteps_flow_matching( + self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, sigmas: Optional[List[float]] = None, @@ -312,7 +303,7 @@ def set_timesteps_flow_matching(self, if is_timesteps_provided: timesteps = np.array(timesteps).astype(self.np_dtype) - + if sigmas is None: if timesteps is None: timesteps = np.linspace( @@ -323,7 +314,6 @@ def set_timesteps_flow_matching(self, sigmas = np.array(sigmas).astype(self.np_dtype) num_inference_steps = len(sigmas) - # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of # "exponential" or "linear" type is applied if self.config.use_dynamic_shifting: @@ -407,8 +397,6 @@ def begin_index(self): """ return self._begin_index - - def set_shift(self, shift: float): self._shift = shift @@ -462,8 +450,6 @@ def scale_noise( def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - - def step( self, @@ -471,9 +457,9 @@ def step( timestep: Union[int, torch.Tensor], sample: torch.Tensor = None, return_dict: bool = True, - **kwargs + **kwargs, ) -> torch.Tensor: - ''' + """ One step of the STORK update for flow matching or noise-based diffusion models. Args: @@ -485,26 +471,26 @@ def step( A current instance of a sample created by the diffusion process. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. - Returns: result (Union[Tuple, STORKSchedulerOutput]): - The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. - ''' + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. + The value is converted back to the original dtype of `model_output` to avoid numerical issues. + """ original_model_output_dtype = model_output.dtype # Cast model_output and sample to "torch.float32" to avoid numerical issues model_output = model_output.to(self.dtype) sample = sample.to(self.dtype) # Move sample to model_output's device sample = sample.to(model_output.device) - + """ self.velocity_predictions always contain upcasted model_output in torch.float32 dtype. """ - + if self.prediction_type == "epsilon": if self.solver_order == 2: result = self.step_noise_2(model_output, timestep, sample, return_dict) - elif self.solver_order ==4: + elif self.solver_order == 4: result = self.step_noise_4(model_output, timestep, sample, return_dict) else: raise ValueError(f"Solver order {self.solver_order} is not yet supported for noise-based models") @@ -517,14 +503,14 @@ def step( raise ValueError(f"Solver order {self.solver_order} is not yet supported for flow matching models") else: raise ValueError(f"Prediction type {self.prediction_type} is not yet supported") - + # Convert the result back to the original dtype of model_output, as this result will be used as the next input to the model if return_dict: result.prev_sample = result.prev_sample.to(original_model_output_dtype) else: result = (result[0].to(original_model_output_dtype),) return result - + def step_flow_matching_2( self, model_output: torch.Tensor, @@ -532,7 +518,7 @@ def step_flow_matching_2( sample: torch.Tensor = None, return_dict: bool = False, ) -> torch.Tensor: - ''' + """ One step of the STORK2 update for flow matching based models. Args: @@ -544,11 +530,11 @@ def step_flow_matching_2( A current instance of a sample created by the diffusion process. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple. - Returns: result (Union[Tuple, STORKSchedulerOutput]): - The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. - ''' + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. + The value is converted back to the original dtype of `model_output` to avoid numerical issues. + """ # Initialize the step index if it's the first step if self._step_index is None: self._step_index = 0 @@ -564,33 +550,59 @@ def step_flow_matching_2( t = self.sigmas[self._step_index] t_next = self.sigmas[self._step_index + 1] - - h1 = self.dt_list[self._step_index-1] - h2 = self.dt_list[self._step_index-2] - h3 = self.dt_list[self._step_index-3] - + h1 = self.dt_list[self._step_index - 1] + h2 = self.dt_list[self._step_index - 2] + h3 = self.dt_list[self._step_index - 3] if self.derivative_order == 2: - velocity_derivative = (-self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output) / (2 * h1) - velocity_second_derivative = 2 / (h1 * h2 * (h1 + h2)) * (self.velocity_predictions[-2] * h1 - self.velocity_predictions[-1] * (h1 + h2) + model_output * h2) + velocity_derivative = ( + -self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output + ) / (2 * h1) + velocity_second_derivative = ( + 2 + / (h1 * h2 * (h1 + h2)) + * ( + self.velocity_predictions[-2] * h1 + - self.velocity_predictions[-1] * (h1 + h2) + + model_output * h2 + ) + ) velocity_third_derivative = None elif self.derivative_order == 3: - velocity_derivative = ((h2 * h3) * (self.velocity_predictions[-1] - model_output) - (h1 * h3) * (self.velocity_predictions[-2] - model_output) + (h1 * h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) - velocity_second_derivative = 2 * ((h2 + h3) * (self.velocity_predictions[-1] - model_output) - (h1 + h3) * (self.velocity_predictions[-2] - model_output) + (h1 + h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) - velocity_third_derivative = 6 * ((h2 - h3) * (self.velocity_predictions[-1] - model_output) + (h3 - h1) * (self.velocity_predictions[-2] - model_output) + (h1 - h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + velocity_derivative = ( + (h2 * h3) * (self.velocity_predictions[-1] - model_output) + - (h1 * h3) * (self.velocity_predictions[-2] - model_output) + + (h1 * h2) * (self.velocity_predictions[-3] - model_output) + ) / (h1 * h2 * h3) + velocity_second_derivative = ( + 2 + * ( + (h2 + h3) * (self.velocity_predictions[-1] - model_output) + - (h1 + h3) * (self.velocity_predictions[-2] - model_output) + + (h1 + h2) * (self.velocity_predictions[-3] - model_output) + ) + / (h1 * h2 * h3) + ) + velocity_third_derivative = ( + 6 + * ( + (h2 - h3) * (self.velocity_predictions[-1] - model_output) + + (h3 - h1) * (self.velocity_predictions[-2] - model_output) + + (h1 - h2) * (self.velocity_predictions[-3] - model_output) + ) + / (h1 * h2 * h3) + ) else: print("The noise approximation order is not supported!") exit() - + self.velocity_predictions.append(model_output) self._step_index += 1 - Y_j_2 = sample Y_j_1 = sample Y_j = sample - # Implementation of our Runge-Kutta-Gegenbauer second order method for j in range(1, self.s + 1): # Calculate the corresponding \bar{alpha}_t and beta_t that aligns with the correct timestep @@ -598,8 +610,8 @@ def step_flow_matching_2( if j == 2: fraction = 4 / (3 * (self.s**2 + self.s - 2)) else: - fraction = ((j - 1)**2 + (j - 1) - 2) / (self.s**2 + self.s - 2) - + fraction = ((j - 1) ** 2 + (j - 1) - 2) / (self.s**2 + self.s - 2) + if j == 1: mu_tilde = 6 / ((self.s + 4) * (self.s - 1)) dt = (t - t_next) * torch.ones(model_output.shape, device=sample.device) @@ -608,27 +620,36 @@ def step_flow_matching_2( mu = (2 * j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 1)) nu = -(j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 2)) mu_tilde = mu * 6 / ((self.s + 4) * (self.s - 1)) - gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j-1)/ 2) - + gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j - 1) / 2) # Probability flow ODE update diff = -fraction * (t - t_next) * torch.ones(model_output.shape, device=sample.device) - velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) - Y_j = mu * Y_j_1 + nu * Y_j_2 + (1 - mu - nu) * sample - dt * mu_tilde * velocity - dt * gamma_tilde * model_output - + velocity = self.taylor_approximation( + self.derivative_order, + diff, + model_output, + velocity_derivative, + velocity_second_derivative, + velocity_third_derivative, + ) + Y_j = ( + mu * Y_j_1 + + nu * Y_j_2 + + (1 - mu - nu) * sample + - dt * mu_tilde * velocity + - dt * gamma_tilde * model_output + ) + Y_j_2 = Y_j_1 Y_j_1 = Y_j - - img_next = Y_j img_next = img_next.to(model_output.dtype) if not return_dict: - return (img_next,) + return (img_next,) return STORKSchedulerOutput(prev_sample=img_next) - def step_flow_matching_4( self, model_output: torch.Tensor, @@ -636,7 +657,7 @@ def step_flow_matching_4( sample: torch.Tensor = None, return_dict: bool = False, ) -> torch.Tensor: - ''' + """ One step of the STORK4 update for flow matching models Args: @@ -649,8 +670,8 @@ def step_flow_matching_4( Returns: `torch.FloatTensor`: The next sample in the diffusion chain. - ''' - + """ + # Initialize the step index if it's the first step if self._step_index is None: self._step_index = 0 @@ -663,29 +684,55 @@ def step_flow_matching_4( t_start = torch.ones(model_output.shape, device=sample.device) * t t_next = self.sigmas[self._step_index + 1] - - h1 = self.dt_list[self._step_index-1] - h2 = self.dt_list[self._step_index-2] - h3 = self.dt_list[self._step_index-3] - + h1 = self.dt_list[self._step_index - 1] + h2 = self.dt_list[self._step_index - 2] + h3 = self.dt_list[self._step_index - 3] if self.derivative_order == 2: - velocity_derivative = (-self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output) / (2 * h1) - velocity_second_derivative = 2 / (h1 * h2 * (h1 + h2)) * (self.velocity_predictions[-2] * h1 - self.velocity_predictions[-1] * (h1 + h2) + model_output * h2) + velocity_derivative = ( + -self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output + ) / (2 * h1) + velocity_second_derivative = ( + 2 + / (h1 * h2 * (h1 + h2)) + * ( + self.velocity_predictions[-2] * h1 + - self.velocity_predictions[-1] * (h1 + h2) + + model_output * h2 + ) + ) velocity_third_derivative = None elif self.derivative_order == 3: - velocity_derivative = ((h2 * h3) * (self.velocity_predictions[-1] - model_output) - (h1 * h3) * (self.velocity_predictions[-2] - model_output) + (h1 * h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) - velocity_second_derivative = 2 * ((h2 + h3) * (self.velocity_predictions[-1] - model_output) - (h1 + h3) * (self.velocity_predictions[-2] - model_output) + (h1 + h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) - velocity_third_derivative = 6 * ((h2 - h3) * (self.velocity_predictions[-1] - model_output) + (h3 - h1) * (self.velocity_predictions[-2] - model_output) + (h1 - h2) * (self.velocity_predictions[-3] - model_output)) / (h1 * h2 * h3) + velocity_derivative = ( + (h2 * h3) * (self.velocity_predictions[-1] - model_output) + - (h1 * h3) * (self.velocity_predictions[-2] - model_output) + + (h1 * h2) * (self.velocity_predictions[-3] - model_output) + ) / (h1 * h2 * h3) + velocity_second_derivative = ( + 2 + * ( + (h2 + h3) * (self.velocity_predictions[-1] - model_output) + - (h1 + h3) * (self.velocity_predictions[-2] - model_output) + + (h1 + h2) * (self.velocity_predictions[-3] - model_output) + ) + / (h1 * h2 * h3) + ) + velocity_third_derivative = ( + 6 + * ( + (h2 - h3) * (self.velocity_predictions[-1] - model_output) + + (h3 - h1) * (self.velocity_predictions[-2] - model_output) + + (h1 - h2) * (self.velocity_predictions[-3] - model_output) + ) + / (h1 * h2 * h3) + ) else: print("The noise approximation order is not supported!") exit() - + self.velocity_predictions.append(model_output) self._step_index += 1 - - Y_j_2 = sample Y_j_1 = sample Y_j = sample @@ -701,13 +748,10 @@ def step_flow_matching_4( mz = int(mp[0]) mr = int(mp[1]) - - - ''' + """ The first part of the STORK4 update - ''' + """ for j in range(1, mdeg + 1): - # First sub-step in the first part of the STORK4 update if j == 1: temp1 = -(t - t_next) * recf[mr] * torch.ones(model_output.shape, device=sample.device) @@ -718,10 +762,19 @@ def step_flow_matching_4( # Second and the following sub-steps in the first part of the STORK4 update else: diff = ci1 - t_start - velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + velocity = self.taylor_approximation( + self.derivative_order, + diff, + model_output, + velocity_derivative, + velocity_second_derivative, + velocity_third_derivative, + ) - temp1 = -(t - t_next) * recf[mr + 2 * (j-2) + 1] * torch.ones(model_output.shape, device=sample.device) - temp3 = -recf[mr + 2 * (j-2) + 2] * torch.ones(model_output.shape, device=sample.device) + temp1 = ( + -(t - t_next) * recf[mr + 2 * (j - 2) + 1] * torch.ones(model_output.shape, device=sample.device) + ) + temp3 = -recf[mr + 2 * (j - 2) + 2] * torch.ones(model_output.shape, device=sample.device) temp2 = torch.ones(model_output.shape, device=sample.device) - temp3 ci1 = temp1 + temp2 * ci2 + temp3 * ci3 @@ -734,43 +787,72 @@ def step_flow_matching_4( ci3 = ci2 ci2 = ci1 - ''' + """ The finishing four-step procedure as a composition method - ''' + """ # First finishing step - temp1 = -(t - t_next) * fpa[mz,0] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpa[mz, 0] * torch.ones(model_output.shape, device=sample.device) diff = ci1 - t_start - velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + velocity = self.taylor_approximation( + self.derivative_order, + diff, + model_output, + velocity_derivative, + velocity_second_derivative, + velocity_third_derivative, + ) Y_j_1 = velocity Y_j_3 = Y_j + temp1 * Y_j_1 # Second finishing step ci2 = ci1 + temp1 - temp1 = -(t - t_next) * fpa[mz,1] * torch.ones(model_output.shape, device=sample.device) - temp2 = -(t - t_next) * fpa[mz,2] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpa[mz, 1] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz, 2] * torch.ones(model_output.shape, device=sample.device) diff = ci2 - t_start - velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + velocity = self.taylor_approximation( + self.derivative_order, + diff, + model_output, + velocity_derivative, + velocity_second_derivative, + velocity_third_derivative, + ) Y_j_2 = velocity Y_j_4 = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 # Third finishing step ci2 = ci1 + temp1 + temp2 - temp1 = -(t - t_next) * fpa[mz,3] * torch.ones(model_output.shape, device=sample.device) - temp2 = -(t - t_next) * fpa[mz,4] * torch.ones(model_output.shape, device=sample.device) - temp3 = -(t - t_next) * fpa[mz,5] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpa[mz, 3] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz, 4] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpa[mz, 5] * torch.ones(model_output.shape, device=sample.device) diff = ci2 - t_start - velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + velocity = self.taylor_approximation( + self.derivative_order, + diff, + model_output, + velocity_derivative, + velocity_second_derivative, + velocity_third_derivative, + ) Y_j_3 = velocity - fnt = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + # This correponds to the the noise-prediction counterpart. + # fnt = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 # Fourth finishing step ci2 = ci1 + temp1 + temp2 + temp3 - temp1 = -(t - t_next) * fpb[mz,0] * torch.ones(model_output.shape, device=sample.device) - temp2 = -(t - t_next) * fpb[mz,1] * torch.ones(model_output.shape, device=sample.device) - temp3 = -(t - t_next) * fpb[mz,2] * torch.ones(model_output.shape, device=sample.device) - temp4 = -(t - t_next) * fpb[mz,3] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpb[mz, 0] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpb[mz, 1] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpb[mz, 2] * torch.ones(model_output.shape, device=sample.device) + temp4 = -(t - t_next) * fpb[mz, 3] * torch.ones(model_output.shape, device=sample.device) diff = ci2 - t_start - velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative) + velocity = self.taylor_approximation( + self.derivative_order, + diff, + model_output, + velocity_derivative, + velocity_second_derivative, + velocity_third_derivative, + ) Y_j_4 = velocity Y_j = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + temp4 * Y_j_4 img_next = Y_j @@ -778,7 +860,6 @@ def step_flow_matching_4( if not return_dict: return (img_next,) return STORKSchedulerOutput(prev_sample=img_next) - def step_noise_2( self, @@ -787,7 +868,7 @@ def step_noise_2( sample: torch.Tensor = None, return_dict: bool = False, ) -> torch.Tensor: - ''' + """ One step of the STORK2 update for noise-based diffusion models. Args: @@ -802,21 +883,20 @@ def step_noise_2( Returns: `torch.FloatTensor`: The next sample in the diffusion chain. - ''' + """ # Initialize the step index if it's the first step if self._step_index is None: self._step_index = 0 self.initial_noise = model_output - total_step = self.config.num_train_timesteps t = self.timesteps[self._step_index] / total_step beta_0, beta_1 = self.betas[0], self.betas[-1] t_start = torch.ones(model_output.shape, device=sample.device) * t beta_t = (beta_0 + t_start * (beta_1 - beta_0)) * total_step - log_mean_coeff = (-0.25 * t_start ** 2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step - std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + log_mean_coeff = (-0.25 * t_start**2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) # Tweedie's trick if self._step_index == len(self.timesteps) - 1: @@ -825,15 +905,17 @@ def step_noise_2( if not return_dict: return (img_next,) return STORKSchedulerOutput(prev_sample=img_next) - + t_next = self.timesteps[self._step_index + 1] / total_step # drift, diffusion -> f(x,t), g(t) - drift_initial, diffusion_initial = -0.5 * beta_t * sample, torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device) + drift_initial, diffusion_initial = ( + -0.5 * beta_t * sample, + torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device), + ) noise_initial = model_output score = -noise_initial / std # score -> noise - drift_initial = drift_initial - diffusion_initial ** 2 * score * 0.5 # drift -> dx/dt - + drift_initial = drift_initial - diffusion_initial**2 * score * 0.5 # drift -> dx/dt dt = torch.ones(model_output.shape, device=sample.device) * self.dt @@ -853,7 +935,13 @@ def step_noise_2( elif self._step_index == 1: # SECOND RUN t_previous = torch.ones(model_output.shape, device=sample.device) * self.timesteps[0] / 1000 - drift_previous = self.drift_function(self.betas, self.config.num_train_timesteps, t_previous, self.initial_sample, self.noise_predictions[-1]) + drift_previous = self.drift_function( + self.betas, + self.config.num_train_timesteps, + t_previous, + self.initial_sample, + self.noise_predictions[-1], + ) img_next = sample - 0.75 * dt * drift_initial + 0.25 * dt * drift_previous @@ -863,9 +951,11 @@ def step_noise_2( return SchedulerOutput(prev_sample=img_next) elif self._step_index == 2: h = 0.5 * dt - + noise_derivative = (3 * self.noise_predictions[0] - 4 * self.noise_predictions[1] + model_output) / (2 * h) - noise_second_derivative = (self.noise_predictions[0] - 2 * self.noise_predictions[1] + model_output) / (h ** 2) + noise_second_derivative = (self.noise_predictions[0] - 2 * self.noise_predictions[1] + model_output) / ( + h**2 + ) noise_third_derivative = None model_output = self.initial_noise @@ -880,8 +970,12 @@ def step_noise_2( elif self._step_index == 3: h = 0.5 * dt - noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) - noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / ( + 2 * h + ) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / ( + h**2 + ) noise_third_derivative = None self.noise_predictions.append(noise_initial) @@ -889,32 +983,46 @@ def step_noise_2( elif self._step_index == 4: h = dt - noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) - noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / ( + 2 * h + ) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / ( + h**2 + ) noise_third_derivative = None - + self.noise_predictions.append(noise_initial) noise_approx_order = 2 else: # ALL ELSE h = dt - - noise_derivative = (2 * self.noise_predictions[-3] - 9 * self.noise_predictions[-2] + 18 * self.noise_predictions[-1] - 11 * noise_initial) / (6 * h) - noise_second_derivative = (-self.noise_predictions[-3] + 4 * self.noise_predictions[-2] -5 * self.noise_predictions[-1] + 2 * noise_initial) / (h**2) - noise_third_derivative = (self.noise_predictions[-3] - 3 * self.noise_predictions[-2] + 3 * self.noise_predictions[-1] - noise_initial) / (h**3) + + noise_derivative = ( + 2 * self.noise_predictions[-3] + - 9 * self.noise_predictions[-2] + + 18 * self.noise_predictions[-1] + - 11 * noise_initial + ) / (6 * h) + noise_second_derivative = ( + -self.noise_predictions[-3] + + 4 * self.noise_predictions[-2] + - 5 * self.noise_predictions[-1] + + 2 * noise_initial + ) / (h**2) + noise_third_derivative = ( + self.noise_predictions[-3] + - 3 * self.noise_predictions[-2] + + 3 * self.noise_predictions[-1] + - noise_initial + ) / (h**3) self.noise_predictions.append(noise_initial) noise_approx_order = 3 - Y_j_2 = sample Y_j_1 = sample Y_j = sample - ci1 = t_start - ci2 = t_start - ci3 = t_start - # Implementation of our Runge-Kutta-Gegenbauer second order method for j in range(1, self.s + 1): # Calculate the corresponding \bar{alpha}_t and beta_t that aligns with the correct timestep @@ -922,8 +1030,8 @@ def step_noise_2( if j == 2: fraction = 4 / (3 * (self.s**2 + self.s - 2)) else: - fraction = ((j - 1)**2 + (j - 1) - 2) / (self.s**2 + self.s - 2) - + fraction = ((j - 1) ** 2 + (j - 1) - 2) / (self.s**2 + self.s - 2) + if j == 1: mu_tilde = 6 / ((self.s + 4) * (self.s - 1)) dt = (t - t_next) * torch.ones(model_output.shape, device=sample.device) @@ -932,19 +1040,29 @@ def step_noise_2( mu = (2 * j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 1)) nu = -(j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 2)) mu_tilde = mu * 6 / ((self.s + 4) * (self.s - 1)) - gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j-1)/ 2) - + gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j - 1) / 2) # Probability flow ODE update diff = -fraction * (t - t_next) * torch.ones(model_output.shape, device=sample.device) - velocity = self.taylor_approximation(self.derivative_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) - Y_j = mu * Y_j_1 + nu * Y_j_2 + (1 - mu - nu) * sample - dt * mu_tilde * velocity - dt * gamma_tilde * model_output - + velocity = self.taylor_approximation( + noise_approx_order, + diff, + model_output, + noise_derivative, + noise_second_derivative, + noise_third_derivative, + ) + Y_j = ( + mu * Y_j_1 + + nu * Y_j_2 + + (1 - mu - nu) * sample + - dt * mu_tilde * velocity + - dt * gamma_tilde * model_output + ) + Y_j_2 = Y_j_1 Y_j_1 = Y_j - - img_next = Y_j img_next = img_next.to(model_output.dtype) self._step_index += 1 @@ -953,7 +1071,6 @@ def step_noise_2( return (img_next,) return STORKSchedulerOutput(prev_sample=img_next) - def step_noise_4( self, model_output: torch.Tensor, @@ -961,7 +1078,7 @@ def step_noise_4( sample: torch.Tensor = None, return_dict: bool = False, ) -> torch.Tensor: - ''' + """ One step of the STORK4 update for noise-based diffusion models. Args: @@ -976,21 +1093,20 @@ def step_noise_4( Returns: `torch.FloatTensor`: The next sample in the diffusion chain. - ''' + """ # Initialize the step index if it's the first step if self._step_index is None: self._step_index = 0 self.initial_noise = model_output - total_step = self.config.num_train_timesteps t = self.timesteps[self._step_index] / total_step beta_0, beta_1 = self.betas[0], self.betas[-1] t_start = torch.ones(model_output.shape, device=sample.device) * t beta_t = (beta_0 + t_start * (beta_1 - beta_0)) * total_step - log_mean_coeff = (-0.25 * t_start ** 2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step - std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + log_mean_coeff = (-0.25 * t_start**2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) # Tweedie's trick if self._step_index == len(self.timesteps) - 1: @@ -999,15 +1115,17 @@ def step_noise_4( if not return_dict: return (img_next,) return STORKSchedulerOutput(prev_sample=img_next) - + t_next = self.timesteps[self._step_index + 1] / total_step # drift, diffusion -> f(x,t), g(t) - drift_initial, diffusion_initial = -0.5 * beta_t * sample, torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device) + drift_initial, diffusion_initial = ( + -0.5 * beta_t * sample, + torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device), + ) noise_initial = model_output score = -noise_initial / std # score -> noise - drift_initial = drift_initial - diffusion_initial ** 2 * score * 0.5 # drift -> dx/dt - + drift_initial = drift_initial - diffusion_initial**2 * score * 0.5 # drift -> dx/dt dt = torch.ones(model_output.shape, device=sample.device) * self.dt @@ -1027,7 +1145,13 @@ def step_noise_4( elif self._step_index == 1: # SECOND RUN t_previous = torch.ones(model_output.shape, device=sample.device) * self.timesteps[0] / 1000 - drift_previous = self.drift_function(self.betas, self.config.num_train_timesteps, t_previous, self.initial_sample, self.noise_predictions[-1]) + drift_previous = self.drift_function( + self.betas, + self.config.num_train_timesteps, + t_previous, + self.initial_sample, + self.noise_predictions[-1], + ) img_next = sample - 0.75 * dt * drift_initial + 0.25 * dt * drift_previous @@ -1037,9 +1161,11 @@ def step_noise_4( return SchedulerOutput(prev_sample=img_next) elif self._step_index == 2: h = 0.5 * dt - + noise_derivative = (3 * self.noise_predictions[0] - 4 * self.noise_predictions[1] + model_output) / (2 * h) - noise_second_derivative = (self.noise_predictions[0] - 2 * self.noise_predictions[1] + model_output) / (h ** 2) + noise_second_derivative = (self.noise_predictions[0] - 2 * self.noise_predictions[1] + model_output) / ( + h**2 + ) noise_third_derivative = None model_output = self.initial_noise @@ -1054,8 +1180,12 @@ def step_noise_4( elif self._step_index == 3: h = 0.5 * dt - noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) - noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / ( + 2 * h + ) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / ( + h**2 + ) noise_third_derivative = None self.noise_predictions.append(noise_initial) @@ -1063,24 +1193,42 @@ def step_noise_4( elif self._step_index == 4: h = dt - noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h) - noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2) + noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / ( + 2 * h + ) + noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / ( + h**2 + ) noise_third_derivative = None - + self.noise_predictions.append(noise_initial) noise_approx_order = 2 else: # ALL ELSE h = dt - - noise_derivative = (2 * self.noise_predictions[-3] - 9 * self.noise_predictions[-2] + 18 * self.noise_predictions[-1] - 11 * noise_initial) / (6 * h) - noise_second_derivative = (-self.noise_predictions[-3] + 4 * self.noise_predictions[-2] -5 * self.noise_predictions[-1] + 2 * noise_initial) / (h**2) - noise_third_derivative = (self.noise_predictions[-3] - 3 * self.noise_predictions[-2] + 3 * self.noise_predictions[-1] - noise_initial) / (h**3) + + noise_derivative = ( + 2 * self.noise_predictions[-3] + - 9 * self.noise_predictions[-2] + + 18 * self.noise_predictions[-1] + - 11 * noise_initial + ) / (6 * h) + noise_second_derivative = ( + -self.noise_predictions[-3] + + 4 * self.noise_predictions[-2] + - 5 * self.noise_predictions[-1] + + 2 * noise_initial + ) / (h**2) + noise_third_derivative = ( + self.noise_predictions[-3] + - 3 * self.noise_predictions[-2] + + 3 * self.noise_predictions[-1] + - noise_initial + ) / (h**3) self.noise_predictions.append(noise_initial) noise_approx_order = 3 - Y_j_2 = sample Y_j_1 = sample Y_j = sample @@ -1096,11 +1244,10 @@ def step_noise_4( mz = int(mp[0]) mr = int(mp[1]) - ''' + """ The first part of the STORK4 update - ''' + """ for j in range(1, mdeg + 1): - # First sub-step in the first part of the STORK4 update if j == 1: temp1 = -(t - t_next) * recf[mr] * torch.ones(model_output.shape, device=sample.device) @@ -1111,11 +1258,22 @@ def step_noise_4( # Second and the following sub-steps in the first part of the STORK4 update else: diff = ci1 - t_start - noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) - drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci1, Y_j_1, noise_approx) + noise_approx = self.taylor_approximation( + noise_approx_order, + diff, + model_output, + noise_derivative, + noise_second_derivative, + noise_third_derivative, + ) + drift_approx = self.drift_function( + self.betas, self.config.num_train_timesteps, ci1, Y_j_1, noise_approx + ) - temp1 = -(t - t_next) * recf[mr + 2 * (j-2) + 1] * torch.ones(model_output.shape, device=sample.device) - temp3 = -recf[mr + 2 * (j-2) + 2] * torch.ones(model_output.shape, device=sample.device) + temp1 = ( + -(t - t_next) * recf[mr + 2 * (j - 2) + 1] * torch.ones(model_output.shape, device=sample.device) + ) + temp3 = -recf[mr + 2 * (j - 2) + 2] * torch.ones(model_output.shape, device=sample.device) temp2 = torch.ones(model_output.shape, device=sample.device) - temp3 ci1 = temp1 + temp2 * ci2 + temp3 * ci3 @@ -1128,62 +1286,64 @@ def step_noise_4( ci3 = ci2 ci2 = ci1 - ''' + """ The finishing four-step procedure as a composition method - ''' + """ # First finishing step - temp1 = -(t - t_next) * fpa[mz,0] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpa[mz, 0] * torch.ones(model_output.shape, device=sample.device) diff = ci1 - t_start - noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + noise_approx = self.taylor_approximation( + noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative + ) drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci1, Y_j, noise_approx) Y_j_1 = drift_approx Y_j_3 = Y_j + temp1 * Y_j_1 # Second finishing step ci2 = ci1 + temp1 - temp1 = -(t - t_next) * fpa[mz,1] * torch.ones(model_output.shape, device=sample.device) - temp2 = -(t - t_next) * fpa[mz,2] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpa[mz, 1] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz, 2] * torch.ones(model_output.shape, device=sample.device) diff = ci2 - t_start - noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + noise_approx = self.taylor_approximation( + noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative + ) drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, Y_j_3, noise_approx) Y_j_2 = drift_approx Y_j_4 = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 # Third finishing step ci2 = ci1 + temp1 + temp2 - temp1 = -(t - t_next) * fpa[mz,3] * torch.ones(model_output.shape, device=sample.device) - temp2 = -(t - t_next) * fpa[mz,4] * torch.ones(model_output.shape, device=sample.device) - temp3 = -(t - t_next) * fpa[mz,5] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpa[mz, 3] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpa[mz, 4] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpa[mz, 5] * torch.ones(model_output.shape, device=sample.device) diff = ci2 - t_start - noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + noise_approx = self.taylor_approximation( + noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative + ) drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, Y_j_4, noise_approx) Y_j_3 = drift_approx fnt = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 # Fourth finishing step ci2 = ci1 + temp1 + temp2 + temp3 - temp1 = -(t - t_next) * fpb[mz,0] * torch.ones(model_output.shape, device=sample.device) - temp2 = -(t - t_next) * fpb[mz,1] * torch.ones(model_output.shape, device=sample.device) - temp3 = -(t - t_next) * fpb[mz,2] * torch.ones(model_output.shape, device=sample.device) - temp4 = -(t - t_next) * fpb[mz,3] * torch.ones(model_output.shape, device=sample.device) + temp1 = -(t - t_next) * fpb[mz, 0] * torch.ones(model_output.shape, device=sample.device) + temp2 = -(t - t_next) * fpb[mz, 1] * torch.ones(model_output.shape, device=sample.device) + temp3 = -(t - t_next) * fpb[mz, 2] * torch.ones(model_output.shape, device=sample.device) + temp4 = -(t - t_next) * fpb[mz, 3] * torch.ones(model_output.shape, device=sample.device) diff = ci2 - t_start - noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative) + noise_approx = self.taylor_approximation( + noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative + ) drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, fnt, noise_approx) Y_j_4 = drift_approx Y_j = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + temp4 * Y_j_4 - - img_next = Y_j self._step_index += 1 if not return_dict: return (img_next,) return STORKSchedulerOutput(prev_sample=img_next) - - - - def startup_phase_flow_matching( self, @@ -1191,7 +1351,7 @@ def startup_phase_flow_matching( sample: torch.Tensor = None, return_dict: bool = True, ) -> torch.Tensor: - ''' + """ Startup phase for the STORK2 and STORK4 update for flow matching based models. Args: @@ -1204,30 +1364,35 @@ def startup_phase_flow_matching( Returns: result (Union[Tuple, STORKSchedulerOutput]): - The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. - ''' + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. + The value is converted back to the original dtype of `model_output` to avoid numerical issues. + """ dt = self.dt_list[self._step_index] dt = torch.ones(model_output.shape, device=sample.device) * dt - + if self._step_index == 0: # Perfrom Euler's method for a half step img_next = sample - 0.5 * dt * model_output - self.velocity_predictions.append(model_output) + self.velocity_predictions.append(model_output) elif self._step_index == 1: # Perfrom Heun's method for a half step img_next = sample - 0.75 * dt * model_output + 0.25 * dt * self.velocity_predictions[-1] elif self._step_index == 2 or (self._step_index == 3 and self.derivative_order == 3): - dt_previous = self.dt_list[self._step_index-1] + dt_previous = self.dt_list[self._step_index - 1] dt_previous = torch.ones(model_output.shape, device=sample.device) * dt_previous - img_next = sample + (dt**2 / (2 * (-dt_previous)) - dt) * model_output + (dt**2 / (2 * dt_previous)) * self.velocity_predictions[-1] + img_next = ( + sample + + (dt**2 / (2 * (-dt_previous)) - dt) * model_output + + (dt**2 / (2 * dt_previous)) * self.velocity_predictions[-1] + ) self.velocity_predictions.append(model_output) else: raise NotImplementedError( f"Startup phase for step {self._step_index} is not implemented. Please check the implementation." ) - + self._step_index += 1 - + if not return_dict: return (img_next,) return STORKSchedulerOutput(prev_sample=img_next) @@ -1238,8 +1403,8 @@ def startup_phase_noise( drift: torch.Tensor, sample: torch.Tensor = None, return_dict: bool = False, - ) -> torch.Tensor: - ''' + ) -> torch.Tensor: + """ Startup phase for the STORK2 and STORK4 update for noise-based diffusion models. Args: @@ -1254,8 +1419,9 @@ def startup_phase_noise( Returns: result (Union[Tuple, STORKSchedulerOutput]): - The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues. - ''' + The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. + The value is converted back to the original dtype of `model_output` to avoid numerical issues. + """ dt = torch.ones(model_output.shape, device=sample.device) * self.dt if self._step_index == 0: # Perfrom Euler's method for a half step @@ -1282,13 +1448,10 @@ def startup_phase_noise( return STORKSchedulerOutput(prev_sample=img_next) else: raise ValueError("Startup phase is only supported for the first two steps.") - - - def __len__(self): return self.config.num_train_timesteps - + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -1303,7 +1466,7 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens A scaled input sample. """ return sample - + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): if self.config.time_shift_type == "exponential": return self._time_shift_exponential(mu, sigma, t) @@ -1315,8 +1478,10 @@ def _time_shift_exponential(self, mu, sigma, t): def _time_shift_linear(self, mu, sigma, t): return mu / (mu + (1 / t - 1) ** sigma) - - def taylor_approximation(self, taylor_approx_order, diff, model_output, derivative, second_derivative, third_derivative=None): + + def taylor_approximation( + self, taylor_approx_order, diff, model_output, derivative, second_derivative, third_derivative=None + ): if taylor_approx_order == 2: if third_derivative is not None: raise ValueError("The third derivative is computed but not used!") @@ -1324,17 +1489,17 @@ def taylor_approximation(self, taylor_approx_order, diff, model_output, derivati elif taylor_approx_order == 3: if third_derivative is None: raise ValueError("The third derivative is not computed!") - approx_value = model_output + diff * derivative + 0.5 * diff**2 * second_derivative \ - + diff**3 * third_derivative / 6 + approx_value = ( + model_output + diff * derivative + 0.5 * diff**2 * second_derivative + diff**3 * third_derivative / 6 + ) else: print("The noise approximation order is not supported!") exit() return approx_value - def drift_function(self, betas, total_step, t_eval, y_eval, noise): - ''' + """ Drift function for the probability flow ODE in the noise-based diffusion model. Args: @@ -1352,25 +1517,25 @@ def drift_function(self, betas, total_step, t_eval, y_eval, noise): Returns: `torch.FloatTensor`: The drift term for the probability flow ODE in the diffusion model. - ''' + """ beta_0, beta_1 = betas[0], betas[-1] beta_t = (beta_0 + t_eval * (beta_1 - beta_0)) * total_step beta_t = beta_t * torch.ones(y_eval.shape, device=y_eval.device) - log_mean_coeff = (-0.25 * t_eval ** 2 * (beta_1 - beta_0) - 0.5 * t_eval * beta_0) * total_step - std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + log_mean_coeff = (-0.25 * t_eval**2 * (beta_1 - beta_0) - 0.5 * t_eval * beta_0) * total_step + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) # drift, diffusion -> f(x,t), g(t) drift, diffusion = -0.5 * beta_t * y_eval, torch.sqrt(beta_t) * torch.ones(y_eval.shape, device=y_eval.device) score = -noise / std # score -> noise - drift = drift - diffusion ** 2 * score * 0.5 # drift -> dx/dt + drift = drift - diffusion**2 * score * 0.5 # drift -> dx/dt return drift def b_coeff(self, j): - ''' - Coefficients of STORK2. The are based on the second order Runge-Kutta-Gegenbauer method. - Details of the coefficients can be found in https://www.sciencedirect.com/science/article/pii/S0021999120306537 + """ + Coefficients of STORK2. The are based on the second order Runge-Kutta-Gegenbauer method. Details of the + coefficients can be found in https://www.sciencedirect.com/science/article/pii/S0021999120306537 Args: j (`int`): @@ -1379,7 +1544,7 @@ def b_coeff(self, j): Returns: `float`: The coefficient of the STORK2. - ''' + """ if j < 0: print("The b_j coefficient in the RKG method can't have j negative") return @@ -1387,14 +1552,15 @@ def b_coeff(self, j): return 1 if j == 1: return 1 / 3 - + return 4 * (j - 1) * (j + 4) / (3 * j * (j + 1) * (j + 2) * (j + 3)) def coeff_stork4(self): - ''' - Load pre-computed coefficients of STORK4. The are based on the fourth order orthogonal Runge-Kutta-Chebyshev (ROCK4) method. - Details of the coefficients can be found in https://epubs.siam.org/doi/abs/10.1137/S1064827500379549. - The pre-computed coefficients are based on the implementation https://www.mathworks.com/matlabcentral/fileexchange/12129-rock4. + """ + Load pre-computed coefficients of STORK4. The are based on the fourth order orthogonal Runge-Kutta-Chebyshev + (ROCK4) method. Details of the coefficients can be found in + https://epubs.siam.org/doi/abs/10.1137/S1064827500379549. The pre-computed coefficients are based on the + implementation https://www.mathworks.com/matlabcentral/fileexchange/12129-rock4. Args: j (`int`): @@ -1405,31 +1571,29 @@ def coeff_stork4(self): The degrees that coefficients were pre-computed for STORK4. fpa, fpb, fpbe, recf (`torch.FloatTensor`): The parameters for the finishing procedure. - ''' + """ # Degrees - data = loadmat(f'{CONSTANTSFOLDER}/ms.mat') - ms = data['ms'][0] + data = loadmat(f"{CONSTANTSFOLDER}/ms.mat") + ms = data["ms"][0] # Parameters for the finishing procedure - data = loadmat(f'{CONSTANTSFOLDER}/fpa.mat') - fpa = data['fpa'] + data = loadmat(f"{CONSTANTSFOLDER}/fpa.mat") + fpa = data["fpa"] - data = loadmat(f'{CONSTANTSFOLDER}/fpb.mat') - fpb = data['fpb'] + data = loadmat(f"{CONSTANTSFOLDER}/fpb.mat") + fpb = data["fpb"] - data = loadmat(f'{CONSTANTSFOLDER}/fpbe.mat') - fpbe = data['fpbe'] + data = loadmat(f"{CONSTANTSFOLDER}/fpbe.mat") + fpbe = data["fpbe"] # Parameters for the recurrence procedure - data = loadmat(f'{CONSTANTSFOLDER}/recf.mat') - recf = data['recf'][0] + data = loadmat(f"{CONSTANTSFOLDER}/recf.mat") + recf = data["recf"][0] return ms, fpa, fpb, fpbe, recf - - def mdegr(self, mdeg1, ms): - ''' + """ Find the optimal degree in the pre-computed degree coefficients table for the STORK4 method. Args: @@ -1442,37 +1606,20 @@ def mdegr(self, mdeg1, ms): mdeg (`int`): The optimal degree in the pre-computed degree coefficients table for the STORK4 method. mp (`torch.FloatTensor`): - The pointer which select the degree in ms[i], such that mdeg<=ms[i]. - mp[0] (`int`): The pointer which select the degree in ms[i], such that mdeg<=ms[i]. - mp[1] (`int`): The pointer which gives the corresponding position of a_1 in the data recf for the selected degree. - ''' + The pointer which select the degree in ms[i], such that mdeg<=ms[i]. mp[0] (`int`): The pointer which + select the degree in ms[i], such that mdeg<=ms[i]. mp[1] (`int`): The pointer which gives the + corresponding position of a_1 in the data recf for the selected degree. + """ mp = torch.zeros(2) mp[1] = 1 mdeg = mdeg1 for i in range(len(ms)): - if (ms[i]/mdeg) >= 1: + if (ms[i] / mdeg) >= 1: mdeg = ms[i] mp[0] = i mp[1] = mp[1] - 1 break - else: + else: mp[1] = mp[1] + ms[i] * 2 - 1 return mdeg, mp - - - - def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.Tensor`): - The input sample. - - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample diff --git a/tests/schedulers/test_scheduler_stork.py b/tests/schedulers/test_scheduler_stork.py index f7a3133da310..fc96d5598fb3 100644 --- a/tests/schedulers/test_scheduler_stork.py +++ b/tests/schedulers/test_scheduler_stork.py @@ -23,7 +23,6 @@ def get_scheduler_config(self, **kwargs): config.update(**kwargs) return config - def check_over_configs(self, time_step=0, **config): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -31,7 +30,6 @@ def check_over_configs(self, time_step=0, **config): residual = 0.1 * sample for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config, prediction_type="epsilon") scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(num_inference_steps) @@ -51,8 +49,6 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 noise scheduler outputs are not identical" - - scheduler_config = self.get_scheduler_config(**config, prediction_type="flow_prediction") scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(num_inference_steps) @@ -65,12 +61,16 @@ def check_over_configs(self, time_step=0, **config): output = scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample new_output = new_scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample - assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK2 flow matching scheduler outputs are not identical" + assert torch.sum(torch.abs(output - new_output)) < 1e-5, ( + "STORK2 flow matching scheduler outputs are not identical" + ) output = scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample new_output = new_scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample - assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 flow matching scheduler outputs are not identical" + assert torch.sum(torch.abs(output - new_output)) < 1e-5, ( + "STORK4 flow matching scheduler outputs are not identical" + ) @unittest.skip("Test not supported.") def test_from_save_pretrained(self): @@ -114,7 +114,6 @@ def check_over_forward(self, time_step=0, **forward_kwargs): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 noise scheduler outputs are not identical" - scheduler_config = self.get_scheduler_config(prediction_type="flow_prediction") scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(num_inference_steps) @@ -134,13 +133,16 @@ def check_over_forward(self, time_step=0, **forward_kwargs): output = scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample new_output = new_scheduler.step_flow_matching_2(residual, time_step, sample, return_dict=True).prev_sample - assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK2 flow matching scheduler outputs are not identical" + assert torch.sum(torch.abs(output - new_output)) < 1e-5, ( + "STORK2 flow matching scheduler outputs are not identical" + ) output = scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample new_output = new_scheduler.step_flow_matching_4(residual, time_step, sample, return_dict=True).prev_sample - assert torch.sum(torch.abs(output - new_output)) < 1e-5, "STORK4 flow matching scheduler outputs are not identical" - + assert torch.sum(torch.abs(output - new_output)) < 1e-5, ( + "STORK4 flow matching scheduler outputs are not identical" + ) def test_timesteps(self): for timesteps in [100, 1000]: @@ -161,12 +163,28 @@ def test_steps_offset(self): ) # Test for flow matching based models - scheduler_config = self.get_scheduler_config(prediction_type="flow_prediction", shift=3.0, time_shift_type="exponential") + scheduler_config = self.get_scheduler_config( + prediction_type="flow_prediction", shift=3.0, time_shift_type="exponential" + ) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(10) assert torch.allclose( scheduler.timesteps, - torch.Tensor([1000.0000, 980.0647, 960.1293, 913.3490, 857.6923, 790.3683, 707.2785, 602.1506, 464.8760, 278.0488, 8.9286]), + torch.Tensor( + [ + 1000.0000, + 980.0647, + 960.1293, + 913.3490, + 857.6923, + 790.3683, + 707.2785, + 602.1506, + 464.8760, + 278.0488, + 8.9286, + ] + ), ) def test_betas(self): @@ -177,7 +195,6 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) - def test_time_indices(self): for t in [1, 5, 10]: self.check_over_forward(time_step=t) From edf036e46d28ed3f1fa0734934da9dc27e531718 Mon Sep 17 00:00:00 2001 From: ZT220501 Date: Sat, 5 Jul 2025 19:07:37 -0700 Subject: [PATCH 5/6] Add doc --- docs/source/en/api/schedulers/stork.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 docs/source/en/api/schedulers/stork.md diff --git a/docs/source/en/api/schedulers/stork.md b/docs/source/en/api/schedulers/stork.md new file mode 100644 index 000000000000..8c7597b34515 --- /dev/null +++ b/docs/source/en/api/schedulers/stork.md @@ -0,0 +1,23 @@ + + +# STORKScheduler +`STORKScheduler` is the sampling method from the paper [STORK: Improving the Fidelity of Mid-NFE Sampling for Diffusion and Flow Matching Models](https://arxiv.org/abs/2505.24210) by [Zheng Tan](https://zt220501.github.io/), [Weizhen Wang](https://weizhenwang-1210.github.io/), [Andrea L. Bertozzi](https://www.math.ucla.edu/~bertozzi/), and [Ernest K. Ryu](https://ernestryu.com/). It was motivated by stabilized Runge--Kutta methods, with Taylor expansion adaptation for diffusion and flow matching models. + +-------------------- + +## STORKScheduler +[[autodoc]] STORKScheduler + +## SchedulerOutput +[[autodoc]] schedulers.scheduling_utils.SchedulerOutput + From 7d0db8c05a38c46e47b5b7ae4ef0d75c67dc1d40 Mon Sep 17 00:00:00 2001 From: weizhen Date: Sat, 5 Jul 2025 19:14:49 -0700 Subject: [PATCH 6/6] add doc markdown for STORK --- docs/source/en/_toctree.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 770093438ed5..6b5bab5ad93d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -30,7 +30,8 @@ - local: using-diffusers/push_to_hub title: Push files to the Hub title: Load pipelines and adapters -- sections: +- isExpanded: false + sections: - local: tutorials/using_peft_for_inference title: LoRA - local: using-diffusers/ip_adapter @@ -44,7 +45,6 @@ - local: using-diffusers/textual_inversion_inference title: Textual inversion title: Adapters - isExpanded: false - sections: - local: using-diffusers/unconditional_image_generation title: Unconditional image generation @@ -652,6 +652,8 @@ title: ScoreSdeVeScheduler - local: api/schedulers/score_sde_vp title: ScoreSdeVpScheduler + - local: api/schedulers/stork + title: STORKScheduler - local: api/schedulers/tcd title: TCDScheduler - local: api/schedulers/unipc