Skip to content

Fix casting in SongUNetPosEmbd and shape in CorrDiff generation #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

juliusberner
Copy link
Contributor

@juliusberner juliusberner commented Jun 17, 2025

PhysicsNeMo Pull Request

Description

  1. Fix regression output shape in CorrDiff
  2. Only use act if fused_act is True in ApexGroupNorm
  3. Avoid dtype change of attributes (since self.pos_embd can be buffer or parameter) and fix dtype of softmax output (which is fp32) for SongUNetPosEmbd
  4. Avoid changing the dtype of the .data attribute of self.scalar to enable .compile for SongUNetPosEmbd.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

@juliusberner juliusberner changed the title Fix dtype in SognUNet and shape in CorrDiff generation Fix dtype in SongUNetPosEmbd and shape in CorrDiff generation Jun 18, 2025
@juliusberner juliusberner changed the title Fix dtype in SongUNetPosEmbd and shape in CorrDiff generation Fix dtypes in SongUNetPosEmbd and shape in CorrDiff generation Jun 18, 2025
@juliusberner juliusberner changed the title Fix dtypes in SongUNetPosEmbd and shape in CorrDiff generation Fix casting in SongUNetPosEmbd and shape in CorrDiff generation Jun 18, 2025
@juliusberner juliusberner force-pushed the jberner/fix_corrdiff branch from 92a58a3 to b7e0382 Compare June 18, 2025 00:07
@juliusberner juliusberner force-pushed the jberner/fix_corrdiff branch from b7e0382 to f124ba9 Compare June 18, 2025 00:14
@CharlelieLrt CharlelieLrt self-requested a review June 18, 2025 00:17
@CharlelieLrt CharlelieLrt added bug Something isn't working 2 - In Progress Currently a work in progress labels Jun 18, 2025
Copy link
Collaborator

@CharlelieLrt CharlelieLrt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Just need a few clarifications and a few MREs for the bugs that this PR fixes.

@@ -196,7 +196,7 @@ def generate_fn():
net=net_reg,
img_lr=img_lr,
latents_shape=(
cfg.generation.seed_batch_size,
sum(map(len, rank_batches)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juliusberner could you explain the reason for this change? AFAIK the batch dimension of latents_shape is never really used, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the batch-size that the output of regression_step is expanded to. Since we compute image_out = image_reg + image_res later, this needs to match the batch-size of the output of diffusion_step.

Comment on lines -883 to +889
if self.prob_channels and out.dtype != self.scalar.dtype:
self.scalar.data = self.scalar.data.to(out.dtype)
if self.prob_channels and (not self.training):
out[:, self.prob_channels] = (
out[:, self.prob_channels] * self.scalar
).softmax(dim=1)
elif self.prob_channels and self.training:
scalar = self.scalar
if out.dtype != scalar.dtype:
scalar = scalar.to(out.dtype)
if self.training:
out[:, self.prob_channels] = out[:, self.prob_channels] * scalar
else:
out[:, self.prob_channels] = (
out[:, self.prob_channels] * self.scalar
(out[:, self.prob_channels] * scalar)
.softmax(dim=1)
.to(out.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but could you just post below an MRE of the bug you encountered with the former casting logic?

Copy link
Contributor Author

@juliusberner juliusberner Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In amp-bf16 training, the output of softmax is float32, while out.dtype = bfloat16, which gives a RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source..

Comment on lines +948 to +950
pos_embd = self.pos_embd
if x.dtype != pos_embd.dtype:
pos_embd = pos_embd.to(x.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two remarks:

  1. Same as above: could you post below a MRE of the bug that you would get with the former casting logic (the MRE could be grouped with the one above)
  2. Is there some logic problem here? We are accessing pos_embd.dtype and right below we are checking if pos_embd is not None? I think if self.pos_embd is None then we do return None right away?

Copy link
Contributor Author

@juliusberner juliusberner Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The assignment self.pos_embd = self.pos_embd.to(dtype) only works if pos_embd is a buffer but not if it is a parameter (which is the case if self.gridtype == "learnable"). Thus, we define a new local variable which works in both cases.
  2. positional_embedding_indexing is only called in the forward if self.pos_embd is not None. If it is called from outside, it would return an empty list []. How should we handle it?

Comment on lines +1088 to +1090
embeddings = self.pos_embd
if x.dtype != embeddings.dtype:
embeddings = embeddings.to(x.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two remarks:

  1. Same as above, it would be great if you could post an MRE below (can be grouped with other RMEs for these castings bugs).
  2. Is there a specific reason to call it embeddings here, whereas it was called pos_embd in the positional_embedding_indexing method? If not, let's try to remain consistent in the names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Copying from above: The assignment self.pos_embd = self.pos_embd.to(dtype) only works if pos_embd is a buffer but not if it is a parameter (which is the case if self.gridtype == "learnable"). Thus, we define a new local variable which works in both cases.
  2. I took it from the existing code, but it makes sense to rename it to pos_embd.

CharlelieLrt and others added 2 commits June 26, 2025 16:40
…ding, lead time aware, with compile, apex_gn, etc...

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2 - In Progress Currently a work in progress bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants