-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluator.py
117 lines (97 loc) · 4.41 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score
from matplotlib import pyplot as plt
from torchvision.utils import save_image
from pathlib import Path
import pandas as pd
import seaborn as sn
import os
from logger import get_logger
logger = get_logger("Evaluator logger")
class SatEvaluator():
def __init__(self, device='cuda:0', pos_label=0, save_dir="results"):
self.device = device
self.pos_label = pos_label
self.save_dir = save_dir
self.total_preds = torch.empty(size=(0, 1), device=device)
self.total_gt = torch.empty(size=(0, 1), device=device)
self.train_losses = []
self.test_losses = []
# FP-FN analysis
self.FP_counter = 0
self.FN_counter = 0
self.FP_save_dir = os.path.join(save_dir, "fp-fn-analysis", "FP")
self.FN_save_dir = os.path.join(save_dir, "fp-fn-analysis", "FN")
Path(self.FP_save_dir).mkdir(parents=True, exist_ok=True) # create the directory if necessary
Path(self.FN_save_dir).mkdir(parents=True, exist_ok=True) # create the directory if necessary
def record_preds_gt(self, preds, gt):
self.total_preds = torch.cat((preds, self.total_preds))
self.total_gt = torch.cat((gt, self.total_gt))
def record_train_loss(self, train_loss):
if torch.is_tensor(train_loss):
train_loss = train_loss.item()
self.train_losses.append(train_loss)
def record_test_loss(self, test_loss):
if torch.is_tensor(test_loss):
test_loss = test_loss.item()
self.test_losses.append(test_loss)
def evaluate_accuracy(self):
return accuracy_score(self.total_gt.cpu(), self.total_preds.cpu())
def evaluate_f1(self):
return f1_score(self.total_gt.cpu(), self.total_preds.cpu(), pos_label=self.pos_label)
def evaluate_confmat(self):
return confusion_matrix(self.total_gt.cpu(), self.total_preds.cpu(), normalize='true')
def evaluate_precision(self):
return precision_score(self.total_gt.cpu(), self.total_preds.cpu(), pos_label=self.pos_label)
def evaluate_recall(self):
return recall_score(self.total_gt.cpu(), self.total_preds.cpu(), pos_label=self.pos_label)
def plot_training_info(self):
plt.figure(figsize=(6.4, 7.5))
plt.subplot(211)
plt.plot(self.train_losses, 'b')
plt.xlabel("Iteration number")
plt.ylabel("Cross Entropy Loss")
plt.title("Training loss")
if len(self.train_losses) > 0:
plt.yscale('log')
plt.grid(True)
plt.subplot(212)
plt.plot(self.test_losses, 'b')
plt.xlabel("Iteration number")
plt.ylabel("Cross Entropy Loss")
plt.title("Validation loss")
if len(self.test_losses) > 0:
plt.yscale('log')
plt.grid(True)
plt.tight_layout()
save_dir = os.path.join(self.save_dir, "trainval_curves.jpg")
print(f"Saving the graphs to {save_dir}")
plt.savefig(save_dir)
plt.close('all')
def plot_confmat(self):
confusion_matrix = pd.DataFrame(self.evaluate_confmat(),
index=["Actual positive", "Actual negative"],
columns=["Predicted positive", "Predicted negative"]
)
plt.figure()
sn.heatmap(confusion_matrix, annot=True, cmap="Blues")
plt.savefig(os.path.join(self.save_dir, "confmat.jpg"))
plt.close('all')
def reset(self):
self.total_preds = torch.empty(size=(0, 1), device=self.device)
self.total_gt = torch.empty(size=(0, 1), device=self.device)
self.train_losses = []
self.test_losses = []
def save_FP_FN(self, preds, labels_batch, images_batch):
# Save FP and FN images
for i in range(len(preds)):
if preds[i] == 0 and labels_batch[i] == 1:
# FP
img = images_batch[i]
save_image(img, os.path.join(self.FP_save_dir, f"image_{self.FP_counter}.jpg"))
self.FP_counter += 1
elif preds[i] == 1 and labels_batch[i] == 0:
# FN
img = images_batch[i]
save_image(img, os.path.join(self.FN_save_dir, f"image_{self.FN_counter}.jpg"))
self.FN_counter += 1