Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify the Dice loss #376

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

zifuwanggg
Copy link

@zifuwanggg zifuwanggg commented Oct 11, 2024

The Dice loss in training.loss_fns is modified based on JDTLoss and segmentation_models.pytorch.

The original Dice loss is incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, it is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{|x|_1 + |y|_1 - |x-y|_1}{2}$. This reformulation has been proven to retain equivalence with the original version when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new version is minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, it resolves the issue with soft labels [1, 2].

Although the original SAM/SAM2 models were trained without soft labels, this modification enables soft label training for downstream fine-tuning without changing the existing behavior.

Example

import torch
import torch.linalg as LA
import torch.nn.functional as F

torch.manual_seed(0)

b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)

pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2)

def dice_old(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    intersection = torch.sum(x * y, dim=dims)
    return 2 * intersection / cardinality

def dice_new(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    difference = LA.vector_norm(x - y, ord=1, dim=dims)
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality

print(dice_old(pred, one_hot_label, dims), dice_new(pred, one_hot_label, dims))
print(dice_old(pred, soft_label, dims), dice_new(pred, soft_label, dims))
print(dice_old(pred, pred, dims), dice_new(pred, pred, dims))

# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])

References

[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants