Skip to content

Commit c8f4a72

Browse files
FarbdruckerLukas Sannerqubvel
authored
Fix/issue 1198 value error invalid type torch.mps.float tensor (#1199)
* fixing type and device casting * updating version with fix * Revert "fixing type and device casting" This reverts commit 62682f0. * Revert "updating version with fix" This reverts commit 0b431b0. * fixing type and device casting * updating version with fix * Update segmentation_models_pytorch/__version__.py --------- Co-authored-by: Lukas Sanner <lukas.sanner@lgln.niedersachsen.de> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent 5ba632d commit c8f4a72

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.0"
1+
__version__ = "0.5.1.dev0"

segmentation_models_pytorch/losses/_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def focal_loss_with_logits(
6666
References:
6767
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
6868
"""
69-
target = target.type(output.type())
69+
target = target.to(dtype=output.dtype, device=output.device)
7070

7171
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
7272
pt = torch.exp(-logpt)

0 commit comments

Comments
 (0)