Skip to content

Modify the Dice loss #3825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddleseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class DiceLoss(nn.Layer):
"""
The implements of the dice loss.

The original article refers to
Wang, Zifu et. al. "Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels"
(https://arxiv.org/abs/2303.16296)

Args:
weight (list[float], optional): The weight for each class. Default: None.
ignore_index (int64): ignore_index (int64, optional): Specifies a target value that
Expand Down Expand Up @@ -70,8 +74,9 @@ def dice_loss_helper(logit, label, mask, smooth, eps):
mask = paddle.reshape(mask, [0, -1])
logit *= mask
label *= mask
intersection = paddle.sum(logit * label, axis=1)
difference = paddle.sum(paddle.abs(logit - label), axis=1)
cardinality = paddle.sum(logit + label, axis=1)
intersection = (cardinality - difference) / 2
dice_loss = 1 - (2 * intersection + smooth) / (cardinality + smooth + eps)
dice_loss = dice_loss.mean()
return dice_loss
102 changes: 56 additions & 46 deletions paddleseg/models/losses/maskformer_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
def dice_loss(inputs, targets, num_masks):
"""
Compute the DICE loss, similar to generalized IOU for masks

The original article refers to
Wang, Zifu et. al. "Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels"
(https://arxiv.org/abs/2303.16296)

Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
Expand All @@ -38,8 +43,9 @@ def dice_loss(inputs, targets, num_masks):
"""
inputs = F.sigmoid(inputs)
inputs = paddle.flatten(inputs, 1)
numerator = 2 * (inputs * targets).sum(-1)
difference = paddle.abs(inputs - targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
numerator = denominator - difference
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_masks

Expand All @@ -61,8 +67,9 @@ def sigmoid_focal_loss(inputs, targets, num_masks, alpha=0.25, gamma=2):
Loss tensor
"""
prob = F.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(
inputs, targets, reduction="none")
ce_loss = F.binary_cross_entropy_with_logits(inputs,
targets,
reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t)**gamma)

Expand Down Expand Up @@ -173,27 +180,26 @@ def forward(self, outputs, targets):

# Iterate through batch size
for b in range(bs):
out_prob = F.softmax(
outputs["pred_logits"][b],
axis=-1) # [num_queries, num_classes]
out_prob = F.softmax(outputs["pred_logits"][b],
axis=-1) # [num_queries, num_classes]
out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]

tgt_ids = targets[b]["labels"]
# gt masks are already padded when preparing target
if targets[b]["labels"].shape[0] == 0:
indices.append((np.array(
[], dtype='int64'), np.array(
[], dtype='int64')))
indices.append(
(np.array([], dtype='int64'), np.array([], dtype='int64')))
continue
tgt_mask = paddle.cast(targets[b]["masks"], out_mask.dtype)
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -paddle.gather(out_prob, index=tgt_ids, axis=1)

# Downsample gt masks to save memory
tgt_mask = F.interpolate(
tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest")
tgt_mask = F.interpolate(tgt_mask[:, None],
size=out_mask.shape[-2:],
mode="nearest")

# Flatten spatial dimension
out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W]
Expand All @@ -212,12 +218,12 @@ def forward(self, outputs, targets):

indices.append(linear_sum_assignment(C))

return [(paddle.to_tensor(
i, dtype='int64'), paddle.to_tensor(
j, dtype='int64')) for i, j in indices]
return [(paddle.to_tensor(i, dtype='int64'),
paddle.to_tensor(j, dtype='int64')) for i, j in indices]


def nested_tensor_from_tensor_list(tensor_list):

def _max_by_axis(the_list):
maxes = the_list[0]
for sublist in the_list[1:]:
Expand All @@ -234,8 +240,8 @@ def _max_by_axis(the_list):

for i in range(tensor.shape[0]):
img = tensor_list[i]
tensor[i, :img.shape[0], :img.shape[1], :img.shape[
2]] = copy.deepcopy(img)
tensor[i, :img.shape[0], :img.shape[1], :img.
shape[2]] = copy.deepcopy(img)
mask[i, :img.shape[1], :img.shape[2]] = False
else:
raise ValueError("not supported")
Expand Down Expand Up @@ -273,15 +279,17 @@ def __init__(self,
dec_layers = 6
aux_weight_dict = {}
for i in range(dec_layers - 1):
aux_weight_dict.update(
{k + f"_{i}": v
for k, v in weight_dict.items()})
aux_weight_dict.update({
k + f"_{i}": v
for k, v in weight_dict.items()
})
weight_dict.update(aux_weight_dict)
self.num_classes = num_classes
self.ignore_index = ignore_index
self.weight_dict = weight_dict
self.matcher = HungarianMatcher(
cost_class=1, cost_mask=mask_weight, cost_dice=dice_weight)
self.matcher = HungarianMatcher(cost_class=1,
cost_mask=mask_weight,
cost_dice=dice_weight)
self.losses = losses
self.empty_weight = paddle.ones(shape=(num_classes + 1, ))
self.empty_weight[-1] = eos_coef
Expand All @@ -307,17 +315,18 @@ def loss_labels(self, outputs, targets, indices, num_masks):
idx = self._get_src_permutation_idx(indices)
target_classes_o = paddle.concat(
[t["labels"][J] for t, (_, J) in zip(targets_cpt, indices_cpt)])
target_classes = paddle.full(
src_logits.shape[:2], self.num_classes, dtype='int64')
target_classes = paddle.full(src_logits.shape[:2],
self.num_classes,
dtype='int64')
target_classes[idx] = target_classes_o

loss_ce = F.cross_entropy(
src_logits.transpose((0, 2, 1)).cast('float32'),
target_classes,
weight=self.empty_weight,
axis=1,
use_softmax=True,
ignore_index=255)
loss_ce = F.cross_entropy(src_logits.transpose(
(0, 2, 1)).cast('float32'),
target_classes,
weight=self.empty_weight,
axis=1,
use_softmax=True,
ignore_index=255)
losses = {"loss_ce": loss_ce}
return losses

Expand Down Expand Up @@ -352,11 +361,10 @@ def loss_masks(self, outputs, targets, indices, num_masks):
target_masks = paddle.cast(target_masks, src_masks.dtype)
target_masks = target_masks[tgt_idx]

src_masks = F.interpolate(
src_masks[:, None],
size=target_masks.shape[-2:],
mode="bilinear",
align_corners=False)
src_masks = F.interpolate(src_masks[:, None],
size=target_masks.shape[-2:],
mode="bilinear",
align_corners=False)
src_masks = paddle.flatten(src_masks[:, 0], 1)

target_masks = paddle.flatten(target_masks, 1)
Expand Down Expand Up @@ -398,8 +406,10 @@ def forward(self, logits, targets):
padded_masks[:, :gt_masks.shape[1], :gt_masks.shape[2]] = gt_masks

targets_cpt.append({
"labels": targets['gt_classes'][target_per_image_idx, ...],
"masks": padded_masks
"labels":
targets['gt_classes'][target_per_image_idx, ...],
"masks":
padded_masks
})

targets = []
Expand All @@ -411,12 +421,12 @@ def forward(self, logits, targets):
start_idx = int(invalid_indices[0].numpy())
else:
start_idx = len(item['labels'])
index = paddle.cast(
paddle.to_tensor([i for i in range(start_idx)]), 'int64')
item['labels'] = paddle.gather(
item['labels'], index, axis=0) # [n] n<150
item['masks'] = paddle.gather(
item["masks"], index, axis=0) # [n,512,512]
index = paddle.cast(paddle.to_tensor([i for i in range(start_idx)]),
'int64')
item['labels'] = paddle.gather(item['labels'], index,
axis=0) # [n] n<150
item['masks'] = paddle.gather(item["masks"], index,
axis=0) # [n,512,512]
targets.append(item)

logits_without_aux = {
Expand All @@ -432,8 +442,8 @@ def forward(self, logits, targets):

if dist.get_world_size() > 1:
dist.all_reduce(num_masks)
num_masks = paddle.clip(
num_masks / dist.get_world_size(), min=1).detach().numpy()[0]
num_masks = paddle.clip(num_masks / dist.get_world_size(),
min=1).detach().numpy()[0]

losses = {}
for loss in self.losses:
Expand Down