Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: huggingface tokenizer and Neural LM training pipeline. #139

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f038e60
hugginface tokenizer and Neural LM training pipeline.
glynpu Mar 25, 2021
e9482d2
draft of class LMDataset
glynpu Mar 29, 2021
135bfdb
a dummy implementation of LMDataset
glynpu Mar 29, 2021
88e0d49
collate function of NNLM
glynpu Mar 30, 2021
27b1863
add scripts to process word piece lexicons.
csukuangfj Mar 30, 2021
212b79b
Merge pull request #2 from csukuangfj/fangjun-rnnlm
glynpu Mar 30, 2021
47bf358
trainer
glynpu Mar 30, 2021
d8aaabd
generate lexicon
glynpu Mar 30, 2021
c44f99d
check text length in dataset.py
glynpu Mar 30, 2021
b13954d
remove shuf/comm commands
glynpu Mar 30, 2021
775d477
beta version of training pipeline
glynpu Mar 30, 2021
3b83338
Merge pull request #1 from glynpu/lyg_dev
glynpu Mar 30, 2021
d415ed0
remove unused file
glynpu Mar 30, 2021
4937232
add dependency and fix known bugs
glynpu Apr 1, 2021
61863db
fix various bugs
glynpu Apr 2, 2021
d4dccae
compute word_ppl from token_ppl
glynpu Apr 2, 2021
a4d5f1b
add results.md
glynpu Apr 3, 2021
53e2d1e
compute word_ppl from token_ppl
glynpu Apr 3, 2021
b226a3a
support yaml configuration
glynpu Apr 9, 2021
89ece61
update results with nvocab=5000
glynpu Apr 9, 2021
c3f8811
fix reviews
glynpu Apr 9, 2021
d1b803b
fixed reviews
glynpu Apr 9, 2021
c45d31f
support multi-gpu training with ddp
glynpu Apr 10, 2021
1d38c21
n-best rescoring result with 8-layer transformer lm
glynpu Apr 14, 2021
f6914cd
Merge remote-tracking branch 'dan/master' into nnlm
glynpu Apr 20, 2021
d847b28
filter train data by length to increase batch_size
glynpu Apr 20, 2021
52300df
use Noam optimizer
glynpu Apr 20, 2021
e61a9d1
add rescore scripts
glynpu Apr 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[flake8]
show-source=true
statistics=true
max-line-length=80
exclude =
.git,

ignore =
# E127 continuation line over-indented for visual indent
E127,
# F401, import but not used
F401,
# W504, line break after binary operator
W504,
Empty file.
72 changes: 72 additions & 0 deletions egs/librispeech/asr/nnlm/compute_word_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3

# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo)
# Apache 2.0

# Reference:
# https://github.com/espnet/espnet/blob/master/espnet/lm/pytorch_backend/lm.py
# https://github.com/mobvoi/wenet/blob/main/wenet/bin/train.py
import argparse

import logging
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import sys
import yaml

sys.path.insert(0, './local/')

from common import load_checkpoint
from evaluator import Evaluator
# from model import TransformerModel
from pathlib import Path
from typing import List, Dict


def get_args():
parser = argparse.ArgumentParser(
description='compute token/word ppl of txt')
parser.add_argument('--config',
help='config file',
default='conf/lm_small_transformer.yaml')
parser.add_argument('--vocab_size', type=int, default=5000)
parser.add_argument('--model',
type=str,
default='exp-nnlm/models/epoch_30.pt',
help='full path of loaded model')
parser.add_argument('--tokenizer_path',
type=str,
default='exp-nnlm/tokenizer-librispeech.json')
parser.add_argument('--txt_file',
type=str,
default='data/nnlm/text/dev.txt')

args = parser.parse_args()

return args


def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')

# Set random seed
torch.manual_seed(2021)

# device = torch.device("cuda", args.local_rank)
device = torch.device('cpu')
print(device)

evaluator = Evaluator(device=device,
model_path=args.model,
config_file=args.config,
tokenizer_path=args.tokenizer_path)
evaluator.compute_ppl(txt_file=args.txt_file)


if __name__ == '__main__':
main()
45 changes: 45 additions & 0 deletions egs/librispeech/asr/nnlm/conf/lm_small_transformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

gpu: 1
tensorboard_dir: 'exp-nnlm/tensorobard'

# network architecture equivalent configuration to
# https://github.com/pytorch/examples/blob/master/word_language_model/main.py
model_module: transformer
transformer_conf:
embed_unit: 200
attention_heads: 8
nlayers: 16
linear_units: 2048
dropout: 0.2

shared_conf:
ntoken: 5003

# Now using Noam optimizer and tuning configuration
# optimizer_conf:
# # for Adam
# lr: 0.0003
# weight_decay: 0.001
# # for SGD
# # lr: 0.01
# # weight_decay: 0.001

trainer_conf:
num_epochs: 60
clip: 0.25
model_dir: './exp-nnlm/models/'


dataset_conf:
train_token: 'data/nnlm/text/librispeech.txt.tokens'
dev_token: 'data/nnlm/text/dev.txt.tokens'

dataloader_conf:
train:
batch_size: 256
num_workers: 0
drop_last: True
dev:
batch_size: 20
num_workers: 0
drop_last: False
41 changes: 41 additions & 0 deletions egs/librispeech/asr/nnlm/conf/lm_transformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# modified from:
# https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/conf/tuning/lm_transformer.yaml

gpu: 1
tensorboard_dir: 'exp-nnlm/tensorobard'

# network architecture
model_module: transformer
transformer_conf:
embed_unit: 128
attention_heads: 8
nlayers: 16
linear_units: 2048
dropout: 0.2

shared_conf:
ntoken: 5003

optimizer_conf:
lr: 0.02
weight_decay: 0.005

trainer_conf:
num_epochs: 50
clip: 0.25
model_dir: './exp-nnlm/models/'


dataset_conf:
train_token: 'data/nnlm/text/librispeech.txt.tokens'
dev_token: 'data/nnlm/text/dev.txt.tokens'

dataloader_conf:
train:
batch_size: 60
num_workers: 10
drop_last: True
dev:
batch_size: 60
num_workers: 10
drop_last: False
42 changes: 42 additions & 0 deletions egs/librispeech/asr/nnlm/local/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo)
# Apache 2.0

# modified from https://github.com/k2-fsa/snowfall/blob/master/snowfall/common.py to save/load non-Acoustic Model
import logging
import os
import torch

from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

Pathlike = Union[str, Path]
Info = Optional[dict]


def load_checkpoint(filename: Pathlike,
model: torch.nn.Module,
info: Info = None) -> Dict[str, Any]:
logging.info('load checkpoint from {}'.format(filename))

checkpoint = torch.load(filename, map_location='cpu')

model.load_state_dict(checkpoint['state_dict'])

return checkpoint


def save_checkpoint(filename: Pathlike,
model: torch.nn.Module,
info: Info = None) -> None:
if not os.path.exists(os.path.dirname(filename)):
Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
logging.info(f'Save checkpoint to {filename}')
checkpoint = {
'state_dict': model.module.state_dict(),
}
if info is not None:
checkpoint.update(info)

torch.save(checkpoint, filename)
101 changes: 101 additions & 0 deletions egs/librispeech/asr/nnlm/local/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python3

# Copyright (c) 2020 Xiaomi Corporation (author: Liyong Guo)
# Apache 2.0

import time
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from typing import List

import numpy as np
import os
import torch


class CollateFunc(object):
'''Collate function for LMDataset
'''

def __init__(self, pad_index=None):
# pad_index should be identical to ignore_index of torch.nn.NLLLoss
# and padding_idx in torch.nn.Embedding
self.pad_index = pad_index

def __call__(self, batch: List[List[int]]):
'''batch contains token_id.
batch can be viewd as a ragged 2-d array, with a row represents a token_id.
token_id reprents a tokenized text, whose format is:
<bos_id> token_id token_id token_id *** <eos_id>
'''
# data_pad: [batch_size, seq_len]
# each seq_len always different
data_pad = pad_sequence(
[torch.from_numpy(np.array(x)).long() for x in batch], True,
self.pad_index)
data_pad = data_pad.t().contiguous()
# xs_pad, ys_pad: [max_seq_len, batch_size]
# max_seq_len is the maximum length in current batch
xs_pad = data_pad[:-1, :]
ys_pad = data_pad[1:, :]
return xs_pad, ys_pad


class LMDataset(Dataset):

def __init__(self, token_file: str, ntoken: int):
'''Dataset to load Language Model train/dev text data

Args:
token_file: each line is a tokenized text, looks like:
token_id token_id *** token_id token_id

A real example is:

485 135 974 255 1220 33 35 377
2130 1960

when loaded, <bos_id>/<eos_id> is added to compose input/target

'''
self.bos_id = ntoken - 3
self.eos_id = ntoken - 2
self.pad_index = ntoken - 1
assert os.path.exists(
token_file
), "token_file: {} does not exist, please check that.".format(
token_file)
self.data = []
with open(token_file, 'r') as f:
for line in f:
token_id = [int(i) for i in line.strip().split()]
# Empty line exists in librispeech.txt. Disregrad that.
if len(token_id) == 0:
continue
# https://github.com/espnet/espnet/blob/master/espnet/lm/lm_utils.py#L179
# add bos_id and eos_id to each piece of example
token_id.insert(0, self.bos_id)
token_id.append(self.eos_id)
self.data.append(token_id)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx]


if __name__ == '__main__':
dev_file = "./data/nnlm/text/dev.txt.tokens"
dataset = LMDataset(dev_file)
collate_func = CollateFunc()
data_loader = DataLoader(dataset,
batch_size=2,
shuffle=True,
num_workers=0,
collate_fn=collate_func)
for i, batch in enumerate(data_loader):
xs, ys = batch
print(xs)
print(ys)
print(batch)
Loading