-
Notifications
You must be signed in to change notification settings - Fork 377
Fix casting in SongUNetPosEmbd and shape in CorrDiff generation #982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
92a58a3
to
b7e0382
Compare
b7e0382
to
f124ba9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Just need a few clarifications and a few MREs for the bugs that this PR fixes.
@@ -196,7 +196,7 @@ def generate_fn(): | |||
net=net_reg, | |||
img_lr=img_lr, | |||
latents_shape=( | |||
cfg.generation.seed_batch_size, | |||
sum(map(len, rank_batches)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@juliusberner could you explain the reason for this change? AFAIK the batch dimension of latents_shape
is never really used, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the batch-size that the output of regression_step
is expanded to. Since we compute image_out = image_reg + image_res
later, this needs to match the batch-size of the output of diffusion_step
.
if self.prob_channels and out.dtype != self.scalar.dtype: | ||
self.scalar.data = self.scalar.data.to(out.dtype) | ||
if self.prob_channels and (not self.training): | ||
out[:, self.prob_channels] = ( | ||
out[:, self.prob_channels] * self.scalar | ||
).softmax(dim=1) | ||
elif self.prob_channels and self.training: | ||
scalar = self.scalar | ||
if out.dtype != scalar.dtype: | ||
scalar = scalar.to(out.dtype) | ||
if self.training: | ||
out[:, self.prob_channels] = out[:, self.prob_channels] * scalar | ||
else: | ||
out[:, self.prob_channels] = ( | ||
out[:, self.prob_channels] * self.scalar | ||
(out[:, self.prob_channels] * scalar) | ||
.softmax(dim=1) | ||
.to(out.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but could you just post below an MRE of the bug you encountered with the former casting logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In amp-bf16
training, the output of softmax
is float32
, while out.dtype = bfloat16
, which gives a RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.
.
pos_embd = self.pos_embd | ||
if x.dtype != pos_embd.dtype: | ||
pos_embd = pos_embd.to(x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two remarks:
- Same as above: could you post below a MRE of the bug that you would get with the former casting logic (the MRE could be grouped with the one above)
- Is there some logic problem here? We are accessing
pos_embd.dtype
and right below we are checkingif pos_embd is not None
? I think ifself.pos_embd
isNone
then we doreturn None
right away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The assignment
self.pos_embd = self.pos_embd.to(dtype)
only works ifpos_embd
is a buffer but not if it is a parameter (which is the case ifself.gridtype == "learnable"
). Thus, we define a new local variable which works in both cases. positional_embedding_indexing
is only called in theforward
ifself.pos_embd is not None
. If it is called from outside, it would return an empty list[]
. How should we handle it?
embeddings = self.pos_embd | ||
if x.dtype != embeddings.dtype: | ||
embeddings = embeddings.to(x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two remarks:
- Same as above, it would be great if you could post an MRE below (can be grouped with other RMEs for these castings bugs).
- Is there a specific reason to call it
embeddings
here, whereas it was calledpos_embd
in thepositional_embedding_indexing
method? If not, let's try to remain consistent in the names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Copying from above: The assignment
self.pos_embd = self.pos_embd.to(dtype)
only works ifpos_embd
is a buffer but not if it is a parameter (which is the case ifself.gridtype == "learnable"
). Thus, we define a new local variable which works in both cases. - I took it from the existing code, but it makes sense to rename it to
pos_embd
.
…ding, lead time aware, with compile, apex_gn, etc... Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
PhysicsNeMo Pull Request
Description
fused_act
isTrue
inApexGroupNorm
self.pos_embd
can be buffer or parameter) and fix dtype of softmax output (which isfp32
) forSongUNetPosEmbd
.data
attribute ofself.scalar
to enable.compile
forSongUNetPosEmbd
.Checklist