-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_utils.py
120 lines (93 loc) · 3.21 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import re
import nltk
from nltk.corpus import stopwords
import string
import itertools
from collections import Counter
from itertools import count
import torch
from tqdm import tqdm
stop_words = set(stopwords.words('english'))
def get_word_vector(vocab, emb='glove'):
fname = 'glove.6B.300d.txt'
with open(fname, 'rt', encoding='utf8') as fi:
full_content = fi.read().strip().split('\n')
data = {}
for i in tqdm(range(len(full_content)), total=len(full_content), desc='loading glove vocabs...'):
i_word = full_content[i].split(' ')[0]
if i_word not in vocab.keys():
continue
i_embeddings = [float(val)
for val in full_content[i].split(' ')[1:]]
data[i_word] = i_embeddings
w = []
find = 0
for word in vocab.keys():
try:
w.append(torch.tensor(data[word]))
find += 1
except:
w.append(torch.rand(300))
print('found', find, 'words in', emb)
return torch.stack(w, dim=0)
def data_preprocessing(text, remove_stopword=True):
text = text.lower()
text = re.sub('<.*?>', '', text)
text = ''.join([c for c in text if c not in string.punctuation])
if remove_stopword:
text = [word for word in text.split() if word not in stop_words]
else:
text = [word for word in text.split()]
text = ' '.join(text)
return '<cls> ' + text
def create_vocab(corpus, vocab_size=30000):
corpus = [t.split() for t in corpus]
corpus = list(itertools.chain.from_iterable(corpus))
count_words = Counter(corpus)
print('total count words', len(count_words))
sorted_words = count_words.most_common()
if vocab_size > len(sorted_words):
v = len(sorted_words)
else:
v = vocab_size - 3
vocab_to_int = {w: i + 3 for i, (w, c) in enumerate(sorted_words[:v])}
vocab_to_int['<pad>'] = 0
vocab_to_int['<unk>'] = 1
vocab_to_int['<cls>'] = 2
print('vocab size', len(vocab_to_int))
return vocab_to_int
class Textset(Dataset):
def __init__(self, text, label, vocab, max_len):
super().__init__()
new_text = []
for t in text:
if len(t) > max_len:
t = t[:max_len]
new_text.append(t)
else:
new_text.append(t)
self.x = new_text
self.y = label
self.vocab = vocab
def collate(self, batch):
x = [torch.tensor(x) for x, y in batch]
y = [y for x, y in batch]
x_tensor = pad_sequence(x, True)
y = torch.tensor(y)
return x_tensor, y
def convert2id(self, text):
r = []
for word in text.split():
if word in self.vocab.keys():
r.append(self.vocab[word])
else:
r.append(self.vocab['<unk>'])
return r
def __getitem__(self, idx):
text = self.x[idx]
word_id = self.convert2id(text)
return word_id, self.y[idx]
def __len__(self):
return len(self.x)