Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: huggingface tokenizer and Neural LM training pipeline. #139

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f038e60
hugginface tokenizer and Neural LM training pipeline.
glynpu Mar 25, 2021
e9482d2
draft of class LMDataset
glynpu Mar 29, 2021
135bfdb
a dummy implementation of LMDataset
glynpu Mar 29, 2021
88e0d49
collate function of NNLM
glynpu Mar 30, 2021
27b1863
add scripts to process word piece lexicons.
csukuangfj Mar 30, 2021
212b79b
Merge pull request #2 from csukuangfj/fangjun-rnnlm
glynpu Mar 30, 2021
47bf358
trainer
glynpu Mar 30, 2021
d8aaabd
generate lexicon
glynpu Mar 30, 2021
c44f99d
check text length in dataset.py
glynpu Mar 30, 2021
b13954d
remove shuf/comm commands
glynpu Mar 30, 2021
775d477
beta version of training pipeline
glynpu Mar 30, 2021
3b83338
Merge pull request #1 from glynpu/lyg_dev
glynpu Mar 30, 2021
d415ed0
remove unused file
glynpu Mar 30, 2021
4937232
add dependency and fix known bugs
glynpu Apr 1, 2021
61863db
fix various bugs
glynpu Apr 2, 2021
d4dccae
compute word_ppl from token_ppl
glynpu Apr 2, 2021
a4d5f1b
add results.md
glynpu Apr 3, 2021
53e2d1e
compute word_ppl from token_ppl
glynpu Apr 3, 2021
b226a3a
support yaml configuration
glynpu Apr 9, 2021
89ece61
update results with nvocab=5000
glynpu Apr 9, 2021
c3f8811
fix reviews
glynpu Apr 9, 2021
d1b803b
fixed reviews
glynpu Apr 9, 2021
c45d31f
support multi-gpu training with ddp
glynpu Apr 10, 2021
1d38c21
n-best rescoring result with 8-layer transformer lm
glynpu Apr 14, 2021
f6914cd
Merge remote-tracking branch 'dan/master' into nnlm
glynpu Apr 20, 2021
d847b28
filter train data by length to increase batch_size
glynpu Apr 20, 2021
52300df
use Noam optimizer
glynpu Apr 20, 2021
e61a9d1
add rescore scripts
glynpu Apr 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions egs/librispeech/asr/nnlm/local/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo)
# Apache 2.0

# modified from https://github.com/k2-fsa/snowfall/blob/master/snowfall/common.py to save/load non-Acoustic Model
import logging
import os
import torch

from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

Pathlike = Union[str, Path]
Info = Union[dict, None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to Info = Optional[dict]



def load_checkpoint(filename: Pathlike,
model: torch.nn.Module,
info: Info = None) -> Dict[str, Any]:
logging.info('load checkpoint from {}'.format(filename))

checkpoint = torch.load(filename, map_location='cpu')

model.load_state_dict(checkpoint['state_dict'])

return checkpoint


def save_checkpoint(filename: Pathlike,
model: torch.nn.Module,
info: Info = None) -> None:
if not os.path.exists(os.path.dirname(filename)):
Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
logging.info(f'Save checkpoint to {filename}')
checkpoint = {
'state_dict': model.state_dict(),
}
if info is not None:
checkpoint.update(info)

torch.save(checkpoint, filename)
42 changes: 10 additions & 32 deletions egs/librispeech/asr/nnlm/local/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo)
# Apache 2.0

import time
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from typing import List
from util import convert_tokens_to_ids

import numpy as np
import os
Expand Down Expand Up @@ -37,35 +37,22 @@ def __call__(self, batch: List[List[int]]):

class LMDataset(Dataset):

def __init__(self, text_file: str, lexicon):
def __init__(self, text_file: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe the format of text_file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

'''Dataset to load Language Model train/dev text data

Args:
text_file: text file, text for one utt per line.
'''
self.lexicon = lexicon
assert os.path.exists(
text_file), "text_file: {} does not exist, please check that."
text_file
), "text_file: {} does not exist, please check that.".format(text_file)
self.data = []
with open(text_file, 'r') as f:
# a line represent a piece of text, e.g.
# DELAWARE IS NOT AFRAID OF DOGS
for line in f:
# import pdb
# pdb.set_trace()
text = line.strip().lower().split()
# print(text)
if len(text) == 0:
continue
word_id = convert_tokens_to_ids(text, self.lexicon.word2id)
if len(word_id) == 0:
continue
word_id = torch.from_numpy(np.array(word_id, dtype="int32"))

token_id = self.lexicon.word_seq_to_word_piece_seq(word_id)
# token_id format:
# <bos_id> token_id token_id token_id *** <eos_id>
if len(token_id) >= 2:
for idx, line in enumerate(f):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx is never used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

token_id = [int(i) for i in line.strip().split()]
# TODO(Liyong Guo): add bos_id and eos_id to each piece of example
# then each valid example should be longer than 2
if len(token_id) > 2:
self.data.append(token_id)

def __len__(self):
Expand All @@ -74,18 +61,9 @@ def __len__(self):
def __getitem__(self, idx):
return self.data[idx]

def text2id(self, text: List[str]) -> List[int]:
# A dumpy implementation
return [i for i in range(len(text))]

def text_id2token_id(self, text_id: List[int]) -> List[int]:
# A dumpy implementation
return [i for i in range(len(text_id))]


if __name__ == '__main__':
# train_file = "./data/nnlm/text/librispeech.txt"
dev_file = "./data/nnlm/text/dev.txt"
dev_file = "./data/nnlm/text/dev.txt.tokens"
dataset = LMDataset(dev_file)
collate_func = CollateFunc()
data_loader = DataLoader(dataset,
Expand Down
39 changes: 30 additions & 9 deletions egs/librispeech/asr/nnlm/local/generate_lexicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Apache 2.0

import argparse
import collections
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers import decoders
Expand All @@ -29,17 +30,41 @@ def get_args():


def generate_tokens(args):
''' Extract symbols and there corresponding ids from a tokenizer,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: the corresponding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fxied

and save as tokens.txt.
An example file looks like:
a 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does an ID start from 0 or is 0 reserved for a special token?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. Now index 0 is occupied by [unk]. Head of a real tokens.txt is:

[unk] 0
' 1
a 2
b 3
c 4
...
patty 9994
neatly 9995
stormy 9996
daddy 9997
##enon 9998
remarkably 9999

I will check is there a way to reserve index 0 with hugginface tokenizer.

b 2
c 3
...
it 100
sh 101

'''

tokenizer = Tokenizer.from_file(args.tokenizer_path)
symbols = tokenizer.get_vocab()
tokens_file = '{}/tokens.txt'.format(args.lexicon_path)
tokens_f = open(tokens_file, 'w')
for idx, sym in enumerate(symbols):
tokens_f.write('{} {}\n'.format(sym.lower(), idx))
id2sym = dict((v, k.lower()) for k, v in symbols.items())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id2sym = {idx: sym.lower() for sym, idx in symbols.items()}

is much clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for idx in range(len(symbols)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it required that the resulting file has its second column listed in increasing order?
Otherwise, it does not need to create another intermediate variable id2sym.
We can iterate over symbols directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to ensure that ids are continues. And a ordered tokens.list looks nice.
result is nort sorted if we iterate over symbols directly, output by:

    for k, v in symbols.items():
        print(k.lower(), v)

looks like following(quite disorded):
'''
##ark 335
##umes 3822
vain 3593
eastern 4515
next 1372
knowing 4454
##jo 2789
western 3987
garden 1387
tree 1348
'''

assert idx in id2sym
tokens_f.write('{} {}\n'.format(id2sym[idx], idx))

tokens_f.close()


def generate_lexicon(args, words):
''' Tokenize every word in words.txt and save as lexicont.txt.
Each line represents a word and its tokenized representation, i.e. a sequence of tokens. a word and its tokens are seprated by a table.

An example file looks like:

abbreviating abb ##re ##via ##ting
abbreviation abb ##re ##via ##t ##ion
abbreviations abb ##re ##via ##t ##ions

'''
special_words = [
'<eps>', '!SIL', '<SPOKEN_NOISE>', '<UNK>', '<s>', '</s>', '#0'
]
Expand All @@ -48,7 +73,8 @@ def generate_lexicon(args, words):
tokenizer = Tokenizer.from_file(args.tokenizer_path)
tokenizer.decoder = decoders.WordPiece()
for word in words:
if word not in special_words:
if not (word.upper() in special_words or
word.lower() in special_words):
output = tokenizer.encode(word)
tokens = ' '.join(output.tokens)
else:
Expand All @@ -60,16 +86,11 @@ def generate_lexicon(args, words):
def load_words(args):
words = []
tokens_file = '{}/words.txt'.format(args.lexicon_path)
# special_words = [
# '<eps>', '!SIL', '<SPOKEN_NOISE>', '<UNK>', '<s>', '</s>', '#0'
# ]
# special_words = []

with open(tokens_file) as f:
for line in f:
arr = line.strip().split()
# if arr[0] not in special_words:
words.append(arr[0])
words.append(arr[0].lower())

return words

Expand Down
26 changes: 17 additions & 9 deletions egs/librispeech/asr/nnlm/local/huggingface_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def get_args():
def train_tokenizer(train_files, save_path, vocab_size):
if os.path.exists(save_path):
logging.warning(
"{} already exists. Please check that.".format(save_path))
return
"{} already exists. Backing up that.".format(save_path))
shutil.move(save_path, '{}'.format(save_path))
else:
Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)

Expand All @@ -52,34 +52,42 @@ def train_tokenizer(train_files, save_path, vocab_size):
tokenizer.pre_tokenizer = Whitespace()

# default vocab_size=30000
# here set vocab_size=1000 for accelerating
trainer = WordPieceTrainer(vocab_size=vocab_size, special_tokens=['[UNK]'])
tokenizer.train(train_files, trainer)
tokenizer.save(save_path)


def tokenize_text(test_file, tokenizer_path):
'''
tokenize text
input format looks like:
BOY IS BETTER UNBORN THAN
BRAVE OFFICER


output format looks like:
355 127 794 4824 346 370
1330 1898
'''
if not os.path.exists(tokenizer_path):
logging.warning(
"Tokenizer {} does not exist. Please check that.".format(
tokenizer_path))
logging.warning("Tokenizer {} does not exist.".format(tokenizer_path))
return
tokenizer = Tokenizer.from_file(tokenizer_path)
tokenizer.decoder = decoders.WordPiece()
tokenized_file = "{}.tokens".format(test_file)
# tokenized_ids = "{}.ids".format(test_file)
if os.path.exists(tokenized_file):
logging.warning(
"The input file seems already tokenized. Buckupping previous result"
)
shutil.copyfile(tokenized_file, "{}.bk".format(tokenized_file))
shutil.move(tokenized_file, "{}.bk".format(tokenized_file))
logging.warning("Tokenizing {}.".format(test_file))
fout = open(tokenized_file, 'w')
with open(test_file) as f:
for line in f:
line = line.strip()
output = tokenizer.encode(line)
fout.write(" ".join(output.tokens) + '\n')
if len(output.ids) > 0:
fout.write(' '.join([str(i) for i in output.ids]) + '\n')

fout.close()

Expand Down
34 changes: 27 additions & 7 deletions egs/librispeech/asr/nnlm/local/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import math
import torch

from common import load_checkpoint, save_checkpoint


# references:
# https://github.com/Hiroshiba/pytorch-trainer/blob/master/pytorch_trainer/training/trainer.py
Expand All @@ -26,7 +28,9 @@ def __init__(self,
batch_size=1,
epoch=0,
num_epochs=10,
log_interval=10,
clip=0.25,
log_interval=100,
model_dir="exp-nnlm/models/",
writer=None):
self.device = device
self.model = model
Expand All @@ -41,6 +45,8 @@ def __init__(self,
self.iterations = 0
self.writer = writer
self.log_interval = log_interval
self.clip = clip
self.model_dir = model_dir

def run(self):
for epoch in range(self.num_epochs):
Expand All @@ -49,13 +55,17 @@ def run(self):

if self.dev_data_loader is not None:
self.eval()
save_checkpoint("{}/epoch_{}.pt".format(self.model_dir, epoch),
self.model)

self.epoch += 1

def train(self):
self.model.train()
total_loss = 0
num_total_batch = len(self.train_data_loader)
for batch_idx, batch in enumerate(self.train_data_loader):
self.optimizer.zero_grad()
batch_input, batch_target = batch
batch_input = batch_input.to(self.device)
batch_target = batch_target.to(self.device)
Expand All @@ -65,18 +75,28 @@ def train(self):
prediction = batch_output.view(-1, self.ntokens)
target = torch.flatten(batch_target.transpose(0, 1))
loss = self.criterion(prediction, target)
self.optimizer.zero_grad()
loss.backward()

torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
self.optimizer.step()

self.writer.add_scalar('train_loss', loss, self.iterations)

self.iterations += 1
if batch_idx % self.log_interval == 0:
total_loss += loss.item()
if batch_idx % self.log_interval == 0 and batch_idx > 0:
cur_loss = total_loss / self.log_interval
log_str = 'TRAIN Batch {}/{} loss {:.6f} ppl {:.6f} at epoch {}'.format(
batch_idx, num_total_batch, loss.item(),
math.exp(loss.item()), self.epoch)
batch_idx, num_total_batch, cur_loss, math.exp(cur_loss),
self.epoch)
logging.info(log_str)
total_loss = 0.0
if batch_idx % 10000 == 0 and batch_idx > 0:
save_checkpoint(
"./exp/nn-lm/models/epoch_{}-batch_{}.pt".format(
self.epoch, batch_idx), self.model)

@torch.no_grad()
def eval(self):
self.model.eval()
total_loss = 0.0
Expand All @@ -91,9 +111,9 @@ def eval(self):
prediction = batch_output.view(-1, self.ntokens)
target = torch.flatten(batch_target.transpose(0, 1))
loss = self.criterion(prediction, target)
total_loss += loss * self.batch_size
total_loss += loss

loss = total_loss / (num_total_batch * self.batch_size)
loss = total_loss / num_total_batch
ppl = math.exp(loss)
self.writer.add_scalar('dev_ppl', ppl, self.epoch)
log_str = 'dev loss is {:.6f} and ppl {:.6f} at epoch {}'.format(
Expand Down
Loading