Skip to content

Fix casting in SongUNetPosEmbd and shape in CorrDiff generation #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/weather/corrdiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def generate_fn():
net=net_reg,
img_lr=img_lr,
latents_shape=(
cfg.generation.seed_batch_size,
sum(map(len, rank_batches)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juliusberner could you explain the reason for this change? AFAIK the batch dimension of latents_shape is never really used, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the batch-size that the output of regression_step is expanded to. Since we compute image_out = image_reg + image_res later, this needs to match the batch-size of the output of diffusion_step.

img_out_channels,
img_shape[0],
img_shape[1],
Expand Down
2 changes: 1 addition & 1 deletion physicsnemo/models/diffusion/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def __init__(
self.act_fn = None
self.amp_mode = amp_mode
if self.use_apex_gn:
if self.act:
if self.fused_act:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
Expand Down
50 changes: 23 additions & 27 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,6 @@ def forward(
"Cannot provide both embedding_selector and global_index."
)

if x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)

# Append positional embedding to input conditioning
if self.pos_embd is not None:
# Select positional embeddings with a selector function
Expand All @@ -877,18 +874,19 @@ def forward(

out = super().forward(x, noise_labels, class_labels, augment_labels)

if self.lead_time_mode:
if self.lead_time_mode and self.prob_channels:
# if training mode, let crossEntropyLoss do softmax. The model outputs logits.
# if eval mode, the model outputs probability
if self.prob_channels and out.dtype != self.scalar.dtype:
self.scalar.data = self.scalar.data.to(out.dtype)
if self.prob_channels and (not self.training):
out[:, self.prob_channels] = (
out[:, self.prob_channels] * self.scalar
).softmax(dim=1)
elif self.prob_channels and self.training:
scalar = self.scalar
if out.dtype != scalar.dtype:
scalar = scalar.to(out.dtype)
if self.training:
out[:, self.prob_channels] = out[:, self.prob_channels] * scalar
else:
out[:, self.prob_channels] = (
out[:, self.prob_channels] * self.scalar
(out[:, self.prob_channels] * scalar)
.softmax(dim=1)
.to(out.dtype)
Comment on lines -883 to +889
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but could you just post below an MRE of the bug you encountered with the former casting logic?

Copy link
Contributor Author

@juliusberner juliusberner Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In amp-bf16 training, the output of softmax is float32, while out.dtype = bfloat16, which gives a RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source..

)
return out

Expand Down Expand Up @@ -947,15 +945,16 @@ def positional_embedding_indexing(
"""
# If no global indices are provided, select all embeddings and expand
# to match the batch size of the input
if x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)
pos_embd = self.pos_embd
if x.dtype != pos_embd.dtype:
pos_embd = pos_embd.to(x.dtype)
Comment on lines +948 to +950
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two remarks:

  1. Same as above: could you post below a MRE of the bug that you would get with the former casting logic (the MRE could be grouped with the one above)
  2. Is there some logic problem here? We are accessing pos_embd.dtype and right below we are checking if pos_embd is not None? I think if self.pos_embd is None then we do return None right away?

Copy link
Contributor Author

@juliusberner juliusberner Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The assignment self.pos_embd = self.pos_embd.to(dtype) only works if pos_embd is a buffer but not if it is a parameter (which is the case if self.gridtype == "learnable"). Thus, we define a new local variable which works in both cases.
  2. positional_embedding_indexing is only called in the forward if self.pos_embd is not None. If it is called from outside, it would return an empty list []. How should we handle it?


if global_index is None:
if self.lead_time_mode:
selected_pos_embd = []
if self.pos_embd is not None:
if pos_embd is not None:
selected_pos_embd.append(
self.pos_embd[None].expand((x.shape[0], -1, -1, -1))
pos_embd[None].expand((x.shape[0], -1, -1, -1))
)
if self.lt_embd is not None:
selected_pos_embd.append(
Expand All @@ -972,7 +971,7 @@ def positional_embedding_indexing(
if len(selected_pos_embd) > 0:
selected_pos_embd = torch.cat(selected_pos_embd, dim=1)
else:
selected_pos_embd = self.pos_embd[None].expand(
selected_pos_embd = pos_embd[None].expand(
(x.shape[0], -1, -1, -1)
) # (B, C_{PE}, H, W)

Expand All @@ -985,11 +984,11 @@ def positional_embedding_indexing(
global_index = torch.reshape(
torch.permute(global_index, (1, 0, 2, 3)), (2, -1)
) # (P, 2, X, Y) to (2, P*X*Y)
selected_pos_embd = self.pos_embd[
selected_pos_embd = pos_embd[
:, global_index[0], global_index[1]
] # (C_pe, P*X*Y)
selected_pos_embd = torch.permute(
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
torch.reshape(selected_pos_embd, (pos_embd.shape[0], P, H, W)),
(1, 0, 2, 3),
) # (P, C_pe, X, Y)

Expand All @@ -1000,7 +999,7 @@ def positional_embedding_indexing(
# Append positional and lead time embeddings to input conditioning
if self.lead_time_mode:
embeds = []
if self.pos_embd is not None:
if pos_embd is not None:
embeds.append(selected_pos_embd) # reuse code below
if self.lt_embd is not None:
lt_embds = self.lt_embd[
Expand Down Expand Up @@ -1086,15 +1085,12 @@ def positional_embedding_selector(
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
>>>
"""
if x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)
embeddings = self.pos_embd
if x.dtype != embeddings.dtype:
embeddings = embeddings.to(x.dtype)
Comment on lines +1088 to +1090
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two remarks:

  1. Same as above, it would be great if you could post an MRE below (can be grouped with other RMEs for these castings bugs).
  2. Is there a specific reason to call it embeddings here, whereas it was called pos_embd in the positional_embedding_indexing method? If not, let's try to remain consistent in the names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Copying from above: The assignment self.pos_embd = self.pos_embd.to(dtype) only works if pos_embd is a buffer but not if it is a parameter (which is the case if self.gridtype == "learnable"). Thus, we define a new local variable which works in both cases.
  2. I took it from the existing code, but it makes sense to rename it to pos_embd.

if lead_time_label is not None:
# all patches share same lead_time_label
embeddings = torch.cat(
[self.pos_embd, self.lt_embd[lead_time_label[0].int()]]
)
else:
embeddings = self.pos_embd
embeddings = torch.cat([embeddings, self.lt_embd[lead_time_label[0].int()]])
return embedding_selector(embeddings) # (B, N_pe, H, W)

def _get_positional_embedding(self):
Expand Down
12 changes: 12 additions & 0 deletions test/models/common/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ def forward(*args, **kwargs):
return forward


def torch_compile_model(
model: physicsnemo.Module, fullgraph: bool = True, error_on_recompile: bool = False
) -> bool:
backend = (
nop_backend # for fast compilation for fx graph capture, use a nop backend
)
torch._dynamo.reset()
torch._dynamo.config.error_on_recompile = error_on_recompile
model = torch.compile(model, backend=backend, fullgraph=fullgraph)
return model


def validate_torch_compile(
model: physicsnemo.Module,
in_args: Tuple[Tensor] = (),
Expand Down
92 changes: 78 additions & 14 deletions test/models/diffusion/test_song_unet_pos_embd_agn_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,58 @@
from physicsnemo.models.diffusion import SongUNetPosEmbd as UNet


def setup_model_learnable_embd():
# Smaller architecture variant with learnable positional embeddings
# (more similar to CorrDiff example)
N_pos = 4
model = UNet(
img_resolution=128,
in_channels=2 + N_pos,
out_channels=2,
model_channels=32,
channel_mult_emb=2,
gridtype="learnable",
N_grid_channels=N_pos,
use_apex_gn=True,
amp_mode=True,
)
return model


# Test forward pass with AMP, Apex GN, and compile
@pytest.mark.parametrize("device", ["cuda:0"])
def test_song_unet_forward(device):
torch.manual_seed(0)
H, W = 32, 64
model = (
setup_model_learnable_embd().to(device).to(memory_format=torch.channels_last)
)
input_image = torch.ones([1, 2, H, W]).to(device)
noise_labels = torch.randn([1]).to(device)
class_labels = torch.randint(0, 1, (1, 1)).to(device)
idx_x = torch.arange(45, 45 + H)
idx_y = torch.arange(12, 12 + W)
mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y)
global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to(device)

# Compile model
model = common.torch_compile_model(model)

with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
output_image = model(input_image, noise_labels, class_labels, global_index)
assert output_image.shape == (1, 2, H, W)

# TODO: add non-regression test
return


@pytest.mark.parametrize("device", ["cuda:0"])
def test_song_unet_global_indexing(device):
torch.manual_seed(0)
N_pos = 2
batch_shape_x = 32
batch_shape_y = 64
# Construct the DDM++ UNet model
H, W = 32, 64

# Construct the DDM++ UNet model
model = (
UNet(
img_resolution=128,
Expand All @@ -49,20 +93,40 @@ def test_song_unet_global_indexing(device):
.to(device)
.to(memory_format=torch.channels_last)
)
input_image = torch.ones([1, 2, batch_shape_x, batch_shape_y]).to(device)
noise_labels = noise_labels = torch.randn([1]).to(device)
input_image = torch.ones([1, 2, H, W]).to(device)
noise_labels = torch.randn([1]).to(device)
class_labels = torch.randint(0, 1, (1, 1)).to(device)
idx_x = torch.arange(45, 45 + batch_shape_x)
idx_y = torch.arange(12, 12 + batch_shape_y)
idx_x = torch.arange(45, 45 + H)
idx_y = torch.arange(12, 12 + W)
mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y)
global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to(device)

with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
output_image = model(input_image, noise_labels, class_labels, global_index)
pos_embed = model.positional_embedding_indexing(input_image, global_index)
assert output_image.shape == (1, 2, batch_shape_x, batch_shape_y)
assert output_image.shape == (1, 2, H, W)
assert torch.equal(pos_embed, global_index)

# Smaller architecture variant with learnable positional embeddings
# (more similar to CorrDiff example)
model = (
setup_model_learnable_embd().to(device).to(memory_format=torch.channels_last)
)
input_image = torch.ones([1, 2, H, W]).to(device)
noise_labels = torch.randn([1]).to(device)
class_labels = torch.randint(0, 1, (1, 1)).to(device)
idx_x = torch.arange(45, 45 + H)
idx_y = torch.arange(12, 12 + W)
mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y)
global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to(device)

with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
output_image = model(input_image, noise_labels, class_labels, global_index)
pos_embed = model.positional_embedding_indexing(input_image, global_index)
assert output_image.shape == (1, 2, H, W)
assert pos_embed.shape == (1, N_pos, H, W)
assert torch.equal(pos_embed, model.pos_embd[:, 45 : 45 + H, 12 : 12 + W])


@pytest.mark.parametrize("device", ["cuda:0"])
def test_song_unet_constructor(device):
Expand Down Expand Up @@ -138,12 +202,6 @@ def test_song_unet_position_embedding(device):
.to(device)
.to(memory_format=torch.channels_last)
)
noise_labels = torch.randn([1]).to(device)
class_labels = torch.randint(0, 1, (1, 1)).to(device)
input_image = torch.ones([1, 2, 16, 16]).to(device)
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
output_image = model(input_image, noise_labels, class_labels)
assert output_image.shape == (1, out_channels, img_resolution, img_resolution)
assert model.pos_embd.shape == (100, img_resolution, img_resolution)

model = (
Expand All @@ -160,6 +218,12 @@ def test_song_unet_position_embedding(device):
)
assert model.pos_embd.shape == (40, img_resolution, img_resolution)

# Test with learnable positional embeddings
model = (
setup_model_learnable_embd().to(device).to(memory_format=torch.channels_last)
)
assert model.pos_embd.shape == (4, img_resolution, img_resolution)


def test_fails_if_grid_is_invalid():
"""Test the positional embedding options. "linear" gridtype only support 2 channels, and N_grid_channels in "sinusoidal" should be a factor of 4"""
Expand Down
Loading