-
Notifications
You must be signed in to change notification settings - Fork 377
Fix casting in SongUNetPosEmbd and shape in CorrDiff generation #982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
23f79b5
bab3815
f124ba9
2a817f7
12f502f
ab4c29b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -857,9 +857,6 @@ def forward( | |
"Cannot provide both embedding_selector and global_index." | ||
) | ||
|
||
if x.dtype != self.pos_embd.dtype: | ||
self.pos_embd = self.pos_embd.to(x.dtype) | ||
|
||
# Append positional embedding to input conditioning | ||
if self.pos_embd is not None: | ||
# Select positional embeddings with a selector function | ||
|
@@ -877,18 +874,19 @@ def forward( | |
|
||
out = super().forward(x, noise_labels, class_labels, augment_labels) | ||
|
||
if self.lead_time_mode: | ||
if self.lead_time_mode and self.prob_channels: | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# if training mode, let crossEntropyLoss do softmax. The model outputs logits. | ||
# if eval mode, the model outputs probability | ||
if self.prob_channels and out.dtype != self.scalar.dtype: | ||
self.scalar.data = self.scalar.data.to(out.dtype) | ||
if self.prob_channels and (not self.training): | ||
out[:, self.prob_channels] = ( | ||
out[:, self.prob_channels] * self.scalar | ||
).softmax(dim=1) | ||
elif self.prob_channels and self.training: | ||
scalar = self.scalar | ||
if out.dtype != scalar.dtype: | ||
scalar = scalar.to(out.dtype) | ||
if self.training: | ||
out[:, self.prob_channels] = out[:, self.prob_channels] * scalar | ||
else: | ||
out[:, self.prob_channels] = ( | ||
out[:, self.prob_channels] * self.scalar | ||
(out[:, self.prob_channels] * scalar) | ||
.softmax(dim=1) | ||
.to(out.dtype) | ||
Comment on lines
-883
to
+889
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM, but could you just post below an MRE of the bug you encountered with the former casting logic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In |
||
) | ||
return out | ||
|
||
|
@@ -947,15 +945,16 @@ def positional_embedding_indexing( | |
""" | ||
# If no global indices are provided, select all embeddings and expand | ||
# to match the batch size of the input | ||
if x.dtype != self.pos_embd.dtype: | ||
self.pos_embd = self.pos_embd.to(x.dtype) | ||
pos_embd = self.pos_embd | ||
if x.dtype != pos_embd.dtype: | ||
pos_embd = pos_embd.to(x.dtype) | ||
Comment on lines
+948
to
+950
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two remarks:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
if global_index is None: | ||
if self.lead_time_mode: | ||
selected_pos_embd = [] | ||
if self.pos_embd is not None: | ||
if pos_embd is not None: | ||
selected_pos_embd.append( | ||
self.pos_embd[None].expand((x.shape[0], -1, -1, -1)) | ||
pos_embd[None].expand((x.shape[0], -1, -1, -1)) | ||
) | ||
if self.lt_embd is not None: | ||
selected_pos_embd.append( | ||
|
@@ -972,7 +971,7 @@ def positional_embedding_indexing( | |
if len(selected_pos_embd) > 0: | ||
selected_pos_embd = torch.cat(selected_pos_embd, dim=1) | ||
else: | ||
selected_pos_embd = self.pos_embd[None].expand( | ||
selected_pos_embd = pos_embd[None].expand( | ||
(x.shape[0], -1, -1, -1) | ||
) # (B, C_{PE}, H, W) | ||
|
||
|
@@ -985,11 +984,11 @@ def positional_embedding_indexing( | |
global_index = torch.reshape( | ||
torch.permute(global_index, (1, 0, 2, 3)), (2, -1) | ||
) # (P, 2, X, Y) to (2, P*X*Y) | ||
selected_pos_embd = self.pos_embd[ | ||
selected_pos_embd = pos_embd[ | ||
:, global_index[0], global_index[1] | ||
] # (C_pe, P*X*Y) | ||
selected_pos_embd = torch.permute( | ||
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), | ||
torch.reshape(selected_pos_embd, (pos_embd.shape[0], P, H, W)), | ||
(1, 0, 2, 3), | ||
) # (P, C_pe, X, Y) | ||
|
||
|
@@ -1000,7 +999,7 @@ def positional_embedding_indexing( | |
# Append positional and lead time embeddings to input conditioning | ||
if self.lead_time_mode: | ||
embeds = [] | ||
if self.pos_embd is not None: | ||
if pos_embd is not None: | ||
embeds.append(selected_pos_embd) # reuse code below | ||
if self.lt_embd is not None: | ||
lt_embds = self.lt_embd[ | ||
|
@@ -1086,15 +1085,12 @@ def positional_embedding_selector( | |
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) | ||
>>> | ||
""" | ||
if x.dtype != self.pos_embd.dtype: | ||
self.pos_embd = self.pos_embd.to(x.dtype) | ||
embeddings = self.pos_embd | ||
if x.dtype != embeddings.dtype: | ||
embeddings = embeddings.to(x.dtype) | ||
Comment on lines
+1088
to
+1090
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two remarks:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if lead_time_label is not None: | ||
# all patches share same lead_time_label | ||
embeddings = torch.cat( | ||
[self.pos_embd, self.lt_embd[lead_time_label[0].int()]] | ||
) | ||
else: | ||
embeddings = self.pos_embd | ||
embeddings = torch.cat([embeddings, self.lt_embd[lead_time_label[0].int()]]) | ||
return embedding_selector(embeddings) # (B, N_pe, H, W) | ||
|
||
def _get_positional_embedding(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@juliusberner could you explain the reason for this change? AFAIK the batch dimension of
latents_shape
is never really used, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the batch-size that the output of
regression_step
is expanded to. Since we computeimage_out = image_reg + image_res
later, this needs to match the batch-size of the output ofdiffusion_step
.