Skip to content

Commit 399422a

Browse files
juliusbernerJulius BernerCharlelieLrt
authored
CorrDiff fixes (precond., patching, & cfg parsing) (#937)
* Add legacy scaling function Signed-off-by: Julius Berner <jberner@nvidia.com> * Adapt warning for `scale_cond_input=True` * Fix max_patch_per_gpu=1 behavior Signed-off-by: Julius Berner <jberner@nvidia.com> * Avoid views for floats and simplify input_interp concat Signed-off-by: Julius Berner <jberner@nvidia.com> * Added tests to check differentiability of patching and deterministic sampler Signed-off-by: Charlelie Laurent <claurent@nvidia.com> * Fixed error in new test Signed-off-by: Charlelie Laurent <claurent@nvidia.com> --------- Signed-off-by: Julius Berner <jberner@nvidia.com> Signed-off-by: Charlelie Laurent <claurent@nvidia.com> Co-authored-by: Julius Berner <jberner@nvidia.com> Co-authored-by: Charlelie Laurent <claurent@nvidia.com> Co-authored-by: Charlelie Laurent <84199758+CharlelieLrt@users.noreply.github.com>
1 parent b39dd35 commit 399422a

File tree

5 files changed

+211
-61
lines changed

5 files changed

+211
-61
lines changed

examples/weather/corrdiff/train.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,22 +369,25 @@ def main(cfg: DictConfig) -> None:
369369
batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu
370370
logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds")
371371

372-
patch_num = getattr(cfg.training.hp, "patch_num", 1)
373-
max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1)
374-
375372
# calculate patch per iter
376-
if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1:
373+
patch_num = getattr(cfg.training.hp, "patch_num", 1)
374+
if hasattr(cfg.training.hp, "max_patch_per_gpu"):
375+
max_patch_per_gpu = cfg.training.hp.max_patch_per_gpu
376+
if max_patch_per_gpu // batch_size_per_gpu < 1:
377+
raise ValueError(
378+
f"max_patch_per_gpu ({max_patch_per_gpu}) must be greater or equal to batch_size_per_gpu ({batch_size_per_gpu})."
379+
)
377380
max_patch_num_per_iter = min(
378381
patch_num, (max_patch_per_gpu // batch_size_per_gpu)
379-
) # Ensure at least 1 patch per iter
382+
)
380383
patch_iterations = (
381384
patch_num + max_patch_num_per_iter - 1
382385
) // max_patch_num_per_iter
383386
patch_nums_iter = [
384387
min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter)
385388
for i in range(patch_iterations)
386389
]
387-
print(
390+
logger0.info(
388391
f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}"
389392
)
390393
else:

physicsnemo/models/diffusion/preconditioning.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,15 +1031,6 @@ def __init__(
10311031
stacklevel=2,
10321032
)
10331033

1034-
if scale_cond_input:
1035-
warnings.warn(
1036-
"scale_cond_input=True does not properly scale the conditional input. "
1037-
"(see https://github.com/NVIDIA/modulus/issues/229). "
1038-
"This setup will be deprecated. "
1039-
"Please set scale_cond_input=False.",
1040-
DeprecationWarning,
1041-
)
1042-
10431034
super().__init__(
10441035
img_resolution=img_resolution,
10451036
img_in_channels=img_in_channels,
@@ -1052,10 +1043,48 @@ def __init__(
10521043
**model_kwargs,
10531044
)
10541045

1046+
if scale_cond_input:
1047+
warnings.warn(
1048+
"The `scale_cond_input=True` option does not properly scale the conditional input "
1049+
"and is deprecated. It is highly recommended to set `scale_cond_input=False`. "
1050+
"However, for loading a checkpoint previously trained with `scale_cond_input=True`, "
1051+
"this flag must be set to `True` to ensure compatibility. "
1052+
"For more details, see https://github.com/NVIDIA/modulus/issues/229.",
1053+
DeprecationWarning,
1054+
)
1055+
self.scaling_fn = self._legacy_scaling_fn
1056+
10551057
# Store deprecated parameters for backward compatibility
10561058
self.img_channels = img_channels
10571059
self.scale_cond_input = scale_cond_input
10581060

1061+
@staticmethod
1062+
def _legacy_scaling_fn(
1063+
x: torch.Tensor, img_lr: torch.Tensor, c_in: torch.Tensor
1064+
) -> torch.Tensor:
1065+
"""
1066+
This function does not properly scale the conditional input
1067+
(see https://github.com/NVIDIA/modulus/issues/229)
1068+
and will be deprecated.
1069+
1070+
Concatenate and scale the high-resolution and low-resolution tensors.
1071+
1072+
Parameters
1073+
----------
1074+
x : torch.Tensor
1075+
Noisy high-resolution image of shape (B, C_hr, H, W).
1076+
img_lr : torch.Tensor
1077+
Low-resolution image of shape (B, C_lr, H, W).
1078+
c_in : torch.Tensor
1079+
Scaling factor of shape (B, 1, 1, 1).
1080+
1081+
Returns
1082+
-------
1083+
torch.Tensor
1084+
Scaled and concatenated tensor of shape (B, C_in+C_out, H, W).
1085+
"""
1086+
return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1)
1087+
10591088
def forward(
10601089
self,
10611090
x,

physicsnemo/utils/patching.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -591,14 +591,26 @@ def image_batching(
591591
) # (padding_left,padding_right,padding_top,padding_bottom)
592592
input_padded = image_padding(input)
593593
patch_num = patch_num_x * patch_num_y
594+
595+
# Cast to float for unfold
596+
if input.dtype == torch.int32:
597+
input_padded = input_padded.view(torch.float32)
598+
elif input.dtype == torch.int64:
599+
input_padded = input_padded.view(torch.float64)
600+
594601
x_unfold = torch.nn.functional.unfold(
595-
input=input_padded.view(_cast_type(input_padded)), # Cast to float
602+
input=input_padded,
596603
kernel_size=(patch_shape_y, patch_shape_x),
597604
stride=(
598605
patch_shape_y - overlap_pix - boundary_pix,
599606
patch_shape_x - overlap_pix - boundary_pix,
600607
),
601-
).view(input_padded.dtype)
608+
)
609+
610+
# Cast back to original dtype
611+
if input.dtype in [torch.int32, torch.int64]:
612+
x_unfold = x_unfold.view(input.dtype)
613+
602614
x_unfold = rearrange(
603615
x_unfold,
604616
"b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w",
@@ -608,16 +620,7 @@ def image_batching(
608620
nb_p_w=patch_num_x,
609621
)
610622
if input_interp is not None:
611-
input_interp_repeated = rearrange(
612-
torch.repeat_interleave(
613-
input=input_interp,
614-
repeats=patch_num,
615-
dim=0,
616-
output_size=x_unfold.shape[0],
617-
),
618-
"(b p) c h w -> (p b) c h w",
619-
p=patch_num,
620-
)
623+
input_interp_repeated = input_interp.repeat(patch_num, 1, 1, 1)
621624
return torch.cat((x_unfold, input_interp_repeated), dim=1)
622625
else:
623626
return x_unfold
@@ -722,6 +725,13 @@ def image_fuse(
722725
nb_p_h=patch_num_y,
723726
nb_p_w=patch_num_x,
724727
)
728+
729+
# Cast to float for fold
730+
if input.dtype == torch.int32:
731+
x = x.view(torch.float32)
732+
elif input.dtype == torch.int64:
733+
x = x.view(torch.float64)
734+
725735
# Stitch patches together (by summing over overlapping patches)
726736
x_folded = torch.nn.functional.fold(
727737
input=x,
@@ -733,6 +743,10 @@ def image_fuse(
733743
),
734744
)
735745

746+
# Cast back to original dtype
747+
if input.dtype in [torch.int32, torch.int64]:
748+
x_folded = x_folded.view(input.dtype)
749+
736750
# Remove padding
737751
x_no_padding = x_folded[
738752
..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x
@@ -743,25 +757,3 @@ def image_fuse(
743757

744758
# Normalize by overlap count
745759
return x_no_padding / overlap_count_no_padding
746-
747-
748-
def _cast_type(input: Tensor) -> torch.dtype:
749-
"""Return float type based on input tensor type.
750-
751-
Parameters
752-
----------
753-
input : Tensor
754-
Input tensor to determine float type from
755-
756-
Returns
757-
-------
758-
torch.dtype
759-
Float type corresponding to input tensor type for int32/64,
760-
otherwise returns original dtype
761-
"""
762-
if input.dtype == torch.int32:
763-
return torch.float32
764-
elif input.dtype == torch.int64:
765-
return torch.float64
766-
else:
767-
return input.dtype

test/utils/generative/test_stochastic_sampler.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Callable, Optional
1818

19+
import pytest
1920
import torch
2021
from pytest_utils import import_or_fail
2122
from torch import Tensor
@@ -118,7 +119,8 @@ def test_stochastic_sampler(pytestconfig):
118119

119120
# The test function for edm_sampler with rectangular domain and patching
120121
@import_or_fail("cftime")
121-
def test_stochastic_sampler_rectangle_patching(pytestconfig):
122+
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
123+
def test_stochastic_sampler_rectangle_patching(device, pytestconfig):
122124
from physicsnemo.utils.generative import stochastic_sampler
123125
from physicsnemo.utils.patching import GridPatching2D
124126

@@ -127,8 +129,10 @@ def test_stochastic_sampler_rectangle_patching(pytestconfig):
127129
img_shape_y, img_shape_x = 256, 64
128130
patch_shape_y, patch_shape_x = 16, 10
129131

130-
latents = torch.randn(2, 3, img_shape_y, img_shape_x) # Mock latents
131-
img_lr = torch.randn(2, 3, img_shape_y, img_shape_x) # Mock low-res image
132+
latents = torch.randn(2, 3, img_shape_y, img_shape_x, device=device) # Mock latents
133+
img_lr = torch.randn(
134+
2, 3, img_shape_y, img_shape_x, device=device
135+
) # Mock low-res image
132136

133137
# Test with patching
134138
patching = GridPatching2D(
@@ -139,7 +143,7 @@ def test_stochastic_sampler_rectangle_patching(pytestconfig):
139143
)
140144

141145
# Test with mean_hr conditioning
142-
mean_hr = torch.randn(2, 3, img_shape_y, img_shape_x)
146+
mean_hr = torch.randn(2, 3, img_shape_y, img_shape_x, device=device)
143147
result_mean_hr = stochastic_sampler(
144148
net=net,
145149
latents=latents,
@@ -159,3 +163,91 @@ def test_stochastic_sampler_rectangle_patching(pytestconfig):
159163
assert (
160164
result_mean_hr.shape == latents.shape
161165
), "Mean HR conditioned output shape does not match expected shape"
166+
167+
168+
# Test that the stochastic sampler is differentiable with rectangular patching
169+
# (tests differentiation through the patching and fusing)
170+
@import_or_fail("cftime")
171+
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
172+
def test_stochastic_sampler_patching_differentiable(device, pytestconfig):
173+
from physicsnemo.utils.generative import stochastic_sampler
174+
from physicsnemo.utils.patching import GridPatching2D
175+
176+
# Mock network class
177+
class MockNet:
178+
def __init__(self, sigma_min=0.1, sigma_max=1000):
179+
self.sigma_min = sigma_min
180+
self.sigma_max = sigma_max
181+
182+
def round_sigma(self, t: Tensor) -> Tensor:
183+
return t
184+
185+
def __call__(
186+
self,
187+
x: Tensor,
188+
x_lr: Tensor,
189+
t: Tensor,
190+
class_labels: Optional[Tensor],
191+
global_index: Optional[Tensor] = None,
192+
embedding_selector: Optional[Callable] = None,
193+
) -> Tensor:
194+
# Mock behavior: return input tensor for testing purposes
195+
return x * 0.9 + x_lr[:, : x.shape[1], :, :] * 0.1
196+
197+
net = MockNet()
198+
199+
img_shape_y, img_shape_x = 256, 64
200+
patch_shape_y, patch_shape_x = 16, 10
201+
202+
latents = torch.randn(2, 3, img_shape_y, img_shape_x, device=device) # Mock latents
203+
img_lr = torch.randn(
204+
2, 3, img_shape_y, img_shape_x, device=device
205+
) # Mock low-res image
206+
207+
# Tensors with requires grad
208+
a = torch.randn(1, requires_grad=True, device=device)
209+
b = torch.randn(1, requires_grad=True, device=device)
210+
c = torch.randn(1, requires_grad=True, device=device)
211+
d = torch.randn(1, requires_grad=True, device=device)
212+
e = torch.randn(1, requires_grad=True, device=device)
213+
f = torch.randn(1, requires_grad=True, device=device)
214+
215+
# Test with patching
216+
patching = GridPatching2D(
217+
img_shape=(img_shape_y, img_shape_x),
218+
patch_shape=(patch_shape_y, patch_shape_x),
219+
overlap_pix=4,
220+
boundary_pix=2,
221+
)
222+
223+
# Test with mean_hr conditioning
224+
mean_hr = torch.randn(2, 3, img_shape_y, img_shape_x, device=device)
225+
result_mean_hr = stochastic_sampler(
226+
net=net,
227+
latents=a * latents + b,
228+
img_lr=c * img_lr + d,
229+
patching=patching,
230+
mean_hr=e * mean_hr + f,
231+
num_steps=2,
232+
sigma_min=0.002,
233+
sigma_max=800,
234+
rho=7,
235+
S_churn=0,
236+
S_min=0,
237+
S_max=float("inf"),
238+
S_noise=1,
239+
)
240+
241+
assert (
242+
result_mean_hr.shape == latents.shape
243+
), "Mean HR conditioned output shape does not match expected shape"
244+
245+
loss = result_mean_hr.sum()
246+
loss.backward()
247+
248+
assert a.grad is not None
249+
assert b.grad is not None
250+
assert c.grad is not None
251+
assert d.grad is not None
252+
assert e.grad is not None
253+
assert f.grad is not None

0 commit comments

Comments
 (0)