Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: Compute expected times per pathphone_idx. #106

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
36c56ee
WIP: Compute expected times per pathphone_idx.
csukuangfj Feb 22, 2021
fa0a0cd
add normalization.
csukuangfj Feb 23, 2021
f6d3dcb
add attribute `phones` to `den_lats`.
csukuangfj Feb 23, 2021
aeaee32
Replace expected times for self-loops with average times of neighboring
csukuangfj Feb 23, 2021
3fc6ef2
print total_occupation per pathphone_idx
csukuangfj Feb 23, 2021
2c98f3d
print some debug output.
csukuangfj Feb 23, 2021
44e77f0
add a toy example to display the `phones` attribute.
csukuangfj Feb 23, 2021
6800f60
add decoding_graphs.phones
csukuangfj Feb 23, 2021
f50e4f4
add test scripts.
csukuangfj Feb 23, 2021
6eba4ac
Initialize P.scores randomly.
csukuangfj Feb 23, 2021
2feca7c
Compute embeddings from expected times.
csukuangfj Feb 24, 2021
b3b9fc7
avoid division by zero.
csukuangfj Feb 24, 2021
35bd60f
append phone indexes and expected times to the embeddings.
csukuangfj Feb 25, 2021
6475200
Fix a bug in computing expected times for epsilon self-loops.
csukuangfj Mar 3, 2021
838f8ae
Add TDNN model for the second pass.
csukuangfj Mar 3, 2021
218bf8e
pad embeddings of different paths to the same length.
csukuangfj Mar 3, 2021
3b855ef
Also return `path_to_seq_map` in `compute_embeddings`.
csukuangfj Mar 3, 2021
f6d718e
add training script for embeddings (not working now)
csukuangfj Mar 3, 2021
fdfa412
Finish the neural network part for the second pass.
csukuangfj Mar 4, 2021
1d2c5d6
compute tot_scores for the second pass.
csukuangfj Mar 5, 2021
5412d32
Finish computing total scores from 1st and 2nd pass.
csukuangfj Mar 5, 2021
50cf49b
add decoding script.
csukuangfj Mar 5, 2021
efc542b
Merge remote-tracking branch 'dan/master' into expected-times
csukuangfj Mar 8, 2021
0155dc4
Support saving to/loading from checkpoints for the second pass model.
csukuangfj Mar 8, 2021
da2a80d
Visualize first & second pass obj separately.
csukuangfj Mar 8, 2021
04100f0
disable sorting in the decode script.
csukuangfj Mar 8, 2021
12ec856
refactoring.
csukuangfj Mar 8, 2021
8caa8ba
Support decoding with the second pass model.
csukuangfj Mar 9, 2021
231ffdb
add more comments to the second pass training code after review.
csukuangfj Mar 9, 2021
741a448
add an extra layer to the first pass model for computing embeddings.
csukuangfj Mar 9, 2021
83a5c89
Place the extra layer before LSTMs in the first pass model.
csukuangfj Mar 11, 2021
2648b42
Use the second pass model for rescoring.
csukuangfj Mar 16, 2021
8e47582
Support `num_repeats` in rescoring.
csukuangfj Mar 17, 2021
34f202a
top_sort word_lats before invoking get_tot_scores.
csukuangfj Mar 18, 2021
b1978f3
Rescore with posteriors using the 2nd-pass lattice.
csukuangfj Mar 19, 2021
9a514b5
print the log-probs of the reference input phones.
csukuangfj Mar 19, 2021
cda1c92
Replace expected time with duration and remove EOS in embeddings.
csukuangfj Mar 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
407 changes: 407 additions & 0 deletions egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_decode.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_decode.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env bash

set -xe

for epoch in $(seq 0 9); do
python3 ./mmi_bigram_embeddings_decode.py --epoch ${epoch} --enable_second_pass_decoding 0

python3 ./mmi_bigram_embeddings_decode.py --epoch ${epoch} --enable_second_pass_decoding 1
done
770 changes: 770 additions & 0 deletions egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_train.py

Large diffs are not rendered by default.

28 changes: 20 additions & 8 deletions egs/librispeech/asr/simple_v1/mmi_bigram_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader,
P.set_scores_stochastic_(model.module.P_scores)
else:
P.set_scores_stochastic_(model.P_scores)
assert P.is_cpu
assert P.requires_grad is True

curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
Expand Down Expand Up @@ -311,6 +310,13 @@ def main():
setup_dist(rank=args.local_rank, world_size=args.world_size)
fix_random_seed(42)

if not torch.cuda.is_available():
logging.error('No GPU detected!')
sys.exit(-1)

device_id = args.local_rank
device = torch.device('cuda', device_id)

start_epoch = 0
num_epochs = 10
use_adam = True
Expand All @@ -336,7 +342,8 @@ def main():
graph_compiler = MmiTrainingGraphCompiler(
L_inv=L_inv,
phones=phone_symbol_table,
words=word_symbol_table
words=word_symbol_table,
device=device
)
phone_ids = get_phone_symbols(phone_symbol_table)
P = create_bigram_phone_lm(phone_ids)
Expand Down Expand Up @@ -401,19 +408,24 @@ def main():
num_workers=1
)

if not torch.cuda.is_available():
logging.error('No GPU detected!')
sys.exit(-1)

logging.info("About to create model")
device_id = args.local_rank
device = torch.device('cuda', device_id)
model = TdnnLstm1b(num_features=40,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=3)
model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

model.to(device)
P = P.to(device)
if args.world_size > 1:
logging.info('Using DistributedDataParallel in training. '
'The reported loss, num_frames, etc. for training steps include '
'only the batches seen in the master process (the actual loss '
'includes batches from all GPUs, and the actual num_frames is '
f'approx. {args.world_size}x larger.')
# For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather...
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

describe(model)

if use_adam:
Expand Down
1 change: 1 addition & 0 deletions egs/librispeech/asr/simple_v1/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Copyright (c) 2020 Xiaomi Corporation (authors: Junbo Zhang, Haowen Qiu)
# Apache 2.0
import argparse
import os
import subprocess
import sys
Expand Down
8 changes: 4 additions & 4 deletions snowfall/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def setup_logger(log_filename: Pathlike, log_level: str = 'info', use_console: bool = True) -> None:
now = datetime.now()
date_time = now.strftime('%Y-%m-%d-%H-%M-%S')
log_filename = '{}-{}'.format(log_filename, date_time)
log_filename = '{}-{}.txt'.format(log_filename, date_time)
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s'
level = logging.ERROR
Expand Down Expand Up @@ -256,9 +256,9 @@ def str2bool(v):
raise argparse.ArgumentTypeError('Boolean value expected.')


def describe(model: torch.nn.Module):
def describe(model: torch.nn.Module, title: str = ''):
logging.info('=' * 80)
logging.info('Model parameters summary:')
logging.info(f'{title} Model parameters summary:')
logging.info('=' * 80)
total = 0
for name, param in model.named_parameters():
Expand Down Expand Up @@ -303,7 +303,7 @@ def get_texts(best_paths: k2.Fsa, indices: Optional[torch.Tensor] = None) -> Lis


def invert_permutation(indices: torch.Tensor) -> torch.Tensor:
ans = torch.zeros(indices.shape, device=indices.device, dtype=torch.long)
ans = torch.empty_like(indices, dtype=torch.long)
ans[indices] = torch.arange(0, indices.shape[0], device=indices.device)
return ans

Expand Down
258 changes: 258 additions & 0 deletions snowfall/decoding/rescore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)

import logging
import k2
import k2.fsa_properties as fsa_properties
import torch

from snowfall.common import invert_permutation
from snowfall.training.compute_embeddings import compute_embeddings_from_phone_seqs
from snowfall.training.compute_embeddings import create_phone_fsas
from snowfall.training.compute_embeddings import generate_nbest_list_phone_seqs
from snowfall.decoding.util import get_log_probs


def get_paths(lats: k2.Fsa, num_paths: int,
use_double_scores: bool = True) -> k2.RaggedInt:
'''Return a n-best list **sampled** from the given lattice.

Args:
lats:
An FsaVec, e.g., the decoding output from the 1st pass.
num_paths:
It is the `n` in `n-best`.
use_double_scores:
True to use double precision in :func:`k2.random_paths`;
False to use single precision.
Returns:
A ragged tensor with 3 axes: [seq][path][arc_pos] .
'''
assert len(lats.shape) == 3

# paths will be k2.RaggedInt with 3 axes: [seq][path][arc_pos],
# containing arc_idx012
paths = k2.random_paths(lats,
use_double_scores=use_double_scores,
num_paths=num_paths)
return paths


def get_word_fsas(lats: k2.Fsa, paths: k2.RaggedInt) -> k2.Fsa:
'''
Args:
lats:
An FsaVec, e.g., from the 1st decoding
paths:
Return value of :func:`get_paths`
'''
assert len(lats.shape) == 3
assert hasattr(lats, 'aux_labels')

# word_seqs will be k2.RaggedInt like paths, but containing words
# (and final -1's, and 0's for epsilon)
word_seqs = k2.index(lats.aux_labels, paths)

# Remove epsilons and -1 from `word_seqs`
word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

seq_to_path_shape = k2.ragged.get_layer(word_seqs.shape(), 0)
path_to_seq_map = seq_to_path_shape.row_ids(1)

word_seqs = k2.ragged.remove_axis(word_seqs, 0)

word_fsas = k2.linear_fsa(word_seqs)

word_fsas_with_epsilons = k2.add_epsilon_self_loops(word_fsas)
return word_fsas_with_epsilons, seq_to_path_shape


@torch.no_grad()
def rescore(lats: k2.Fsa,
paths: k2.RaggedInt,
word_fsas: k2.Fsa,
tot_scores_1st: torch.Tensor,
seq_to_path_shape: k2.RaggedShape,
ctc_topo: k2.Fsa,
decoding_graph: k2.Fsa,
dense_fsa_vec: k2.DenseFsaVec,
second_pass_model: torch.nn.Module,
max_phone_id: int,
use_double_scores: bool = True):
'''
Args:
lats:
Lattice from the 1st pass decoding with indexes [seq][state][arc].
paths:
An FsaVec returned by :func:`get_paths`.
word_fsas:
An FsaVec returned by :func:`get_word_fsas`.
tot_scores_1st:
Total scores of the paths from the 1st pass.
ctc_topo:
The return value of :func:`build_ctc_topo`.
decoding_graph:
An Fsa.
dense_fsa_vec:
It contains output from the first pass for computing embeddings.
Note that the output is not processed by log-softmax.
second_pass_model:
Model of the second pass.
use_double_scores:
True to use double precision in :func:`k2.Fsa.get_tot_scores`;
False to use single precision.
Returns:
Return the best_paths of each seq after rescoring.
'''
device = lats.device
assert hasattr(lats, 'phones')
assert paths.num_axes() == 3

# phone_seqs will be k2.RaggedInt like paths, but containing phones
# (and final -1's, and 0's for epsilon)
phone_seqs = k2.index(lats.phones, paths)

# Remove epsilons from `phone_seqs`
phone_seqs = k2.ragged.remove_values_eq(phone_seqs, 0)

# padded_embeddings is a 3-D tensor with shape (N, T, C)
#
# len_per_path is a 1-D tensor with shape (N,)
# len_per_path.shape[0] == N
# 0 < len_per_path[i] <= T
#
# path_to_seq is a 1-D tensor with shape (N,)
# path_to_seq.shape[0] == N
# 0 <= path_to_seq[i] < num_seqs
#
# num_repeats is a k2.RaggedInt with two axes [seq][path_multiplicities]
#
# CAUTION: Paths within a seq are reordered due to `k2.ragged.unique_sequences`.
padded_embeddings, len_per_path, path_to_seq, num_repeats, new2old = compute_embeddings_from_phone_seqs(
phone_seqs=phone_seqs,
ctc_topo=ctc_topo,
dense_fsa_vec=dense_fsa_vec,
max_phone_id=max_phone_id)

# padded_embeddings is of shape [num_paths, max_phone_seq_len, num_features]
# i.e., [N, T, C]
padded_embeddings = padded_embeddings.permute(0, 2, 1)
# now padded_embeddings is [N, C, T]

second_pass_out = second_pass_model(padded_embeddings)

# second_pass_out is of shape [N, C, T]
second_pass_out = second_pass_out.permute(0, 2, 1)
# now second_pass_out is of shape [N, T, C]

if True:
phone_seqs, _, _ = k2.ragged.unique_sequences(phone_seqs, True, True)
phone_seqs = k2.ragged.remove_axis(phone_seqs, 0)
phone_fsas = create_phone_fsas(phone_seqs)
phone_fsas = k2.add_epsilon_self_loops(phone_fsas)

probs = get_log_probs(phone_fsas, second_pass_out, len_per_path)

second_pass_supervision_segments = torch.stack(
(torch.arange(len_per_path.numel(), dtype=torch.int32),
torch.zeros_like(len_per_path), len_per_path),
dim=1)

indices2 = torch.argsort(len_per_path, descending=True)
second_pass_supervision_segments = second_pass_supervision_segments[
indices2]
# Note that path_to_seq is not changed!
# No need to modify second_pass_out

num_repeats_float = k2.ragged.RaggedFloat(
num_repeats.shape(),
num_repeats.values().to(torch.float32))
path_weight = k2.ragged.normalize_scores(num_repeats_float,
use_log=False).values

second_pass_dense_fsa_vec = k2.DenseFsaVec(
second_pass_out, second_pass_supervision_segments)

second_pass_lattices = k2.intersect_dense_pruned(
decoding_graph, second_pass_dense_fsa_vec, 20.0, 10.0, 300, 10000)

# The number of FSAs in the second_pass_lattices may not
# be equal to the number of paths since repeated paths are removed
# by k2.ragged.unique_sequences

inverted_indices2 = invert_permutation(indices2)

second_pass_lattices = k2.index(
second_pass_lattices,
inverted_indices2.to(torch.int32).to(device))
# now second_pass_lattices corresponds to the reordered paths
# (due to k2.ragged.unique_sequences)

if True:
reordered_word_fsas = k2.index(word_fsas, new2old)

reorded_lats = k2.compose(second_pass_lattices,
reordered_word_fsas,
treat_epsilons_specially=False)

if reorded_lats.properties & fsa_properties.TOPSORTED_AND_ACYCLIC != fsa_properties.TOPSORTED_AND_ACYCLIC:
reorded_lats = k2.top_sort(k2.connect(
reorded_lats.to('cpu'))).to(device)

# note some entries in `tot_scores_2nd_num` is -inf !!!
tot_scores_2nd_num = reorded_lats.get_tot_scores(
use_double_scores=True, log_semiring=True)

for k in [0, 1, 2, 30, 40, 50]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output log of this for loop is attached below:

log-decode-second-2021-03-19-21-25-09.txt

A part of them are listed as follows:

2021-03-19 21:25:30,611 INFO [rescore.py:209] 
path: 0
tot_scores: -inf
log_probs:[ [ -86.4828 -2.01945 -11.1086 -133.909 -8.51447 -4.66334  .....  -6.37545 -2.73227 -45.2354 -5.78217 ] ]

2021-03-19 21:25:30,612 INFO [rescore.py:209] 
path: 1
tot_scores: -inf
log_probs:[ [ -87.2116 -1.99367 -8.34123 -134.502 -8.31313 ....   -2.12719 -6.3754 -2.41306 -27.2158 -5.17843 ] ]

2021-03-19 21:25:30,612 INFO [rescore.py:209] 
path: 2
tot_scores: -inf
log_probs:[ [ -87.5116 -2.04247 -8.2759 -134.616 -8.62277 ..... -5.99715 -2.10805 -5.69009 -1.49818 -44.5129 -5.55875 ] ]

2021-03-19 21:25:30,613 INFO [rescore.py:209] 
path: 30
tot_scores: -358.8836602834303
log_probs:[ [ -166.046 -3.19678 -5.97864 -219.513 -6.48385  ...  -1.07081 -4.60576 -2.2935 -0.218605 -5.64143 ] ]

2021-03-19 21:25:30,613 INFO [rescore.py:209] 
path: 40
tot_scores: -401.60242305657346
log_probs:[ [ -112.147 -2.33504 -3824.3 -177.408 -9.14423 -3678.67 -6.93939 .... -2.56483 -1.64079 -4.76889 ] ]

2021-03-19 21:25:30,613 INFO [rescore.py:209] 
path: 50
tot_scores: -398.0919782588134
log_probs:[ [ -138.47 -3.37409 -5.20867 -219.942 -7.50989 -1.4931  ....   -5.48529 -2.04127 -2.47733 -4.61414 ] ]

I am not sure whether these log-probs look reasonable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numbers I was expecting would be quite close to zero, and even closer at odd positions (or maybe even.. i.e. where there are epsilons). I.e. I mean the posterior of the "reference phone" at each position (it's not really the reference, it's the sequence we use for alignment0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting you'd get it by indexing second_pass_dense_fsa_vec with some kind of tensor that's related to the reference phones. Or perhaps you could just take the sum over a particular axis, of (second_pass_dense_fsa_vec * phone_one_hot_input).

pk, _ = k2.ragged.index(probs, torch.tensor([k],
dtype=torch.int32))
assert pk.num_elements() == len_per_path[k]
logging.info(
f'\npath: {k}\ntot_scores: {tot_scores_2nd_num[k]}\nlog_probs:{str(pk)}'
)

tot_scores_2nd_den = second_pass_lattices.get_tot_scores(
log_semiring=True, use_double_scores=use_double_scores)

tot_scores_2nd = tot_scores_2nd_num - tot_scores_2nd_den

# print(
# 'word',
# reordered_word_fsas.arcs.row_splits(1)[1:] -
# reordered_word_fsas.arcs.row_splits(1)[:-1])
# print(
# reorded_lats.arcs.row_splits(1)[1:] -
# reorded_lats.arcs.row_splits(1)[:-1])
print('2 num', tot_scores_2nd_num)
print('2 den', tot_scores_2nd_den)

import sys
sys.exit(0)
else:
tot_scores_2nd = second_pass_lattices.get_tot_scores(
use_double_scores=True, log_semiring=True)

# Now tot_scores_2nd[i] corresponds to sorted_path_i
# `sorted` here is due to k2.ragged.unique_sequences.
# We have to use `new2old` to map it to the original unsorted path

# Note that path_weight was not reordered
tot_scores = tot_scores_1st
tot_scores[new2old.long()] += tot_scores_2nd * path_weight
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
tot_scores.to(torch.float32))
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
print(argmax_indexes)
# argmax_indexes may contain -1. This case happens
# when a sublist contains all -inf
argmax_indexes = torch.clamp(argmax_indexes, min=0)

paths = k2.ragged.remove_axis(paths, 0)

best_paths = k2.index(paths, argmax_indexes)
labels = k2.index(lats.labels.contiguous(), best_paths)
aux_labels = k2.index(lats.aux_labels, best_paths.values())
labels = k2.ragged.remove_values_eq(labels, -1)
best_paths = k2.linear_fsa(labels)
best_paths.aux_labels = aux_labels

return best_paths
Loading