-
Notifications
You must be signed in to change notification settings - Fork 42
WIP: huggingface tokenizer and Neural LM training pipeline. #139
base: master
Are you sure you want to change the base?
Changes from 1 commit
f038e60
e9482d2
135bfdb
88e0d49
27b1863
212b79b
47bf358
d8aaabd
c44f99d
b13954d
775d477
3b83338
d415ed0
4937232
61863db
d4dccae
a4d5f1b
53e2d1e
b226a3a
89ece61
c3f8811
d1b803b
c45d31f
1d38c21
f6914cd
d847b28
52300df
e61a9d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo) | ||
# Apache 2.0 | ||
|
||
# modified from https://github.com/k2-fsa/snowfall/blob/master/snowfall/common.py to save/load non-Acoustic Model | ||
import logging | ||
import os | ||
import torch | ||
|
||
from pathlib import Path | ||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | ||
|
||
Pathlike = Union[str, Path] | ||
Info = Union[dict, None] | ||
|
||
|
||
def load_checkpoint(filename: Pathlike, | ||
model: torch.nn.Module, | ||
info: Info = None) -> Dict[str, Any]: | ||
logging.info('load checkpoint from {}'.format(filename)) | ||
|
||
checkpoint = torch.load(filename, map_location='cpu') | ||
|
||
model.load_state_dict(checkpoint['state_dict']) | ||
|
||
return checkpoint | ||
|
||
|
||
def save_checkpoint(filename: Pathlike, | ||
model: torch.nn.Module, | ||
info: Info = None) -> None: | ||
if not os.path.exists(os.path.dirname(filename)): | ||
Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True) | ||
logging.info(f'Save checkpoint to {filename}') | ||
checkpoint = { | ||
'state_dict': model.state_dict(), | ||
} | ||
if info is not None: | ||
checkpoint.update(info) | ||
|
||
torch.save(checkpoint, filename) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,10 @@ | |
# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo) | ||
# Apache 2.0 | ||
|
||
import time | ||
from torch.utils.data import Dataset, DataLoader | ||
from torch.nn.utils.rnn import pad_sequence | ||
from typing import List | ||
from util import convert_tokens_to_ids | ||
|
||
import numpy as np | ||
import os | ||
|
@@ -37,35 +37,22 @@ def __call__(self, batch: List[List[int]]): | |
|
||
class LMDataset(Dataset): | ||
|
||
def __init__(self, text_file: str, lexicon): | ||
def __init__(self, text_file: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you describe the format of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
'''Dataset to load Language Model train/dev text data | ||
|
||
Args: | ||
text_file: text file, text for one utt per line. | ||
''' | ||
self.lexicon = lexicon | ||
assert os.path.exists( | ||
text_file), "text_file: {} does not exist, please check that." | ||
text_file | ||
), "text_file: {} does not exist, please check that.".format(text_file) | ||
self.data = [] | ||
with open(text_file, 'r') as f: | ||
# a line represent a piece of text, e.g. | ||
# DELAWARE IS NOT AFRAID OF DOGS | ||
for line in f: | ||
# import pdb | ||
# pdb.set_trace() | ||
text = line.strip().lower().split() | ||
# print(text) | ||
if len(text) == 0: | ||
continue | ||
word_id = convert_tokens_to_ids(text, self.lexicon.word2id) | ||
if len(word_id) == 0: | ||
continue | ||
word_id = torch.from_numpy(np.array(word_id, dtype="int32")) | ||
|
||
token_id = self.lexicon.word_seq_to_word_piece_seq(word_id) | ||
# token_id format: | ||
# <bos_id> token_id token_id token_id *** <eos_id> | ||
if len(token_id) >= 2: | ||
for idx, line in enumerate(f): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
token_id = [int(i) for i in line.strip().split()] | ||
# TODO(Liyong Guo): add bos_id and eos_id to each piece of example | ||
# then each valid example should be longer than 2 | ||
if len(token_id) > 2: | ||
self.data.append(token_id) | ||
|
||
def __len__(self): | ||
|
@@ -74,18 +61,9 @@ def __len__(self): | |
def __getitem__(self, idx): | ||
return self.data[idx] | ||
|
||
def text2id(self, text: List[str]) -> List[int]: | ||
# A dumpy implementation | ||
return [i for i in range(len(text))] | ||
|
||
def text_id2token_id(self, text_id: List[int]) -> List[int]: | ||
# A dumpy implementation | ||
return [i for i in range(len(text_id))] | ||
|
||
|
||
if __name__ == '__main__': | ||
# train_file = "./data/nnlm/text/librispeech.txt" | ||
dev_file = "./data/nnlm/text/dev.txt" | ||
dev_file = "./data/nnlm/text/dev.txt.tokens" | ||
dataset = LMDataset(dev_file) | ||
collate_func = CollateFunc() | ||
data_loader = DataLoader(dataset, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# Apache 2.0 | ||
|
||
import argparse | ||
import collections | ||
from tokenizers import Tokenizer | ||
from tokenizers.models import WordPiece | ||
from tokenizers import decoders | ||
|
@@ -29,17 +30,41 @@ def get_args(): | |
|
||
|
||
def generate_tokens(args): | ||
''' Extract symbols and there corresponding ids from a tokenizer, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fxied |
||
and save as tokens.txt. | ||
An example file looks like: | ||
a 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does an ID start from 0 or is 0 reserved for a special token? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not yet. Now index 0 is occupied by [unk]. Head of a real tokens.txt is:
I will check is there a way to reserve index 0 with hugginface tokenizer. |
||
b 2 | ||
c 3 | ||
... | ||
it 100 | ||
sh 101 | ||
|
||
''' | ||
|
||
tokenizer = Tokenizer.from_file(args.tokenizer_path) | ||
symbols = tokenizer.get_vocab() | ||
tokens_file = '{}/tokens.txt'.format(args.lexicon_path) | ||
tokens_f = open(tokens_file, 'w') | ||
for idx, sym in enumerate(symbols): | ||
tokens_f.write('{} {}\n'.format(sym.lower(), idx)) | ||
id2sym = dict((v, k.lower()) for k, v in symbols.items()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. id2sym = {idx: sym.lower() for sym, idx in symbols.items()} is much clearer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
for idx in range(len(symbols)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it required that the resulting file has its second column listed in increasing order? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to ensure that ids are continues. And a ordered tokens.list looks nice.
looks like following(quite disorded): |
||
assert idx in id2sym | ||
tokens_f.write('{} {}\n'.format(id2sym[idx], idx)) | ||
|
||
tokens_f.close() | ||
|
||
|
||
def generate_lexicon(args, words): | ||
''' Tokenize every word in words.txt and save as lexicont.txt. | ||
Each line represents a word and its tokenized representation, i.e. a sequence of tokens. a word and its tokens are seprated by a table. | ||
|
||
An example file looks like: | ||
|
||
abbreviating abb ##re ##via ##ting | ||
abbreviation abb ##re ##via ##t ##ion | ||
abbreviations abb ##re ##via ##t ##ions | ||
|
||
''' | ||
special_words = [ | ||
'<eps>', '!SIL', '<SPOKEN_NOISE>', '<UNK>', '<s>', '</s>', '#0' | ||
] | ||
|
@@ -48,7 +73,8 @@ def generate_lexicon(args, words): | |
tokenizer = Tokenizer.from_file(args.tokenizer_path) | ||
tokenizer.decoder = decoders.WordPiece() | ||
for word in words: | ||
if word not in special_words: | ||
if not (word.upper() in special_words or | ||
word.lower() in special_words): | ||
output = tokenizer.encode(word) | ||
tokens = ' '.join(output.tokens) | ||
else: | ||
|
@@ -60,16 +86,11 @@ def generate_lexicon(args, words): | |
def load_words(args): | ||
words = [] | ||
tokens_file = '{}/words.txt'.format(args.lexicon_path) | ||
# special_words = [ | ||
# '<eps>', '!SIL', '<SPOKEN_NOISE>', '<UNK>', '<s>', '</s>', '#0' | ||
# ] | ||
# special_words = [] | ||
|
||
with open(tokens_file) as f: | ||
for line in f: | ||
arr = line.strip().split() | ||
# if arr[0] not in special_words: | ||
words.append(arr[0]) | ||
words.append(arr[0].lower()) | ||
|
||
return words | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is equivalent to
Info = Optional[dict]