-
Notifications
You must be signed in to change notification settings - Fork 2
/
metrics.py
103 lines (87 loc) · 3.74 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# The commonly used segmetnation metrics
import torch
import numpy as np
def predict(X, threshold):
'''X is sigmoid output of the model'''
X_p = np.copy(X)
preds = (X_p > threshold).astype('uint8')
return preds
def metric(probability, truth, threshold=0.5, reduction='none'):
'''Calculates dice of positive and negative images seperately'''
'''probability and truth must be torch tensors'''
batch_size = len(truth)
with torch.no_grad():
probability = probability.view(batch_size, -1)
truth = truth.view(batch_size, -1)
assert(probability.shape == truth.shape)
p = (probability > threshold).float()
t = (truth > 0.5).float()
t_sum = t.sum(-1)
p_sum = p.sum(-1)
neg_index = torch.nonzero(t_sum == 0)
pos_index = torch.nonzero(t_sum >= 1)
dice_neg = (p_sum == 0).float()
dice_pos = 2 * (p*t).sum(-1)/((p+t).sum(-1))
dice_neg = dice_neg[neg_index]
dice_pos = dice_pos[pos_index]
dice = torch.cat([dice_pos, dice_neg])
dice_neg = np.nan_to_num(dice_neg.mean().item(), 0)
dice_pos = np.nan_to_num(dice_pos.mean().item(), 0)
dice = dice.mean().item()
num_neg = len(neg_index)
num_pos = len(pos_index)
return dice, dice_neg, dice_pos, num_neg, num_pos
class Meter:
'''A meter to keep track of iou and dice scores throughout an epoch'''
def __init__(self, phase, epoch):
self.base_threshold = 0.8 # <<<<<<<<<<< here's the threshold
self.base_dice_scores = []
self.dice_neg_scores = []
self.dice_pos_scores = []
self.iou_scores = []
def update(self, targets, outputs):
probs = torch.sigmoid(outputs)
dice, dice_neg, dice_pos, _, _ = metric(probs, targets, self.base_threshold)
self.base_dice_scores.append(dice)
self.dice_pos_scores.append(dice_pos)
self.dice_neg_scores.append(dice_neg)
preds = predict(probs, self.base_threshold)
iou = compute_iou_batch(preds, targets, classes=[1])
self.iou_scores.append(iou)
def get_metrics(self):
dice = np.mean(self.base_dice_scores)
dice_neg = np.mean(self.dice_neg_scores)
dice_pos = np.mean(self.dice_pos_scores)
dices = [dice, dice_neg, dice_pos]
iou = np.nanmean(self.iou_scores)
return dices, iou
def epoch_log(phase, epoch, epoch_loss, meter, start):
'''logging the metrics at the end of an epoch'''
dices, iou = meter.get_metrics()
dice, dice_neg, dice_pos = dices
print("Loss: %0.4f | IoU: %0.4f | dice: %0.4f | dice_neg: %0.4f | dice_pos: %0.4f" % (epoch_loss, iou, dice, dice_neg, dice_pos))
return dice, iou
def compute_ious(pred, label, classes, ignore_index=255, only_present=True):
'''computes iou for one ground truth mask and predicted mask'''
pred[label == ignore_index] = 0
ious = []
for c in classes:
label_c = label == c
if only_present and np.sum(label_c) == 0:
ious.append(np.nan)
continue
pred_c = pred == c
intersection = np.logical_and(pred_c, label_c).sum()
union = np.logical_or(pred_c, label_c).sum()
if union != 0:
ious.append(intersection / union)
return ious if ious else [1]
def compute_iou_batch(outputs, labels, classes=None):
'''computes mean iou for a batch of ground truth masks and predicted masks'''
ious = []
preds = np.copy(outputs) # copy is imp
labels = np.array(labels) # tensor to np
for pred, label in zip(preds, labels):
ious.append(np.nanmean(compute_ious(pred, label, classes)))
iou = np.nanmean(ious)
return iou