-
Notifications
You must be signed in to change notification settings - Fork 16
/
logger.py
101 lines (85 loc) · 3.85 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from collections import defaultdict
class Logger(object):
""" Adapted from https://github.com/snap-stanford/ogb/ """
def __init__(self, runs, info=None):
self.info = info
self.results = [[] for _ in range(runs)]
def add_result(self, run, result):
assert len(result) == 3
assert run >= 0 and run < len(self.results)
self.results[run].append(result)
def print_statistics(self, run=None):
if run is not None:
result = 100 * torch.tensor(self.results[run])
argmax = result[:, 1].argmax().item()
print(f'Run {run + 1:02d}:')
print(f'Highest Train: {result[:, 0].max():.2f}')
print(f'Highest Valid: {result[:, 1].max():.2f}')
print(f' Final Train: {result[argmax, 0]:.2f}')
print(f' Final Test: {result[argmax, 2]:.2f}')
else:
result = 100 * torch.tensor(self.results)
best_results = []
for r in result:
train1 = r[:, 0].max().item()
valid = r[:, 1].max().item()
train2 = r[r[:, 1].argmax(), 0].item()
test = r[r[:, 1].argmax(), 2].item()
best_results.append((train1, valid, train2, test))
best_result = torch.tensor(best_results)
print(f'All runs:')
r = best_result[:, 0]
print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 1]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 2]
print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 3]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}')
return best_result[:, 1], best_result[:, 3]
class SimpleLogger(object):
""" Adapted from https://github.com/CUAI/CorrectAndSmooth """
def __init__(self, desc, param_names, num_values=2):
self.results = defaultdict(dict)
self.param_names = tuple(param_names)
self.used_args = list()
self.desc = desc
self.num_values = num_values
def add_result(self, run, args, values):
"""Takes run=int, args=tuple, value=tuple(float)"""
assert(len(args) == len(self.param_names))
assert(len(values) == self.num_values)
self.results[run][args] = values
if args not in self.used_args:
self.used_args.append(args)
def get_best(self, top_k=1):
all_results = []
for args in self.used_args:
results = [i[args] for i in self.results.values() if args in i]
results = torch.tensor(results)*100
results_mean = results.mean(dim=0)[-1]
results_std = results.std(dim=0)
all_results.append((args, results_mean))
results = sorted(all_results, key=lambda x: x[1], reverse=True)[:top_k]
return [i[0] for i in results]
def prettyprint(self, x):
if isinstance(x, float):
return '%.2f' % x
return str(x)
def display(self, args = None):
disp_args = self.used_args if args is None else args
if len(disp_args) > 1:
print(f'{self.desc} {self.param_names}, {len(self.results.keys())} runs')
for args in disp_args:
results = [i[args] for i in self.results.values() if args in i]
results = torch.tensor(results)*100
results_mean = results.mean(dim=0)
results_std = results.std(dim=0)
res_str = f'{results_mean[0]:.2f} ± {results_std[0]:.2f}'
for i in range(1, self.num_values):
res_str += f' -> {results_mean[i]:.2f} ± {results_std[1]:.2f}'
print(f'Args {[self.prettyprint(x) for x in args]}: {res_str}')
if len(disp_args) > 1:
print()
return results