-
Notifications
You must be signed in to change notification settings - Fork 19
/
score_lm.py
134 lines (97 loc) · 4.44 KB
/
score_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import collections
import codecs
import numpy as np
from pandas import read_csv
from pathlib import Path
PTB_PATH = Path(__file__).with_name("PTB")
def _read_words(filename):
with open(filename, "r") as f:
return f.read().replace("\n", "<eos>").split()
def _build_vocab(filename):
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
id_to_word = {v: k for k, v in word_to_id.items()}
return word_to_id, id_to_word
def _file_to_word_ids(filename, word_to_id):
data = _read_words(filename)
return [word_to_id[word] for word in data if word in word_to_id]
def load_dataset(data_path=None):
train_path = os.path.join(data_path, "ptb.train.txt")
dev_path = os.path.join(data_path, "ptb.valid.txt")
test_path = os.path.join(data_path, "ptb.test.txt")
word_to_id, id_to_word = _build_vocab(train_path)
train_data = _file_to_word_ids(train_path, word_to_id)
dev_data = _file_to_word_ids(dev_path, word_to_id)
test_data = _file_to_word_ids(test_path, word_to_id)
return train_data, dev_data, test_data, word_to_id, id_to_word
def save_preds(preds, preds_fname):
"""
Save classifier predictions in format appropriate for scoring.
"""
with codecs.open(preds_fname, 'w') as outp:
for vals in preds:
print(*vals, sep='\t', file=outp)
print('Predictions saved to %s' % preds_fname)
def load_preds(preds_fname, compressed=False):
"""
Load classifier predictions in format appropriate for scoring.
"""
c = 'gzip' if compressed else 'infer'
# reading data by columns is necessary to use less memory (for low resource server)
prevs = list(read_csv(preds_fname, sep='\t', compression=c, usecols=["prev"])["prev"])
true_probs = np.float32(read_csv(preds_fname, sep='\t', compression=c, usecols=["true_prob"])["true_prob"])
true_ranks = np.int32(read_csv(preds_fname, sep='\t', compression=c, usecols=["true_rank"])["true_rank"])
kl_uniform = np.float32(read_csv(preds_fname, sep='\t', compression=c, usecols=["kl_uniform"])["kl_uniform"])
kl_unigram = np.float32(read_csv(preds_fname, sep='\t', compression=c, usecols=["kl_unigram"])["kl_unigram"])
return prevs, true_probs, true_ranks, kl_uniform, kl_unigram
def compute_perplexity(probs):
return np.exp(-np.log(probs).sum() / len(probs))
def compute_hit_k(ranks, k=10):
mask = np.where(ranks < k)[0]
return float(len(mask)) / len(ranks)
def compute_average_rank(ranks):
return np.mean(ranks)
def compute_average_kl(kl_divergence):
return np.mean(kl_divergence)
def score_preds(preds_path, ptb_path=PTB_PATH, compressed=False):
data = load_preds(preds_path, compressed=compressed)
recieved_text, probs, ranks, kl_uniform, kl_unigram = data
with open(os.path.join(ptb_path, "ptb.train.txt"), "r") as f:
train_text = f.read().strip().replace("\n", "<eos>")
with open(os.path.join(ptb_path, "ptb.valid.txt"), "r") as f:
dev_text = f.read().strip().replace("\n", "<eos>")
with open(os.path.join(ptb_path, "ptb.test.txt"), "r") as f:
test_text = f.read().strip().replace("\n", "<eos>")
ptb_dataset = [
('train', train_text, train_text.count(" ") + 1),
('valid', dev_text, dev_text.count(" ") + 1),
('test', test_text, test_text.count(" ") + 1),
]
scores = dict()
for name, text, len_text in ptb_dataset:
# Check text is PTB
if ' '.join(recieved_text[:len_text]) != text:
raise Exception(f'Received text does not match PTB text')
# Perplexity calculation
perplexity = compute_perplexity(probs[:len_text])
hit_k = compute_hit_k(ranks[:len_text])
avg_rank = compute_average_rank(ranks[:len_text])
avg_kl_uniform = compute_average_kl(kl_uniform[:len_text])
avg_kl_unigram = compute_average_kl(kl_unigram[:len_text])
scores[name] = {
'perplexity': perplexity,
'hit@10': hit_k,
'avg_rank': avg_rank,
'avg_kl_uniform': avg_kl_uniform,
'avg_kl_unigram': avg_kl_unigram,
}
probs = probs[len_text:]
ranks = ranks[len_text:]
kl_uniform = kl_uniform[len_text:]
kl_unigram = kl_unigram[len_text:]
recieved_text = recieved_text[len_text:]
return scores