forked from magic282/cnndm_acl18
-
Notifications
You must be signed in to change notification settings - Fork 0
/
find_oracle_para.py
114 lines (98 loc) · 3.51 KB
/
find_oracle_para.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys
import itertools
import gc
import math
import datetime
import multiprocessing
from PyRouge.Rouge.Rouge import Rouge
from Document import Document
MAX_COMB_L = 5
MAX_COMB_NUM = 100000
def c_n_x(n, x):
if x > (n >> 2):
x = n - x
res = 1
for i in range(n, n - x, -1):
res *= i
for i in range(x, 0, -1):
res = res // i
return res
def solve_one(document, rouge):
if document.doc_len == 0 or document.summary_len == 0:
return None, 0
sentence_bigram_recall = [0] * document.doc_len
for idx, sent in enumerate(document.doc_sents):
scores = rouge.compute_rouge([document.summary_sents], [sent])
recall = scores['rouge-2']['r'][0]
sentence_bigram_recall[idx] = recall
candidates = []
for idx, recall in enumerate(sentence_bigram_recall):
if recall > 0:
candidates.append(idx)
all_best_l = 1
all_best_score = 0
all_best_comb = None
for l in range(1, len(candidates)):
if l > MAX_COMB_L:
print('Exceed MAX_COMB_L')
break
comb_num = c_n_x(len(candidates), l)
if math.isnan(comb_num) or math.isinf(comb_num) or comb_num > MAX_COMB_NUM:
print('Exceed MAX_COMB_NUM')
break
combs = itertools.combinations(candidates, l)
l_best_score = 0
l_best_choice = None
for comb in combs:
c_string = [document.doc_sents[idx] for idx in comb]
rouge_scores = rouge.compute_rouge([document.summary_sents], [c_string])
rouge_bigram_f1 = rouge_scores['rouge-2']['f'][0]
if rouge_bigram_f1 > l_best_score:
l_best_score = rouge_bigram_f1
l_best_choice = comb
if l_best_score > all_best_score:
all_best_l = l
all_best_score = l_best_score
all_best_comb = l_best_choice
else:
if l > all_best_l:
break
return all_best_comb, all_best_score
def solve(documents, output_file):
writer = open(output_file, 'w', encoding='utf-8', buffering=1)
for idx, doc in enumerate(documents):
if idx % 50 == 0:
print(datetime.datetime.now())
rouge.ngram_buf = {}
gc.collect()
comb = solve_one(doc)
writer.write('{0}\t {1}'.format(comb[0], comb[1]) + '\n')
writer.close()
def load_data(src_file, tgt_file):
docs = []
with open(src_file, 'r', encoding='utf-8') as src_reader, \
open(tgt_file, 'r', encoding='utf-8') as tgt_reader:
for src_line, tgt_line in zip(src_reader, tgt_reader):
src_line = src_line.strip()
tgt_line = tgt_line.strip()
if src_line == "" or tgt_line == "":
docs.append(None)
continue
src_sents = src_line.split('##SENT##')
tgt_sents = tgt_line.strip().split('##SENT##')
docs.append(Document(src_sents, tgt_sents))
return docs
def worker(doc):
rouge = Rouge(use_ngram_buf=True)
comb = solve_one(doc, rouge)
return comb
def main(src_file, tgt_file, outfile_name):
docs = load_data(src_file, tgt_file)
writer = open(outfile_name, 'w', encoding='utf-8', buffering=1)
with multiprocessing.Pool(max(1, int(multiprocessing.cpu_count() * 0.9))) as p:
all_results = (p.map(worker, docs))
for comb in all_results:
writer.write('{0}\t {1}'.format(comb[0], comb[1]) + '\n')
writer.close()
if __name__ == "__main__":
main(sys.argv[1], sys.argv[2], sys.argv[3])