Skip to content

Commit f124ba9

Browse files
author
Julius Berner
committed
Avoid dtype change of buffer/param and fix softmax dtype
1 parent bab3815 commit f124ba9

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

physicsnemo/models/diffusion/song_unet.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -857,9 +857,6 @@ def forward(
857857
"Cannot provide both embedding_selector and global_index."
858858
)
859859

860-
if x.dtype != self.pos_embd.dtype:
861-
self.pos_embd = self.pos_embd.to(x.dtype)
862-
863860
# Append positional embedding to input conditioning
864861
if self.pos_embd is not None:
865862
# Select positional embeddings with a selector function
@@ -877,18 +874,19 @@ def forward(
877874

878875
out = super().forward(x, noise_labels, class_labels, augment_labels)
879876

880-
if self.lead_time_mode:
877+
if self.lead_time_mode and self.prob_channels:
881878
# if training mode, let crossEntropyLoss do softmax. The model outputs logits.
882879
# if eval mode, the model outputs probability
883-
if self.prob_channels and out.dtype != self.scalar.dtype:
884-
self.scalar.data = self.scalar.data.to(out.dtype)
885-
if self.prob_channels and (not self.training):
886-
out[:, self.prob_channels] = (
887-
out[:, self.prob_channels] * self.scalar
888-
).softmax(dim=1)
889-
elif self.prob_channels and self.training:
880+
scalar = self.scalar
881+
if out.dtype != scalar.dtype:
882+
scalar = scalar.to(out.dtype)
883+
if self.training:
884+
out[:, self.prob_channels] = out[:, self.prob_channels] * scalar
885+
else:
890886
out[:, self.prob_channels] = (
891-
out[:, self.prob_channels] * self.scalar
887+
(out[:, self.prob_channels] * scalar)
888+
.softmax(dim=1)
889+
.to(out.dtype)
892890
)
893891
return out
894892

@@ -947,15 +945,16 @@ def positional_embedding_indexing(
947945
"""
948946
# If no global indices are provided, select all embeddings and expand
949947
# to match the batch size of the input
950-
if x.dtype != self.pos_embd.dtype:
951-
self.pos_embd = self.pos_embd.to(x.dtype)
948+
pos_embd = self.pos_embd
949+
if x.dtype != pos_embd.dtype:
950+
pos_embd = pos_embd.to(x.dtype)
952951

953952
if global_index is None:
954953
if self.lead_time_mode:
955954
selected_pos_embd = []
956-
if self.pos_embd is not None:
955+
if pos_embd is not None:
957956
selected_pos_embd.append(
958-
self.pos_embd[None].expand((x.shape[0], -1, -1, -1))
957+
pos_embd[None].expand((x.shape[0], -1, -1, -1))
959958
)
960959
if self.lt_embd is not None:
961960
selected_pos_embd.append(
@@ -972,7 +971,7 @@ def positional_embedding_indexing(
972971
if len(selected_pos_embd) > 0:
973972
selected_pos_embd = torch.cat(selected_pos_embd, dim=1)
974973
else:
975-
selected_pos_embd = self.pos_embd[None].expand(
974+
selected_pos_embd = pos_embd[None].expand(
976975
(x.shape[0], -1, -1, -1)
977976
) # (B, C_{PE}, H, W)
978977

@@ -985,11 +984,11 @@ def positional_embedding_indexing(
985984
global_index = torch.reshape(
986985
torch.permute(global_index, (1, 0, 2, 3)), (2, -1)
987986
) # (P, 2, X, Y) to (2, P*X*Y)
988-
selected_pos_embd = self.pos_embd[
987+
selected_pos_embd = pos_embd[
989988
:, global_index[0], global_index[1]
990989
] # (C_pe, P*X*Y)
991990
selected_pos_embd = torch.permute(
992-
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
991+
torch.reshape(selected_pos_embd, (pos_embd.shape[0], P, H, W)),
993992
(1, 0, 2, 3),
994993
) # (P, C_pe, X, Y)
995994

@@ -1000,7 +999,7 @@ def positional_embedding_indexing(
1000999
# Append positional and lead time embeddings to input conditioning
10011000
if self.lead_time_mode:
10021001
embeds = []
1003-
if self.pos_embd is not None:
1002+
if pos_embd is not None:
10041003
embeds.append(selected_pos_embd) # reuse code below
10051004
if self.lt_embd is not None:
10061005
lt_embds = self.lt_embd[
@@ -1086,15 +1085,12 @@ def positional_embedding_selector(
10861085
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
10871086
>>>
10881087
"""
1089-
if x.dtype != self.pos_embd.dtype:
1090-
self.pos_embd = self.pos_embd.to(x.dtype)
1088+
embeddings = self.pos_embd
1089+
if x.dtype != embeddings.dtype:
1090+
embeddings = embeddings.to(x.dtype)
10911091
if lead_time_label is not None:
10921092
# all patches share same lead_time_label
1093-
embeddings = torch.cat(
1094-
[self.pos_embd, self.lt_embd[lead_time_label[0].int()]]
1095-
)
1096-
else:
1097-
embeddings = self.pos_embd
1093+
embeddings = torch.cat([embeddings, self.lt_embd[lead_time_label[0].int()]])
10981094
return embedding_selector(embeddings) # (B, N_pe, H, W)
10991095

11001096
def _get_positional_embedding(self):

0 commit comments

Comments
 (0)