Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
warnings are raised when attempting to use Apex group norm on CPU.
- Diffusion utils: systematic compilation of patching operations in `stochastic_sampler`
for improved performance.
- Diffusion utils: patch-based inference and lead time support with deterministic sampler
- CorrDiff example: added option for Student-t EDM (t-EDM) in `train.py` and
`generate.py`. When training a CorrDiff diffusion model, this feature can be
enabled with the hydra overrides `++training.hp.distribution=student_t` and
Expand All @@ -74,6 +75,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
overrides `++training.hp.P_mean=<P_mean_value>` and
`++training.hp.P_std=<P_std_value>` for training (and similar ones with
`training.hp` replaced by `generation` for generation).
- Diffusion utils: patch-based inference and lead time support with
deterministic sampler.

### Deprecated

Expand Down
5 changes: 1 addition & 4 deletions examples/weather/corrdiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,12 @@ def main(cfg: DictConfig) -> None:

# Partially instantiate the sampler based on the configs
if cfg.sampler.type == "deterministic":
if cfg.generation.hr_mean_conditioning:
raise NotImplementedError(
"High-res mean conditioning is not yet implemented for the deterministic sampler"
)
sampler_fn = partial(
deterministic_sampler,
num_steps=cfg.sampler.num_steps,
# num_ensembles=cfg.generation.num_ensembles,
solver=cfg.sampler.solver,
patching=patching,
)
elif cfg.sampler.type == "stochastic":
sampler_fn = partial(stochastic_sampler, patching=patching)
Expand Down
194 changes: 180 additions & 14 deletions physicsnemo/utils/diffusion/deterministic_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,73 @@
import torch

from physicsnemo.models.diffusion import EDMPrecond
from physicsnemo.utils.patching import GridPatching2D

# ruff: noqa: E731


# NOTE: use two wrappers for apply, to avoid recompilation when input shape changes
@torch.compile()
def _apply_wrapper_Cin_channels(patching, input, additional_input=None):
"""
Apply the patching operation to the input tensor with :math:`C_{in}` channels.
"""
return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _apply_wrapper_Cout_channels_no_grad(patching, input, additional_input=None):
"""
Apply the patching operation to an input tensor with :math:`C_{out}`
channels that does not require gradients.
"""
return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _apply_wrapper_Cout_channels_grad(patching, input, additional_input=None):
"""
Apply the patching operation to an input tensor with :math:`C_{out}`
channels that requires gradients.
"""
return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _fuse_wrapper(patching, input, batch_size):
return patching.fuse(input=input, batch_size=batch_size)


def _apply_wrapper_select(
input: torch.Tensor, patching: GridPatching2D | None
) -> Callable:
"""
Select the correct patching wrapper based on the input tensor's requires_grad attribute.
If patching is None, return the identity function.
If patching is not None, return the appropriate patching wrapper.
If input.requires_grad is True, return _apply_wrapper_Cout_channels_grad.
If input.requires_grad is False, return
_apply_wrapper_Cout_channels_no_grad.
"""
if patching:
if input.requires_grad:
return _apply_wrapper_Cout_channels_grad
else:
return _apply_wrapper_Cout_channels_no_grad
else:
return lambda patching, input, additional_input=None: input


@nvtx.annotate(message="deterministic_sampler", color="red")
def deterministic_sampler(
net: torch.nn.Module,
latents: torch.Tensor,
img_lr: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
randn_like: Callable = torch.randn_like,
patching: Optional[GridPatching2D] = None,
mean_hr: Optional[torch.Tensor] = None,
lead_time_label: Optional[torch.Tensor] = None,
num_steps: int = 18,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
Expand All @@ -49,6 +105,7 @@ def deterministic_sampler(
S_min: float = 0.0,
S_max: float = float("inf"),
S_noise: float = 1.0,
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
r"""
Generalized sampler, representing the superset of all sampling methods
Expand Down Expand Up @@ -81,6 +138,22 @@ def deterministic_sampler(
during the stochastic sampling. Must have the same signature as
torch.randn_like and return torch.Tensor. Defaults to
torch.randn_like.
patching : Optional[GridPatching2D], default=None
A patching utility for patch-based diffusion. Implements methods to
extract patches from an image and batch the patches along dim=0.
Should also implement a ``fuse`` method to reconstruct the original
image from a batch of patches. See
:class:`~physicsnemo.utils.patching.GridPatching2D` for details. By
default ``None``, in which case non-patched diffusion is used.
mean_hr : Optional[Tensor], optional
Optional tensor containing mean high-resolution images for
conditioning. Must have same height and width as ``img_lr``, with shape
:math:`(B_{hr}, C_{hr}, H, W)` where the batch dimension
:math:`B_{hr}` can be either 1, either equal to batch_size, or can be omitted. If
:math:`B_{hr} = 1` or is omitted, ``mean_hr`` will be expanded to match the shape
of ``img_lr``. By default ``None``.
lead_time_label : Optional[Tensor], optional
Optional lead time labels. By default None.
num_steps : Optional[int]
Number of time-steps for the stochastic ODE integration. Defaults
to 18.
Expand Down Expand Up @@ -157,15 +230,44 @@ def deterministic_sampler(
stochatsic sampler. Added signal noise is proportinal to
:math:`\epsilon_i` where :math:`\epsilon_i \sim \mathcal{N}(0, S_{noise}^2)`. Defaults
to 1.0.

dtype : torch.dtype, optional
Controls the precision used for sampling
Returns
-------
torch.Tensor:
Generated batch of samples. Same shape as the input ``latents``.
"""

# conditioning
# conditioning = [mean_hr, img_lr, global_lr]
x_lr = img_lr
if mean_hr is not None:
if mean_hr.shape[-2:] != img_lr.shape[-2:]:
raise ValueError(
f"mean_hr and img_lr must have the same height and width, "
f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}."
)
x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1)

# Safety check on type of patching
if patching is not None and not isinstance(patching, GridPatching2D):
raise ValueError("patching must be an instance of GridPatching2D.")

# Safety check: if patching is used then img_lr and latents must have same
# height and width, otherwise there is mismatch in the number
# of patches extracted to form the final batch_size.
if patching:
if img_lr.shape[-2:] != latents.shape[-2:]:
raise ValueError(
f"img_lr and latents must have the same height and width, "
f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. "
)
# img_lr and latents must also have the same batch_size, otherwise mismatch
# when processed by the network
if img_lr.shape[0] != latents.shape[0]:
raise ValueError(
f"img_lr and latents must have the same batch size, but found "
f"{img_lr.shape[0]} vs {latents.shape[0]}."
)

if solver not in ["euler", "heun"]:
raise ValueError(f"Unknown solver {solver}")
Expand Down Expand Up @@ -213,6 +315,24 @@ def deterministic_sampler(
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)

batch_size = img_lr.shape[0]
# input and position padding + patching
if patching:
# Patched conditioning [x_lr, mean_hr]
# (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x)
x_lr = _apply_wrapper_Cin_channels(
patching=patching, input=x_lr, additional_input=img_lr
)

# Function to select the correct positional embedding for each patch
def patch_embedding_selector(emb):
# emb: (N_pe, image_shape_y, image_shape_x)
# return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x)
return patching.apply(emb[None].expand(batch_size, -1, -1, -1))

else:
patch_embedding_selector = None

# Compute corresponding betas for VP.
vp_beta_d = (
2
Expand All @@ -222,7 +342,7 @@ def deterministic_sampler(
vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d

# Define time steps in terms of noise level.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
step_indices = torch.arange(num_steps, dtype=dtype, device=latents.device)
if discretization == "vp":
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
Expand All @@ -232,7 +352,7 @@ def deterministic_sampler(
)
sigma_steps = ve_sigma(orig_t_steps)
elif discretization == "iddpm":
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
u = torch.zeros(M + 1, dtype=dtype, device=latents.device)
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
u[j - 1] = (
Expand Down Expand Up @@ -280,7 +400,14 @@ def deterministic_sampler(

# Main sampling loop.
t_next = t_steps[0]
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
x_next = latents.to(dtype) * (sigma(t_next) * s(t_next))

optional_args = {}
if lead_time_label is not None:
optional_args["lead_time_label"] = lead_time_label
if patching:
optional_args["embedding_selector"] = patch_embedding_selector

for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next

Expand All @@ -295,20 +422,40 @@ def deterministic_sampler(
sigma(t_hat) ** 2 - sigma(t_cur) ** 2
).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

# Euler step.
# Euler step. Perform patching operation on score tensor if patch-based
# generation is used denoised = net(x_hat, t_hat,
# class_labels,lead_time_label=lead_time_label)

h = t_next - t_hat
x_hat_batch = _apply_wrapper_select(input=x_hat, patching=patching)(
patching=patching, input=x_hat
).to(latents.device)

if isinstance(net, EDMPrecond):
# Conditioning info is passed as keyword arg
denoised = net(
x_hat / s(t_hat),
x_hat_batch / s(t_hat),
sigma(t_hat),
condition=x_lr,
class_labels=class_labels,
).to(torch.float64)
**optional_args,
).to(dtype)
else:
denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to(
torch.float64
denoised = net(
x_hat_batch / s(t_hat),
x_lr,
sigma(t_hat),
class_labels,
**optional_args,
).to(dtype)

if patching:
# Un-patch the denoised image
# (batch_size, C_out, img_shape_y, img_shape_x)
denoised = _fuse_wrapper(
patching=patching, input=denoised, batch_size=batch_size
)

d_cur = (
sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)
) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
Expand All @@ -319,18 +466,37 @@ def deterministic_sampler(
if solver == "euler" or i == num_steps - 1:
x_next = x_hat + h * d_cur
else:
# Patched input
# (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x)
x_prime_batch = _apply_wrapper_select(input=x_prime, patching=patching)(
patching=patching, input=x_prime
).to(latents.device)

if isinstance(net, EDMPrecond):
# Conditioning info is passed as keyword arg
denoised = net(
x_prime / s(t_prime),
x_prime_batch / s(t_prime),
sigma(t_prime),
condition=x_lr,
class_labels=class_labels,
).to(torch.float64)
**optional_args,
).to(dtype)
else:
denoised = net(
x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels
).to(torch.float64)
x_prime_batch / s(t_prime),
x_lr,
sigma(t_prime),
class_labels,
**optional_args,
).to(dtype)

if patching:
# Un-patch the denoised image
# (batch_size, C_out, img_shape_y, img_shape_x)
denoised = _fuse_wrapper(
patching=patching, input=denoised, batch_size=batch_size
)

d_prime = (
sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)
) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
Expand Down
Loading