Skip to content

Commit 310a862

Browse files
author
Lukas Sanner
committed
fixing type and device casting
1 parent e43ceb6 commit 310a862

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

segmentation_models_pytorch/losses/_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def focal_loss_with_logits(
6666
References:
6767
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
6868
"""
69-
target = target.type(output.type())
69+
target = target.to(dtype=output.dtype, device=output.device)
7070

7171
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
7272
pt = torch.exp(-logpt)

0 commit comments

Comments
 (0)