-
Notifications
You must be signed in to change notification settings - Fork 4
/
validate.py
104 lines (83 loc) · 3.48 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from collections import defaultdict
from typing import Tuple
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from sklearn import metrics as skmetrics
from util import remove_x_axis_duplicates
def compute_meta_auc(result_data: pd.DataFrame,
prefix: str = '',
max_allowable_complexity: int = 30,
max_start_complexity: int = 10) -> Tuple[pd.DataFrame, Tuple[float]]:
"""
Parameters
----------
result_data
prefix
max_allowable_complexity
complexity score under which a model is considered interpretable
max_start_complexity
min complexity of curves included in the AUC-of-AUC comparison must be below this value
Returns
-------
"""
# x_column = f'{prefix}_mean_complexity'
x_column = f'mean_complexity'
compute_columns = result_data.columns[result_data.columns.str.contains('mean')]
estimators = np.unique(result_data.index)
xs = np.empty(len(estimators), dtype=object)
ys = xs.copy()
for i, est in enumerate(estimators):
est_result_df = result_data[result_data.index.str.fullmatch(est)]
complexities_unsorted = est_result_df[x_column]
complexity_sort_indices = complexities_unsorted.argsort()
complexities = complexities_unsorted[complexity_sort_indices]
roc_aucs = est_result_df.iloc[complexity_sort_indices][compute_columns]
xs[i] = complexities.values
ys[i] = roc_aucs.values
# filter out curves which start too complex
start_under_10 = list(map(lambda x: min(x) < max_start_complexity, xs))
# find overlapping complexity region for roc-of-roc comparison
meta_auc_lb = max([x[0] for x in xs])
endpts = np.array([x[-1] for x in xs])
meta_auc_ub = min(endpts[endpts > meta_auc_lb])
meta_auc_ub = min(meta_auc_ub, max_allowable_complexity)
# handle non-overlapping curves
endpt_after_lb = endpts > meta_auc_lb
eligible = start_under_10 & endpt_after_lb
# compute AUC of interpolated curves in overlap region
meta_aucs = defaultdict(lambda: [])
for i in range(len(xs)):
for c, col in enumerate(compute_columns):
if eligible[i]:
x, y = remove_x_axis_duplicates(xs[i], ys[i][:, c])
f_curve = interp1d(x, y)
x_interp = np.linspace(meta_auc_lb, meta_auc_ub, 100)
y_interp = f_curve(x_interp)
auc_value = np.trapz(y_interp, x=x_interp)
else:
auc_value = 0
meta_aucs[col + '_auc'].append(auc_value)
meta_auc_df = pd.DataFrame(meta_aucs, index=estimators)
meta_auc_df[f'{x_column}_lb'] = meta_auc_lb
meta_auc_df[f'{x_column}_ub'] = meta_auc_ub
return meta_auc_df
def get_best_accuracy(ytest, yscore):
thrs = np.unique(yscore)
accs = []
for thr in thrs:
accs.append(skmetrics.accuracy_score(ytest, yscore > thr))
return np.max(accs)
def make_best_spec_high_sens_scorer(min_sensitivity: float = 0.98):
def get_best_spec_high_sens(ytest, yscore):
thrs = np.unique(yscore)
best_spec = 0
for thr in thrs:
preds = yscore > thr
tn, fp, fn, tp = skmetrics.confusion_matrix(ytest, preds).ravel()
specificity = tn / (tn + fp)
sensitivity = tp / (tp + fn)
if sensitivity >= min_sensitivity:
best_spec = max(specificity, best_spec)
return best_spec
return get_best_spec_high_sens