Skip to content

Commit 4b7675f

Browse files
committed
[ADD] better names in loss weight
1 parent d3c6532 commit 4b7675f

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

sgm/modules/diffusionmodules/loss_weighting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
3434
class SevaWeighting(DiffusionLossWeighting):
3535
def __call__(self, sigma: torch.Tensor, mask, max_weight=5.0) -> torch.Tensor:
3636
bools = mask.to(torch.bool)
37-
batch_size, num_frames = bools.shape
38-
indices = torch.arange(num_frames, device=bools.device).unsqueeze(0).expand(batch_size, num_frames)
39-
weights = torch.full((batch_size, num_frames), max_weight, dtype=torch.float, device=bools.device)
37+
batch_size, N = bools.shape
38+
indices = torch.arange(N, device=bools.device).unsqueeze(0).expand(batch_size, N)
39+
weights = torch.full((batch_size, N), max_weight, dtype=torch.float, device=bools.device)
4040

4141
for b in range(batch_size):
4242
true_idx = indices[b][bools[b]]

sgm/modules/diffusionmodules/wrappers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,12 @@ def forward(
5757
x = rearrange(x, "b f c h w -> (b f) c h w")
5858
dense_y=rearrange(c["plucker"], "b f c h w -> (b f) c h w")
5959

60-
#TODO: remove
61-
c = torch.zeros((b, 1, 1024)).type_as(x).to(x.device)
62-
c = repeat(c, "b 1 c -> (b f) 1 c", f=f)
6360
t = repeat(t, "b -> (b f)", f=f)
6461

6562
out = self.diffusion_model(
6663
x,
6764
t=t,
68-
y=c, # c["crossattn"]
65+
y=c["crossattn"],
6966
dense_y=dense_y,
7067
num_frames=f,
7168
**kwargs,

0 commit comments

Comments
 (0)