@@ -857,9 +857,6 @@ def forward(
857
857
"Cannot provide both embedding_selector and global_index."
858
858
)
859
859
860
- if x .dtype != self .pos_embd .dtype :
861
- self .pos_embd = self .pos_embd .to (x .dtype )
862
-
863
860
# Append positional embedding to input conditioning
864
861
if self .pos_embd is not None :
865
862
# Select positional embeddings with a selector function
@@ -877,18 +874,19 @@ def forward(
877
874
878
875
out = super ().forward (x , noise_labels , class_labels , augment_labels )
879
876
880
- if self .lead_time_mode :
877
+ if self .lead_time_mode and self . prob_channels :
881
878
# if training mode, let crossEntropyLoss do softmax. The model outputs logits.
882
879
# if eval mode, the model outputs probability
883
- if self .prob_channels and out .dtype != self .scalar .dtype :
884
- self .scalar .data = self .scalar .data .to (out .dtype )
885
- if self .prob_channels and (not self .training ):
886
- out [:, self .prob_channels ] = (
887
- out [:, self .prob_channels ] * self .scalar
888
- ).softmax (dim = 1 )
889
- elif self .prob_channels and self .training :
880
+ scalar = self .scalar
881
+ if out .dtype != scalar .dtype :
882
+ scalar = scalar .to (out .dtype )
883
+ if self .training :
884
+ out [:, self .prob_channels ] = out [:, self .prob_channels ] * scalar
885
+ else :
890
886
out [:, self .prob_channels ] = (
891
- out [:, self .prob_channels ] * self .scalar
887
+ (out [:, self .prob_channels ] * scalar )
888
+ .softmax (dim = 1 )
889
+ .to (out .dtype )
892
890
)
893
891
return out
894
892
@@ -947,15 +945,16 @@ def positional_embedding_indexing(
947
945
"""
948
946
# If no global indices are provided, select all embeddings and expand
949
947
# to match the batch size of the input
950
- if x .dtype != self .pos_embd .dtype :
951
- self .pos_embd = self .pos_embd .to (x .dtype )
948
+ pos_embd = self .pos_embd
949
+ if x .dtype != pos_embd .dtype :
950
+ pos_embd = pos_embd .to (x .dtype )
952
951
953
952
if global_index is None :
954
953
if self .lead_time_mode :
955
954
selected_pos_embd = []
956
- if self . pos_embd is not None :
955
+ if pos_embd is not None :
957
956
selected_pos_embd .append (
958
- self . pos_embd [None ].expand ((x .shape [0 ], - 1 , - 1 , - 1 ))
957
+ pos_embd [None ].expand ((x .shape [0 ], - 1 , - 1 , - 1 ))
959
958
)
960
959
if self .lt_embd is not None :
961
960
selected_pos_embd .append (
@@ -972,7 +971,7 @@ def positional_embedding_indexing(
972
971
if len (selected_pos_embd ) > 0 :
973
972
selected_pos_embd = torch .cat (selected_pos_embd , dim = 1 )
974
973
else :
975
- selected_pos_embd = self . pos_embd [None ].expand (
974
+ selected_pos_embd = pos_embd [None ].expand (
976
975
(x .shape [0 ], - 1 , - 1 , - 1 )
977
976
) # (B, C_{PE}, H, W)
978
977
@@ -985,11 +984,11 @@ def positional_embedding_indexing(
985
984
global_index = torch .reshape (
986
985
torch .permute (global_index , (1 , 0 , 2 , 3 )), (2 , - 1 )
987
986
) # (P, 2, X, Y) to (2, P*X*Y)
988
- selected_pos_embd = self . pos_embd [
987
+ selected_pos_embd = pos_embd [
989
988
:, global_index [0 ], global_index [1 ]
990
989
] # (C_pe, P*X*Y)
991
990
selected_pos_embd = torch .permute (
992
- torch .reshape (selected_pos_embd , (self . pos_embd .shape [0 ], P , H , W )),
991
+ torch .reshape (selected_pos_embd , (pos_embd .shape [0 ], P , H , W )),
993
992
(1 , 0 , 2 , 3 ),
994
993
) # (P, C_pe, X, Y)
995
994
@@ -1000,7 +999,7 @@ def positional_embedding_indexing(
1000
999
# Append positional and lead time embeddings to input conditioning
1001
1000
if self .lead_time_mode :
1002
1001
embeds = []
1003
- if self . pos_embd is not None :
1002
+ if pos_embd is not None :
1004
1003
embeds .append (selected_pos_embd ) # reuse code below
1005
1004
if self .lt_embd is not None :
1006
1005
lt_embds = self .lt_embd [
@@ -1086,15 +1085,12 @@ def positional_embedding_selector(
1086
1085
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
1087
1086
>>>
1088
1087
"""
1089
- if x .dtype != self .pos_embd .dtype :
1090
- self .pos_embd = self .pos_embd .to (x .dtype )
1088
+ embeddings = self .pos_embd
1089
+ if x .dtype != embeddings .dtype :
1090
+ embeddings = embeddings .to (x .dtype )
1091
1091
if lead_time_label is not None :
1092
1092
# all patches share same lead_time_label
1093
- embeddings = torch .cat (
1094
- [self .pos_embd , self .lt_embd [lead_time_label [0 ].int ()]]
1095
- )
1096
- else :
1097
- embeddings = self .pos_embd
1093
+ embeddings = torch .cat ([embeddings , self .lt_embd [lead_time_label [0 ].int ()]])
1098
1094
return embedding_selector (embeddings ) # (B, N_pe, H, W)
1099
1095
1100
1096
def _get_positional_embedding (self ):
0 commit comments