-
Notifications
You must be signed in to change notification settings - Fork 42
WIP: huggingface tokenizer and Neural LM training pipeline. #139
base: master
Are you sure you want to change the base?
Conversation
This commit is mainly about hugginface tokenizer and a draft transformer/RNN based LM training pipeline.
@@ -0,0 +1,154 @@ | |||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should get in the habit of acknowledging where we got files from, if they were copied from elsewhere...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. I will add a reference into every file. Now all references are together added in run.sh.
These perplexities, are they per word or per token? |
per token. |
egs/librispeech/asr/nnlm/run.sh
Outdated
lm_train=data/lm_train/ | ||
full_text=$lm_train/librispeech_train_960_text | ||
tokenizer=$lm_train/tokenizer-librispeech_train_960.json | ||
if [ $stage -eq 1 ]; then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be $stage -le 1
?
And also for the following if
statements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. "-le" is better. Now "-eq" is used temporarily beacuse it's easier for me to debug stage by stage.
import os | ||
import shutil | ||
from pathlib import Path | ||
from tokenizers import Tokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some documentation describing how the environment is set up?
I assume that you have run pip install tokenizers
beforehand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. A Readme.md will be added.
egs/librispeech/asr/nnlm/main.py
Outdated
# Save the model if the validation loss is the best we've seen so far. | ||
if not best_val_loss or val_loss < best_val_loss: | ||
with open(args.save, 'wb') as f: | ||
torch.save(model, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From https://pytorch.org/tutorials/beginner/saving_loading_models.html
The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved.
Could you save only the state dict of the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
solved as following
def save_checkpoint(filename: Pathlike,
model: torch.nn.Module,
info: Info = None) -> None:
if not os.path.exists(os.path.dirname(filename)):
Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
logging.info(f'Save checkpoint to {filename}')
checkpoint = {
'state_dict': model.state_dict(),
}
if info is not None:
checkpoint.update(info)
torch.save(checkpoint, filename)
egs/librispeech/asr/nnlm/main.py
Outdated
epoch, batch_idx, | ||
len(train_data) // batch_size, lr, | ||
elapsed * 1000 / args.log_interval, cur_loss, | ||
math.exp(cur_loss))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These perplexities, are they per word or per token?
@danpovey
The perplexities are computed as exp(NLL)
and the modelling units are tokens so
PPL is computed with respect to tokens.
@glynpu Do you know what is the normal PPL for the LibriSpeech corpus in terms of tokens? |
It would very much depend on the way it was tokenized.
It's probably better to divide the total log-prob by the number of words,
to get the perplexity per word.
I'd guess between about 80 and 200, but that's just a guess.
…On Sun, Mar 28, 2021 at 11:53 PM Fangjun Kuang ***@***.***> wrote:
the PPL can decrease from around 1000 to aroud 110 with 10 epochs,
@glynpu <https://github.com/glynpu> Do you know what is the normal PPL
for the LibriSpeech corpus in terms of tokens?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#139 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAZFLO3TIKTRXWMFPDOWH53TF5GJHANCNFSM4ZZGJHRQ>
.
|
In our original paper we mention perplexities of 150 and 170.
…On Sun, Mar 28, 2021 at 11:56 PM Daniel Povey ***@***.***> wrote:
It would very much depend on the way it was tokenized.
It's probably better to divide the total log-prob by the number of words,
to get the perplexity per word.
I'd guess between about 80 and 200, but that's just a guess.
On Sun, Mar 28, 2021 at 11:53 PM Fangjun Kuang ***@***.***>
wrote:
> the PPL can decrease from around 1000 to aroud 110 with 10 epochs,
>
> @glynpu <https://github.com/glynpu> Do you know what is the normal PPL
> for the LibriSpeech corpus in terms of tokens?
>
> —
> You are receiving this because you were mentioned.
> Reply to this email directly, view it on GitHub
> <#139 (comment)>, or
> unsubscribe
> <https://github.com/notifications/unsubscribe-auth/AAZFLO3TIKTRXWMFPDOWH53TF5GJHANCNFSM4ZZGJHRQ>
> .
>
|
As shown by RNN-LM experiment in kaldi with librispeech data,
I am studying its configuration and hope to get a comparable ppl with the same data this week. |
egs/librispeech/asr/nnlm/run.sh
Outdated
num_utts_total=$(wc -l <$full_tokens ) | ||
num_valid_test=$(($num_utts_total/${valid_test_fraction})) | ||
set +x | ||
shuf -n $num_valid_test $full_tokens > $valid_test_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we fix the seed for shuf
so that the split is reproducible?
I think a Python script can do this task equally well and is more maintainable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reproducible is important. Maybe the data seperation method of kaldi RNNLM can be used in following experiments.
gunzip -c $text | cut -d ' ' -f2- | awk -v text_dir=$text_dir '{if(NR%2000 == 0) { print >text_dir"/dev.txt"; } else {print;}}' >$text_dir/librispeech.txt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for dropping bash/perl entirely for these sorts of tasks in snowfall.
Yes, probably that modulo method from Kaldi is fine. shuf is not always
installed.
…On Mon, Mar 29, 2021 at 9:47 AM LIyong.Guo ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In egs/librispeech/asr/nnlm/run.sh
<#139 (comment)>:
> + --test-file=$full_text \
+ --tokenizer-path=$tokenizer
+fi
+
+if [ $stage -eq 4 ]; then
+ echo "split all data into train/valid/test"
+
+ full_tokens=${full_text}.tokens
+ valid_test_fraction=10 # currently 5 percent for valid and 5 percent for test
+ valid_test_tokens=$lm_train/valid_test.tokens
+ train_tokens=$lm_train/train.tokens
+
+ num_utts_total=$(wc -l <$full_tokens )
+ num_valid_test=$(($num_utts_total/${valid_test_fraction}))
+ set +x
+ shuf -n $num_valid_test $full_tokens > $valid_test_tokens
Reproducible is important. Maybe the data seperation method of kaldi RNNLM
<https://github.com/kaldi-asr/kaldi/blob/pybind11/egs/librispeech/s5/local/rnnlm/tuning/run_tdnn_lstm_1a.sh#L75>
can be used in following experiments.
gunzip -c $text | cut -d ' ' -f2- | awk -v text_dir=$text_dir '{if(NR%2000
== 0) { print >text_dir"/dev.txt"; } else {print;}}'
>$text_dir/librispeech.txt
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#139 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAZFLO64IOJT2WBXGTJLV53TF7L3TANCNFSM4ZZGJHRQ>
.
|
egs/librispeech/asr/nnlm/main.py
Outdated
# │ e k q w │ | ||
# └ f l r x ┘. | ||
# These columns are treated as independent by the model, which means that the | ||
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that could be overcome by a data sampling and batching strategy where you iterate on the train text with overlapping windows (50% overlap being the obvious setting but for larger data probably a smaller value like 20% would work just as well and train faster)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So is the data treated as one long sequence, rather than a bunch of independent sentences?
I would have thought for ASR applications, the independent-sentences approach might make more sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, training text is not treated as a long sequence. I have modified the data preparation method so that each piece of text is treated independently. Sorry to forget to delete these unrelated original comments.
By the way, I am refactoring the training pipeline according these reviews. Temporarily, a new dataset class is located here, which handles training text one by one and then batchify them independently in CollateFunc.
with open(text_file, 'r') as f:
# a line represent a piece of text, e.g.
# DELAWARE IS NOT AFRAID OF DOGS
for line in f:
text = line.strip().split()
assert len(text) > 0
text_id = self.text2id(text)
# token_id format:
# <bos_id> token_id token_id token_id *** <eos_id>
token_id = self.text_id2token_id(text_id)
self.data.append(token_id)
args = get_args() | ||
if args.train_file is not None: | ||
train_files = [args.train_file] | ||
train_tokenizer(train_files, args.tokenizer_path, args.vocab_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
methods like these (train_tokenizer
, tokenize_text
) would be good candidates to put into the "library" part of snowfall so anybody can import them easily for all the recipes.
Candidate for future work in snowfall: actually this whole script could be easily re-used across recipes had we added a mechanism for auto-registering scripts in PATH (can be done via setup.py)
egs/librispeech/asr/nnlm/run.sh
Outdated
num_utts_total=$(wc -l <$full_tokens ) | ||
num_valid_test=$(($num_utts_total/${valid_test_fraction})) | ||
set +x | ||
shuf -n $num_valid_test $full_tokens > $valid_test_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for dropping bash/perl entirely for these sorts of tasks in snowfall.
Good work! I will try to read and understand what you are doing.
…On Tue, Mar 30, 2021 at 1:45 PM LIyong.Guo ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In egs/librispeech/asr/nnlm/main.py
<#139 (comment)>:
> +###############################################################################
+# Load data
+###############################################################################
+
+corpus = data.Corpus(args.data)
+
+# Starting from sequential data, batchify arranges the dataset into columns.
+# For instance, with the alphabet as the sequence and batch size 4, we'd get
+# ┌ a g m s ┐
+# │ b h n t │
+# │ c i o u │
+# │ d j p v │
+# │ e k q w │
+# └ f l r x ┘.
+# These columns are treated as independent by the model, which means that the
+# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
No, training text is not treated as a long sequence. I have modified the
data preparation method so that each piece of text is treated
independently. Sorry to forget to delete these unrelated original comments.
By the way, I am refactoring the training pipeline according these
reviews. Temporarily, a new dataset class is located here
<https://github.com/glynpu/snowfall/blob/88e0d49d559860134bfdf244b38bf25c84fa2c56/egs/librispeech/asr/nnlm/local/dataset.py#L51>,
which handle training text one by one and then batchfy them independtly in
CollateFunc
<https://github.com/glynpu/snowfall/blob/88e0d49d559860134bfdf244b38bf25c84fa2c56/egs/librispeech/asr/nnlm/local/dataset.py#L15>
.
with open(text_file, 'r') as f:
# a line represent a piece of text, e.g.
# DELAWARE IS NOT AFRAID OF DOGS
for line in f:
text = line.strip().split()
assert len(text) > 0
text_id = self.text2id(text)
# token_id format:
# <bos_id> token_id token_id token_id *** <eos_id>
token_id = self.text_id2token_id(text_id)
self.data.append(token_id)
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#139 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAZFLO7VKW4C76RZ4LSJVR3TGFQOTANCNFSM4ZZGJHRQ>
.
|
add scripts to process word piece lexicons.
def __getitem__(self, idx): | ||
return self.data[idx] | ||
|
||
def text2id(self, text: List[str]) -> List[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following two methods can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
nn.init.uniform_(self.decoder.weight, -initrange, initrange) | ||
|
||
def forward(self, input, hidden): | ||
# import pdb; pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to have the dimensions commented here, e.g. is it (batch_size, num_steps)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Something is not installed...
I don't know easy it is to set it up so things get installed automatically, or at least the user is told what to install? |
A commit to handle this together with other known bugs will be submitted this afternoon. |
scripts to install tokenizers fix training bugs port online tokenization to offline tokenization load/save checkpoint
@danpovey add statement to automatically install dependencies in run.sh
Now I am still facing some converging issues. With several epochs, the ppl stuck around 1000. |
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | ||
|
||
Pathlike = Union[str, Path] | ||
Info = Union[dict, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is equivalent to Info = Optional[dict]
# token_id format: | ||
# <bos_id> token_id token_id token_id *** <eos_id> | ||
if len(token_id) >= 2: | ||
for idx, line in enumerate(f): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idx
is never used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -37,35 +37,22 @@ def __call__(self, batch: List[List[int]]): | |||
|
|||
class LMDataset(Dataset): | |||
|
|||
def __init__(self, text_file: str, lexicon): | |||
def __init__(self, text_file: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe the format of text_file
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -29,17 +30,41 @@ def get_args(): | |||
|
|||
|
|||
def generate_tokens(args): | |||
''' Extract symbols and there corresponding ids from a tokenizer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: the corresponding
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fxied
tokenizer = Tokenizer.from_file(args.tokenizer_path) | ||
symbols = tokenizer.get_vocab() | ||
tokens_file = '{}/tokens.txt'.format(args.lexicon_path) | ||
tokens_f = open(tokens_file, 'w') | ||
for idx, sym in enumerate(symbols): | ||
tokens_f.write('{} {}\n'.format(sym.lower(), idx)) | ||
id2sym = dict((v, k.lower()) for k, v in symbols.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
id2sym = {idx: sym.lower() for sym, idx in symbols.items()}
is much clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
for idx, sym in enumerate(symbols): | ||
tokens_f.write('{} {}\n'.format(sym.lower(), idx)) | ||
id2sym = dict((v, k.lower()) for k, v in symbols.items()) | ||
for idx in range(len(symbols)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it required that the resulting file has its second column listed in increasing order?
Otherwise, it does not need to create another intermediate variable id2sym
.
We can iterate over symbols
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to ensure that ids are continues. And a ordered tokens.list looks nice.
result is nort sorted if we iterate over symbols directly, output by:
for k, v in symbols.items():
print(k.lower(), v)
looks like following(quite disorded):
'''
##ark 335
##umes 3822
vain 3593
eastern 4515
next 1372
knowing 4454
##jo 2789
western 3987
garden 1387
tree 1348
'''
output = tokenizer.encode(word) | ||
tokens = ' '.join(output.tokens) | ||
else: | ||
tokens = '[unk]' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a difference between [unk]
and <UNK>
?
I find that you're using <UNK>
in the above special_words
, but [unk]
here.
BTW: what are special_words
for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
special tokens is a heritage of words.txt: simple_v1/data/lang_nosp/words.txt. whose head is:
<eps> 0
!SIL 1
<SPOKEN_NOISE> 2
<UNK> 3
A 4
...
#0 200004
<s> 200005
</s> 200006
I just want to make sure every word in words.txt could be tokenized. As thoses special workds not "real" words, I think map them to [unk] is better than tokenized by a trained tokenizer.
In short, [UNK] amother with other special words is a heritage from upstream asr pipeline. and [unk] is a token by huggingface tokenizer.
egs/librispeech/asr/nnlm/main.py
Outdated
|
||
train_data_loader = DataLoader(train_dataset, | ||
batch_size=args.batch_size, | ||
shuffle=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to set shuffle
to True
for training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed. shuffle=True is used for debug to easily trace whether Dataloader and collate function works as expected.
batch_input, batch_target = batch | ||
batch_input = batch_input.to(self.device) | ||
batch_target = batch_target.to(self.device) | ||
self.model.to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great if this to(self.device)
is moved out of the loop. It needs to be done
only once, e.g., inside the constructor self.__init__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
Args: | ||
x: the sequence fed to the positional encoder model (required). | ||
Shape: | ||
x: [sequence length, batch size, embed dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great if you get the habit of writing more documentation.
You're saying that the input is of shape [seq_len, batch_size, embedding_dim]
,
but you are using batch first when invoking pad_sequence
in dataset.py. This may explain why the training is not converging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
with vocab_size=2000, epochs=50 tokens ppl of train: around 80 of dev: 119
Fixes #132
2021-04-23 use AM model trained with full librispeech data
2021-04-21
max_norm=5 is better than max_norm=0.25. The training is ongoing.
16 layers trained with Noam optimizer got a better wer than previous 8-layer transformers.But with this reference, max_norm=0.25 in clip_grad_norm_ seems TOO SMALL, which may explains epoch-19 only obtain a little gains comparing to epoch-3.Now max_norm=5 is used refering to espent transformer lm , and results coming soon.--------- previous comments------
This commit is mainly about hugginface tokenizer and
a draft transformer/RNN based LM training pipeline.
They are implemented mainly by referencing the follwing tutorials: tokenizer and neural LM which is also referenced by Espnet
Current (tokenizer + transformer LM) experiment shows that the PPL can decrease from around 1000 to aroud 110 with 10 epochs, as shown by the following screenshots.
TODOs:
1. Extend this training pipeline with advanced utils, such as multi-thread prefetching Dataloader with proper collate_fn and tensorboard summary writer.2. Evaluation/test parts3. Do experiments with full Librispeech data. Currently only 50MB training text is used out of around 4GB.4. A proper way to integrate NNLM into previous asr decode pipeline, i.e. the aim of the issue #132
5. Try other network structures.