From 1c77e98ee3c01e3cb7858202fbbf29385894f03f Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Fri, 11 Oct 2024 10:41:25 +0200 Subject: [PATCH] Modify Dice loss --- training/loss_fns.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/training/loss_fns.py b/training/loss_fns.py index d281b1a9..4b7d3833 100644 --- a/training/loss_fns.py +++ b/training/loss_fns.py @@ -9,6 +9,7 @@ import torch import torch.distributed +import torch.linalg as LA import torch.nn as nn import torch.nn.functional as F @@ -20,6 +21,9 @@ def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): """ Compute the DICE loss, similar to generalized IOU for masks + Reference: + Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. + Wang, Z. et. al. MICCAI 2023. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. @@ -38,11 +42,11 @@ def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): # flatten spatial dimension while keeping multimask channel dimension inputs = inputs.flatten(2) targets = targets.flatten(2) - numerator = 2 * (inputs * targets).sum(-1) else: inputs = inputs.flatten(1) - numerator = 2 * (inputs * targets).sum(1) denominator = inputs.sum(-1) + targets.sum(-1) + difference = LA.vector_norm(inputs - targets, ord=1, dim=-1) + numerator = (denominator - difference) / 2 loss = 1 - (numerator + 1) / (denominator + 1) if loss_on_multimask: return loss / num_objects