Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: huggingface tokenizer and Neural LM training pipeline. #139

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f038e60
hugginface tokenizer and Neural LM training pipeline.
glynpu Mar 25, 2021
e9482d2
draft of class LMDataset
glynpu Mar 29, 2021
135bfdb
a dummy implementation of LMDataset
glynpu Mar 29, 2021
88e0d49
collate function of NNLM
glynpu Mar 30, 2021
27b1863
add scripts to process word piece lexicons.
csukuangfj Mar 30, 2021
212b79b
Merge pull request #2 from csukuangfj/fangjun-rnnlm
glynpu Mar 30, 2021
47bf358
trainer
glynpu Mar 30, 2021
d8aaabd
generate lexicon
glynpu Mar 30, 2021
c44f99d
check text length in dataset.py
glynpu Mar 30, 2021
b13954d
remove shuf/comm commands
glynpu Mar 30, 2021
775d477
beta version of training pipeline
glynpu Mar 30, 2021
3b83338
Merge pull request #1 from glynpu/lyg_dev
glynpu Mar 30, 2021
d415ed0
remove unused file
glynpu Mar 30, 2021
4937232
add dependency and fix known bugs
glynpu Apr 1, 2021
61863db
fix various bugs
glynpu Apr 2, 2021
d4dccae
compute word_ppl from token_ppl
glynpu Apr 2, 2021
a4d5f1b
add results.md
glynpu Apr 3, 2021
53e2d1e
compute word_ppl from token_ppl
glynpu Apr 3, 2021
b226a3a
support yaml configuration
glynpu Apr 9, 2021
89ece61
update results with nvocab=5000
glynpu Apr 9, 2021
c3f8811
fix reviews
glynpu Apr 9, 2021
d1b803b
fixed reviews
glynpu Apr 9, 2021
c45d31f
support multi-gpu training with ddp
glynpu Apr 10, 2021
1d38c21
n-best rescoring result with 8-layer transformer lm
glynpu Apr 14, 2021
f6914cd
Merge remote-tracking branch 'dan/master' into nnlm
glynpu Apr 20, 2021
d847b28
filter train data by length to increase batch_size
glynpu Apr 20, 2021
52300df
use Noam optimizer
glynpu Apr 20, 2021
e61a9d1
add rescore scripts
glynpu Apr 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions egs/librispeech/asr/nnlm/local/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
from io import open
import torch


class Dictionary(object):

def __init__(self):
self.word2idx = {}
self.idx2word = []
self.idx2word.append('<PAD>')
self.word2idx['<PAD>'] = 0

def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
# self.word2idx[word] = len(self.idx2word)
return self.word2idx[word]

def __len__(self):
return len(self.idx2word)


class Corpus(object):

def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.tokens'))
self.valid = self.tokenize(os.path.join(path, 'valid.tokens'))
self.test = self.tokenize(os.path.join(path, 'test.tokens'))

def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r', encoding="utf8") as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
self.dictionary.add_word(word)

# Tokenize file content
with open(path, 'r', encoding="utf8") as f:
idss = []
for line in f:
words = line.split() + ['<eos>']
ids = []
for word in words:
ids.append(self.dictionary.word2idx[word])
idss.append(torch.tensor(ids).type(torch.int64))
# ids = torch.cat(idss)

return idss
42 changes: 42 additions & 0 deletions egs/librispeech/asr/nnlm/local/download_lm_train_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo)
# Apache 2.0

import os
import logging
from google_drive_downloader import GoogleDriveDownloader as gdd
from pathlib import Path

# librispeech-lm-norm.txt is 4G
# train_960_text is 48M, which is stands for the sum of {train_clean_360, train_clean_100, train_other_500}
# here only train_960_text used to verify the whole pipeline
# A copy of train_960_text: "htts://drive.google.com/file/d/1AgP4wTqbfp12dv4fJmjKXHdOf8eOtp_A/view?usp=sharing"
# local_path: "/ceph-ly/open-source/snowfall/egs/librispeech/asr/simple_v1/data/local/lm_train/train_960_text"


def download_librispeech_train_960_text():
train_960_text = "./data/lm_train/librispeech_train_960_text"
if not os.path.exists(train_960_text):
Path(os.path.dirname(train_960_text)).mkdir(parents=True,
exist_ok=True)

logging.info("downloading train_960_text of librispeech.")
gdd.download_file_from_google_drive(
file_id='1AgP4wTqbfp12dv4fJmjKXHdOf8eOtp_A',
dest_path=train_960_text,
unzip=False)
else:
logging.info(
"train_960_text of librispeech is already downloaded. You may should check that"
)


def main():
logging.getLogger().setLevel(logging.INFO)

download_librispeech_train_960_text()


if __name__ == '__main__':
main()
98 changes: 98 additions & 0 deletions egs/librispeech/asr/nnlm/local/huggingface_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3

# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo)
# Apache 2.0

# reference: https://huggingface.co/docs/tokenizers/python/latest/quicktour.html
import argparse
import logging
import os
import shutil
from pathlib import Path
from tokenizers import Tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some documentation describing how the environment is set up?
I assume that you have run pip install tokenizers beforehand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. A Readme.md will be added.

from tokenizers.models import WordPiece
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordPieceTrainer
from tokenizers import decoders


def get_args():
parser = argparse.ArgumentParser(
description='train and tokenize with huggingface tokenizer')
parser.add_argument('--train-file',
type=str,
help="""file to train tokenizer""")
parser.add_argument('--vocab-size',
type=int,
default=1000,
help="""number of tokens of the tokenizer""")
parser.add_argument('--tokenizer-path',
type=str,
help="path to save or load tokenizer")
parser.add_argument('--test-file',
type=str,
help="""file to be tokenized""")
args = parser.parse_args()
return args


def train_tokenizer(train_files, save_path, vocab_size):
if os.path.exists(save_path):
logging.warning(
"{} already exists. Please check that.".format(save_path))
return
else:
Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)

tokenizer = Tokenizer(WordPiece(unk_token='[UNK]'))
tokenizer.normalizer = normalizers.Sequence(
[NFD(), Lowercase(), StripAccents()])
tokenizer.pre_tokenizer = Whitespace()

# default vocab_size=30000
# here set vocab_size=1000 for accelerating
trainer = WordPieceTrainer(vocab_size=vocab_size, special_tokens=['[UNK]'])
tokenizer.train(train_files, trainer)
tokenizer.save(save_path)


def tokenize_text(test_file, tokenizer_path):
if not os.path.exists(tokenizer_path):
logging.warning(
"Tokenizer {} does not exist. Please check that.".format(
tokenizer_path))
return
tokenizer = Tokenizer.from_file(tokenizer_path)
tokenizer.decoder = decoders.WordPiece()
tokenized_file = "{}.tokens".format(test_file)
# tokenized_ids = "{}.ids".format(test_file)
if os.path.exists(tokenized_file):
logging.warning(
"The input file seems already tokenized. Buckupping previous result"
)
shutil.copyfile(tokenized_file, "{}.bk".format(tokenized_file))
logging.warning("Tokenizing {}.".format(test_file))
fout = open(tokenized_file, 'w')
with open(test_file) as f:
for line in f:
line = line.strip()
output = tokenizer.encode(line)
fout.write(" ".join(output.tokens) + '\n')

fout.close()


def main():
args = get_args()
if args.train_file is not None:
train_files = [args.train_file]
train_tokenizer(train_files, args.tokenizer_path, args.vocab_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

methods like these (train_tokenizer, tokenize_text) would be good candidates to put into the "library" part of snowfall so anybody can import them easily for all the recipes.

Candidate for future work in snowfall: actually this whole script could be easily re-used across recipes had we added a mechanism for auto-registering scripts in PATH (can be done via setup.py)


if args.test_file is not None:
tokenize_text(args.test_file, args.tokenizer_path)


if __name__ == '__main__':
main()
154 changes: 154 additions & 0 deletions egs/librispeech/asr/nnlm/local/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should get in the habit of acknowledging where we got files from, if they were copied from elsewhere...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. I will add a reference into every file. Now all references are together added in run.sh.

import torch
import torch.nn as nn
import torch.nn.functional as F

class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
super(RNNModel, self).__init__()
self.ntoken = ntoken
self.drop = nn.Dropout(dropout)
# import pdb; pdb.set_trace()
self.encoder = nn.Embedding(ntoken, ninp, padding_idx=0)
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError( """An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)

# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight

self.init_weights()

self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers

def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder.weight)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, input, hidden):
# import pdb; pdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to have the dimensions commented here, e.g. is it (batch_size, num_steps)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output)
decoded = decoded.view(-1, self.ntoken)
return F.log_softmax(decoded, dim=1), hidden

def init_hidden(self, bsz):
weight = next(self.parameters())
if self.rnn_type == 'LSTM':
return (weight.new_zeros(self.nlayers, bsz, self.nhid),
weight.new_zeros(self.nlayers, bsz, self.nhid))
else:
return weight.new_zeros(self.nlayers, bsz, self.nhid)

# Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
class PositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""

def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)

def forward(self, x):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great if you get the habit of writing more documentation.

You're saying that the input is of shape [seq_len, batch_size, embedding_dim],
but you are using batch first when invoking pad_sequence in dataset.py. This may explain why the training is not converging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

output: [sequence length, batch size, embed dim]
Examples:
>>> output = pos_encoder(x)
"""

x = x + self.pe[:x.size(0), :]
return self.dropout(x)

class TransformerModel(nn.Module):
"""Container module with an encoder, a recurrent or transformer module, and a decoder."""

def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
try:
from torch.nn import TransformerEncoder, TransformerEncoderLayer
except:
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.')
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)

self.init_weights()

def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask

def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder.weight)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, src, has_mask=True):
if has_mask:
device = src.device
if self.src_mask is None or self.src_mask.size(0) != len(src):
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
else:
self.src_mask = None

src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return F.log_softmax(output, dim=-1)
Loading