-
Notifications
You must be signed in to change notification settings - Fork 0
/
decode_line_multi_lm.py
112 lines (98 loc) · 3.66 KB
/
decode_line_multi_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2016 Stanford University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
import os
import sys
import time
import numpy as np
from os.path import join as pjoin
from multiprocessing import Pool
from util_lm_kenlm import score_sent, initialize
from levenshtein import align_pair, align
folder_data = ''
data_dir = ''
out_dir = ''
lm_dir = ''
dev = ''
lm_name = ''
start = 0
end = -1
def remove_nonascii(text):
return re.sub(r'[^\x00-\x7F]', ' ', text)
def rank_sent(pool, sents):
#res1 = score_sent([0, remove_nonascii(sents[0])])
#res2 = score_sent([0, remove_nonascii(sents[1])])
new_sents = [remove_nonascii(ele) for ele in sents]
#print(res1, res2)
probs = np.ones(len(sents)) * -1
results = pool.map(score_sent, zip(np.arange(len(sents)), new_sents))
max_str = ''
max_prob = -1
for tid, score in results:
cur_prob = np.power(10, -score)
probs[tid] = cur_prob
if cur_prob > max_prob:
max_prob = cur_prob
max_str = sents[tid]
return max_str, max_prob, probs
def decode():
global folder_data, data_dir, out_dir, lm_dir, dev, start, end, lm_name
data_dir = pjoin(folder_data, data_dir)
folder_out = pjoin(data_dir, out_dir)
if not os.path.exists(folder_out):
os.makedirs(folder_out)
tic = time.time()
with open(pjoin(data_dir, dev + '.x.txt'), 'r') as f_:
lines = [ele for ele in f_.readlines()]
with open(pjoin(data_dir, dev + '.y.txt'), 'r') as f_:
truths = [ele.strip().lower() for ele in f_.readlines()]
f_o = open(pjoin(folder_out, dev + '.' + str(lm_name) + '.' + 'ec.txt.' + str(start) + '_' + str(end)), 'w')
f_b = open(pjoin(folder_out, dev + '.' + str(lm_name) + '.' + 'o.txt.' + str(start) + '_' + str(end)), 'w')
pool = Pool(100, initializer=initialize(pjoin(folder_data, 'lm/char', lm_dir)))
initialize(pjoin(folder_data, 'lm/char', lm_dir))
for line_id in range(start, end):
line = lines[line_id]
cur_truth = truths[line_id]
sents = [ele for ele in line.strip('\n').split('\t')][:20]
sents = [ele.strip() for ele in sents if len(ele.strip()) > 0]
if len(sents) > 0:
if 'low' in lm_dir:
sents = [ele.lower() for ele in sents]
best_sent, best_prob, probs = rank_sent(pool, sents)
best_dis = align(cur_truth, best_sent.lower())
f_o.write(str(best_dis) + '\t' + str(len(cur_truth)) + '\n')
f_b.write(best_sent + '\n')
else:
f_o.write(str(len(cur_truth)) + '\t' + str(len(cur_truth)) + '\n')
f_b.write('' + '\n')
if line_id % 100 == 0:
toc = time.time()
print(toc - tic)
tic = time.time()
f_o.close()
f_b.close()
def main():
global folder_data, data_dir, out_dir, lm_dir, dev, start, end, lm_name
folder_data = sys.argv[1]
data_dir = sys.argv[2]
out_dir = sys.argv[3]
lm_dir = sys.argv[4]
lm_name = sys.argv[5]
dev = sys.argv[6]
start = int(sys.argv[7])
end = int(sys.argv[8])
decode()
if __name__ == "__main__":
main()