-
Notifications
You must be signed in to change notification settings - Fork 7
/
tool.py
61 lines (60 loc) · 2.42 KB
/
tool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import torchmetrics as tm
import torch.nn.functional as F
class METRICS:
def __init__(self,device='cpu'):
self.device=device
self.auroc=tm.AUROC(task='binary').to(device)
self.auprc=tm.AveragePrecision(task='binary').to(device)
self.roc=tm.ROC(task='binary').to(device)
self.prc=tm.PrecisionRecallCurve(task='binary').to(device)
self.rec=tm.Recall(task='binary').to(device)
self.prec=tm.Precision(task='binary').to(device)
self.f1=tm.F1Score(task='binary').to(device)
self.mcc=tm.MatthewsCorrCoef(task='binary').to(device)
f=lambda a,b,c,d,e:(a/(a+d)+c/(b+c))/2
self.stat=tm.StatScores(task='binary').to(device)
self.bacc=lambda x,y:f(*self.stat(x,y))
def to(self,pred,y):
return pred.to(self.device),y.to(self.device)
def calc_thresh(self,pred,y):
pred,y=self.to(pred,y)
prec, rec, thresholds = self.prc(pred,y)
f1=(2*prec*rec/(prec+rec)).nan_to_num(0)[:-1]
threshold = thresholds[torch.argmax(f1)]
return threshold
def calc_prc(self,pred,y):
pred,y=self.to(pred,y)
auroc = self.auroc(pred,y)
prec, rec, th1 = self.prc(pred,y)
auprc = self.auprc(pred,y)
fpr, tpr, th2 = self.roc(pred,y)
return {
'AUROC':auroc.cpu().item(),'AUPRC':auprc.cpu().item(),'prc':[rec[:-1],prec[:-1],th1],'roc':[fpr,tpr,th2]
}
def __call__(self,pred,y,threshold=None):
pred,y=self.to(pred,y)
auroc = self.auroc(pred,y)
prec, rec, thresholds = self.prc(pred,y)
auprc = self.auprc(pred,y)
if threshold is None:
f1=(2*prec*rec/(prec+rec)).nan_to_num(0)[:-1]
threshold = thresholds[torch.argmax(f1)]
threshold=torch.tensor(threshold)
self.f1.threshold=threshold
self.rec.threshold=threshold
self.mcc.threshold=threshold
self.stat.threshold=threshold
self.prec.threshold=threshold
f1 = self.f1(pred,y)
rec = self.rec(pred,y)
mcc = self.mcc(pred,y)
bacc = self.bacc(pred,y)
prec = self.prec(pred,y)
return {
'AUROC':auroc.cpu().item(),'AUPRC':auprc.cpu().item(),
'RECALL':rec.cpu().item(),'PRECISION':prec.cpu().item(),
'F1':f1.cpu().item(),'MCC':mcc.cpu().item(),
'BACC':bacc.cpu().item(),'threshold':threshold.cpu().item(),
}