-
Notifications
You must be signed in to change notification settings - Fork 42
WIP: huggingface tokenizer and Neural LM training pipeline. #139
base: master
Are you sure you want to change the base?
Changes from 1 commit
f038e60
e9482d2
135bfdb
88e0d49
27b1863
212b79b
47bf358
d8aaabd
c44f99d
b13954d
775d477
3b83338
d415ed0
4937232
61863db
d4dccae
a4d5f1b
53e2d1e
b226a3a
89ece61
c3f8811
d1b803b
c45d31f
1d38c21
f6914cd
d847b28
52300df
e61a9d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
from io import open | ||
import torch | ||
|
||
|
||
class Dictionary(object): | ||
|
||
def __init__(self): | ||
self.word2idx = {} | ||
self.idx2word = [] | ||
self.idx2word.append('<PAD>') | ||
self.word2idx['<PAD>'] = 0 | ||
|
||
def add_word(self, word): | ||
if word not in self.word2idx: | ||
self.idx2word.append(word) | ||
self.word2idx[word] = len(self.idx2word) - 1 | ||
# self.word2idx[word] = len(self.idx2word) | ||
return self.word2idx[word] | ||
|
||
def __len__(self): | ||
return len(self.idx2word) | ||
|
||
|
||
class Corpus(object): | ||
|
||
def __init__(self, path): | ||
self.dictionary = Dictionary() | ||
self.train = self.tokenize(os.path.join(path, 'train.tokens')) | ||
self.valid = self.tokenize(os.path.join(path, 'valid.tokens')) | ||
self.test = self.tokenize(os.path.join(path, 'test.tokens')) | ||
|
||
def tokenize(self, path): | ||
"""Tokenizes a text file.""" | ||
assert os.path.exists(path) | ||
# Add words to the dictionary | ||
with open(path, 'r', encoding="utf8") as f: | ||
for line in f: | ||
words = line.split() + ['<eos>'] | ||
for word in words: | ||
self.dictionary.add_word(word) | ||
|
||
# Tokenize file content | ||
with open(path, 'r', encoding="utf8") as f: | ||
idss = [] | ||
for line in f: | ||
words = line.split() + ['<eos>'] | ||
ids = [] | ||
for word in words: | ||
ids.append(self.dictionary.word2idx[word]) | ||
idss.append(torch.tensor(ids).type(torch.int64)) | ||
# ids = torch.cat(idss) | ||
|
||
return idss |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo) | ||
# Apache 2.0 | ||
|
||
import os | ||
import logging | ||
from google_drive_downloader import GoogleDriveDownloader as gdd | ||
from pathlib import Path | ||
|
||
# librispeech-lm-norm.txt is 4G | ||
# train_960_text is 48M, which is stands for the sum of {train_clean_360, train_clean_100, train_other_500} | ||
# here only train_960_text used to verify the whole pipeline | ||
# A copy of train_960_text: "htts://drive.google.com/file/d/1AgP4wTqbfp12dv4fJmjKXHdOf8eOtp_A/view?usp=sharing" | ||
# local_path: "/ceph-ly/open-source/snowfall/egs/librispeech/asr/simple_v1/data/local/lm_train/train_960_text" | ||
|
||
|
||
def download_librispeech_train_960_text(): | ||
train_960_text = "./data/lm_train/librispeech_train_960_text" | ||
if not os.path.exists(train_960_text): | ||
Path(os.path.dirname(train_960_text)).mkdir(parents=True, | ||
exist_ok=True) | ||
|
||
logging.info("downloading train_960_text of librispeech.") | ||
gdd.download_file_from_google_drive( | ||
file_id='1AgP4wTqbfp12dv4fJmjKXHdOf8eOtp_A', | ||
dest_path=train_960_text, | ||
unzip=False) | ||
else: | ||
logging.info( | ||
"train_960_text of librispeech is already downloaded. You may should check that" | ||
) | ||
|
||
|
||
def main(): | ||
logging.getLogger().setLevel(logging.INFO) | ||
|
||
download_librispeech_train_960_text() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo) | ||
# Apache 2.0 | ||
|
||
# reference: https://huggingface.co/docs/tokenizers/python/latest/quicktour.html | ||
import argparse | ||
import logging | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
from tokenizers import Tokenizer | ||
from tokenizers.models import WordPiece | ||
from tokenizers import normalizers | ||
from tokenizers.normalizers import Lowercase, NFD, StripAccents | ||
from tokenizers.pre_tokenizers import Whitespace | ||
from tokenizers.trainers import WordPieceTrainer | ||
from tokenizers import decoders | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser( | ||
description='train and tokenize with huggingface tokenizer') | ||
parser.add_argument('--train-file', | ||
type=str, | ||
help="""file to train tokenizer""") | ||
parser.add_argument('--vocab-size', | ||
type=int, | ||
default=1000, | ||
help="""number of tokens of the tokenizer""") | ||
parser.add_argument('--tokenizer-path', | ||
type=str, | ||
help="path to save or load tokenizer") | ||
parser.add_argument('--test-file', | ||
type=str, | ||
help="""file to be tokenized""") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def train_tokenizer(train_files, save_path, vocab_size): | ||
if os.path.exists(save_path): | ||
logging.warning( | ||
"{} already exists. Please check that.".format(save_path)) | ||
return | ||
else: | ||
Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True) | ||
|
||
tokenizer = Tokenizer(WordPiece(unk_token='[UNK]')) | ||
tokenizer.normalizer = normalizers.Sequence( | ||
[NFD(), Lowercase(), StripAccents()]) | ||
tokenizer.pre_tokenizer = Whitespace() | ||
|
||
# default vocab_size=30000 | ||
# here set vocab_size=1000 for accelerating | ||
trainer = WordPieceTrainer(vocab_size=vocab_size, special_tokens=['[UNK]']) | ||
tokenizer.train(train_files, trainer) | ||
tokenizer.save(save_path) | ||
|
||
|
||
def tokenize_text(test_file, tokenizer_path): | ||
if not os.path.exists(tokenizer_path): | ||
logging.warning( | ||
"Tokenizer {} does not exist. Please check that.".format( | ||
tokenizer_path)) | ||
return | ||
tokenizer = Tokenizer.from_file(tokenizer_path) | ||
tokenizer.decoder = decoders.WordPiece() | ||
tokenized_file = "{}.tokens".format(test_file) | ||
# tokenized_ids = "{}.ids".format(test_file) | ||
if os.path.exists(tokenized_file): | ||
logging.warning( | ||
"The input file seems already tokenized. Buckupping previous result" | ||
) | ||
shutil.copyfile(tokenized_file, "{}.bk".format(tokenized_file)) | ||
logging.warning("Tokenizing {}.".format(test_file)) | ||
fout = open(tokenized_file, 'w') | ||
with open(test_file) as f: | ||
for line in f: | ||
line = line.strip() | ||
output = tokenizer.encode(line) | ||
fout.write(" ".join(output.tokens) + '\n') | ||
|
||
fout.close() | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
if args.train_file is not None: | ||
train_files = [args.train_file] | ||
train_tokenizer(train_files, args.tokenizer_path, args.vocab_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. methods like these ( Candidate for future work in snowfall: actually this whole script could be easily re-used across recipes had we added a mechanism for auto-registering scripts in PATH (can be done via setup.py) |
||
|
||
if args.test_file is not None: | ||
tokenize_text(args.test_file, args.tokenizer_path) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import math | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should get in the habit of acknowledging where we got files from, if they were copied from elsewhere... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem. I will add a reference into every file. Now all references are together added in run.sh. |
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class RNNModel(nn.Module): | ||
"""Container module with an encoder, a recurrent module, and a decoder.""" | ||
|
||
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): | ||
super(RNNModel, self).__init__() | ||
self.ntoken = ntoken | ||
self.drop = nn.Dropout(dropout) | ||
# import pdb; pdb.set_trace() | ||
self.encoder = nn.Embedding(ntoken, ninp, padding_idx=0) | ||
if rnn_type in ['LSTM', 'GRU']: | ||
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) | ||
else: | ||
try: | ||
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] | ||
except KeyError: | ||
raise ValueError( """An invalid option for `--model` was supplied, | ||
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") | ||
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) | ||
self.decoder = nn.Linear(nhid, ntoken) | ||
|
||
# Optionally tie weights as in: | ||
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) | ||
# https://arxiv.org/abs/1608.05859 | ||
# and | ||
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) | ||
# https://arxiv.org/abs/1611.01462 | ||
if tie_weights: | ||
if nhid != ninp: | ||
raise ValueError('When using the tied flag, nhid must be equal to emsize') | ||
self.decoder.weight = self.encoder.weight | ||
|
||
self.init_weights() | ||
|
||
self.rnn_type = rnn_type | ||
self.nhid = nhid | ||
self.nlayers = nlayers | ||
|
||
def init_weights(self): | ||
initrange = 0.1 | ||
nn.init.uniform_(self.encoder.weight, -initrange, initrange) | ||
nn.init.zeros_(self.decoder.weight) | ||
nn.init.uniform_(self.decoder.weight, -initrange, initrange) | ||
|
||
def forward(self, input, hidden): | ||
# import pdb; pdb.set_trace() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to have the dimensions commented here, e.g. is it (batch_size, num_steps)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
emb = self.drop(self.encoder(input)) | ||
output, hidden = self.rnn(emb, hidden) | ||
output = self.drop(output) | ||
decoded = self.decoder(output) | ||
decoded = decoded.view(-1, self.ntoken) | ||
return F.log_softmax(decoded, dim=1), hidden | ||
|
||
def init_hidden(self, bsz): | ||
weight = next(self.parameters()) | ||
if self.rnn_type == 'LSTM': | ||
return (weight.new_zeros(self.nlayers, bsz, self.nhid), | ||
weight.new_zeros(self.nlayers, bsz, self.nhid)) | ||
else: | ||
return weight.new_zeros(self.nlayers, bsz, self.nhid) | ||
|
||
# Temporarily leave PositionalEncoding module here. Will be moved somewhere else. | ||
class PositionalEncoding(nn.Module): | ||
r"""Inject some information about the relative or absolute position of the tokens | ||
in the sequence. The positional encodings have the same dimension as | ||
the embeddings, so that the two can be summed. Here, we use sine and cosine | ||
functions of different frequencies. | ||
.. math:: | ||
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) | ||
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) | ||
\text{where pos is the word position and i is the embed idx) | ||
Args: | ||
d_model: the embed dim (required). | ||
dropout: the dropout value (default=0.1). | ||
max_len: the max. length of the incoming sequence (default=5000). | ||
Examples: | ||
>>> pos_encoder = PositionalEncoding(d_model) | ||
""" | ||
|
||
def __init__(self, d_model, dropout=0.1, max_len=5000): | ||
super(PositionalEncoding, self).__init__() | ||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
pe = torch.zeros(max_len, d_model) | ||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | ||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | ||
pe[:, 0::2] = torch.sin(position * div_term) | ||
pe[:, 1::2] = torch.cos(position * div_term) | ||
pe = pe.unsqueeze(0).transpose(0, 1) | ||
self.register_buffer('pe', pe) | ||
|
||
def forward(self, x): | ||
r"""Inputs of forward function | ||
Args: | ||
x: the sequence fed to the positional encoder model (required). | ||
Shape: | ||
x: [sequence length, batch size, embed dim] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be great if you get the habit of writing more documentation. You're saying that the input is of shape There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||
output: [sequence length, batch size, embed dim] | ||
Examples: | ||
>>> output = pos_encoder(x) | ||
""" | ||
|
||
x = x + self.pe[:x.size(0), :] | ||
return self.dropout(x) | ||
|
||
class TransformerModel(nn.Module): | ||
"""Container module with an encoder, a recurrent or transformer module, and a decoder.""" | ||
|
||
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): | ||
super(TransformerModel, self).__init__() | ||
try: | ||
from torch.nn import TransformerEncoder, TransformerEncoderLayer | ||
except: | ||
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.') | ||
self.model_type = 'Transformer' | ||
self.src_mask = None | ||
self.pos_encoder = PositionalEncoding(ninp, dropout) | ||
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) | ||
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) | ||
self.encoder = nn.Embedding(ntoken, ninp) | ||
self.ninp = ninp | ||
self.decoder = nn.Linear(ninp, ntoken) | ||
|
||
self.init_weights() | ||
|
||
def _generate_square_subsequent_mask(self, sz): | ||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | ||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | ||
return mask | ||
|
||
def init_weights(self): | ||
initrange = 0.1 | ||
nn.init.uniform_(self.encoder.weight, -initrange, initrange) | ||
nn.init.zeros_(self.decoder.weight) | ||
nn.init.uniform_(self.decoder.weight, -initrange, initrange) | ||
|
||
def forward(self, src, has_mask=True): | ||
if has_mask: | ||
device = src.device | ||
if self.src_mask is None or self.src_mask.size(0) != len(src): | ||
mask = self._generate_square_subsequent_mask(len(src)).to(device) | ||
self.src_mask = mask | ||
else: | ||
self.src_mask = None | ||
|
||
src = self.encoder(src) * math.sqrt(self.ninp) | ||
src = self.pos_encoder(src) | ||
output = self.transformer_encoder(src, self.src_mask) | ||
output = self.decoder(output) | ||
return F.log_softmax(output, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some documentation describing how the environment is set up?
I assume that you have run
pip install tokenizers
beforehand.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. A Readme.md will be added.