Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: Compute expected times per pathphone_idx. #106

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
36c56ee
WIP: Compute expected times per pathphone_idx.
csukuangfj Feb 22, 2021
fa0a0cd
add normalization.
csukuangfj Feb 23, 2021
f6d3dcb
add attribute `phones` to `den_lats`.
csukuangfj Feb 23, 2021
aeaee32
Replace expected times for self-loops with average times of neighboring
csukuangfj Feb 23, 2021
3fc6ef2
print total_occupation per pathphone_idx
csukuangfj Feb 23, 2021
2c98f3d
print some debug output.
csukuangfj Feb 23, 2021
44e77f0
add a toy example to display the `phones` attribute.
csukuangfj Feb 23, 2021
6800f60
add decoding_graphs.phones
csukuangfj Feb 23, 2021
f50e4f4
add test scripts.
csukuangfj Feb 23, 2021
6eba4ac
Initialize P.scores randomly.
csukuangfj Feb 23, 2021
2feca7c
Compute embeddings from expected times.
csukuangfj Feb 24, 2021
b3b9fc7
avoid division by zero.
csukuangfj Feb 24, 2021
35bd60f
append phone indexes and expected times to the embeddings.
csukuangfj Feb 25, 2021
6475200
Fix a bug in computing expected times for epsilon self-loops.
csukuangfj Mar 3, 2021
838f8ae
Add TDNN model for the second pass.
csukuangfj Mar 3, 2021
218bf8e
pad embeddings of different paths to the same length.
csukuangfj Mar 3, 2021
3b855ef
Also return `path_to_seq_map` in `compute_embeddings`.
csukuangfj Mar 3, 2021
f6d718e
add training script for embeddings (not working now)
csukuangfj Mar 3, 2021
fdfa412
Finish the neural network part for the second pass.
csukuangfj Mar 4, 2021
1d2c5d6
compute tot_scores for the second pass.
csukuangfj Mar 5, 2021
5412d32
Finish computing total scores from 1st and 2nd pass.
csukuangfj Mar 5, 2021
50cf49b
add decoding script.
csukuangfj Mar 5, 2021
efc542b
Merge remote-tracking branch 'dan/master' into expected-times
csukuangfj Mar 8, 2021
0155dc4
Support saving to/loading from checkpoints for the second pass model.
csukuangfj Mar 8, 2021
da2a80d
Visualize first & second pass obj separately.
csukuangfj Mar 8, 2021
04100f0
disable sorting in the decode script.
csukuangfj Mar 8, 2021
12ec856
refactoring.
csukuangfj Mar 8, 2021
8caa8ba
Support decoding with the second pass model.
csukuangfj Mar 9, 2021
231ffdb
add more comments to the second pass training code after review.
csukuangfj Mar 9, 2021
741a448
add an extra layer to the first pass model for computing embeddings.
csukuangfj Mar 9, 2021
83a5c89
Place the extra layer before LSTMs in the first pass model.
csukuangfj Mar 11, 2021
2648b42
Use the second pass model for rescoring.
csukuangfj Mar 16, 2021
8e47582
Support `num_repeats` in rescoring.
csukuangfj Mar 17, 2021
34f202a
top_sort word_lats before invoking get_tot_scores.
csukuangfj Mar 18, 2021
b1978f3
Rescore with posteriors using the 2nd-pass lattice.
csukuangfj Mar 19, 2021
9a514b5
print the log-probs of the reference input phones.
csukuangfj Mar 19, 2021
cda1c92
Replace expected time with duration and remove EOS in embeddings.
csukuangfj Mar 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions snowfall/training/compute_expected_times.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
#
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)
#

import k2
import torch


def _create_phone_fsas(phone_seqs: k2.RaggedInt) -> k2.Fsa:
'''
Args:
phone_seqs:
It contains two axes with elements being phone IDs.
The last element of each sub-list is -1.
Returns:
Return an FsaVec representing the phone seqs.
'''
assert phone_seqs.num_axes() == 2
phone_seqs = k2.ragged.remove_values_eq(phone_seqs, -1)
return k2.linear_fsa(phone_seqs)


def compute_expected_times_per_phone(mbr_lats: k2.Fsa,
ctc_topo: k2.Fsa,
dense_fsa_vec: k2.DenseFsaVec,
use_double_scores=True,
num_paths=100) -> torch.Tensor:
'''Compute expected times per phone in a n-best list.

See the following comments for more information:

- `<https://github.com/k2-fsa/snowfall/issues/96>`_
- `<https://github.com/k2-fsa/k2/issues/641>`_

Args:
mbr_lats:
An FsaVec.
ctc_topo:
The return value of :func:`build_ctc_topo`.
dense_fsa_vec:
It contains nnet_output.
use_double_scores:
True to use `double` in :func:`k2.random_paths`; false to use `float`.
num_paths:
Number of random paths to draw in :func:`k2.random_paths`.
Returns:
A 1-D torch.Tensor contains the expected times per pathphone_idx.
'''
lats = mbr_lats
assert len(lats.shape) == 3
assert hasattr(lats, 'phones')

# paths will be k2.RaggedInt with 3 axes: [seq][path][arc_pos],
# containing arc_idx012
paths = k2.random_paths(lats,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found that paths is empty when lats is mbr_lats and device is cuda.

Will look into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. BTW did you already look at the decoding graph that you used to generate den_lats, and make sure that it has more nonzero labels than nonzero phones?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've added two print statements in this pull-request:

den = k2.index_fsa(ctc_topo_P_vec, indexes)
print('den.phones', den.phones.shape, 'nnz',
torch.count_nonzero(den.phones))
print('den.labels', den.labels.shape, 'nnz',
torch.count_nonzero(den.labels))

and the output is

den.phones torch.Size([30446]) nnz tensor(29928)
den.labels torch.Size([30446]) nnz tensor(30100)

den.phones contain more 0s than den.labels.

use_double_scores=use_double_scores,
num_paths=num_paths)

# phone_seqs will be k2.RaggedInt like paths, but containing phones
# (and final -1's, and 0's for epsilon)
phone_seqs = k2.index(lats.phones, paths)

# Remove epsilons from `phone_seqs`
phone_seqs = k2.ragged.remove_values_eq(phone_seqs, 0)

# Remove repeated sequences from `phone_seqs`
#
# TODO(fangjun): `num_repeats` is currently not used
phone_seqs, num_repeats = k2.ragged.unique_sequences(phone_seqs, True)

# Remove the 1st axis from `phone_seqs` (that corresponds to `seq`) and
# keep it for later, we'll be processing paths separately.
seq_to_path_shape = k2.ragged.get_layer(phone_seqs.shape(), 0)
path_to_seq_map = seq_to_path_shape.row_ids(1)

phone_seqs = k2.ragged.remove_axis(phone_seqs, 0)

# now compile decoding graphs corresponding to `phone_seqs` by constructing
# fsas from them (remember we already have the final -1's!) and composing
# with ctc_topo.
phone_fsas = _create_phone_fsas(phone_seqs)
phone_fsas = k2.add_epsilon_self_loops(phone_fsas)

# Set an attribute called pathphone_idx, which corresponds to the arc-index
# in `phone_fsas` with self-loops.
# Each phone has an index but there are blanks between them and at the start
# and end.
phone_fsas.pathphone_idx = torch.arange(phone_fsas.arcs.num_elements(),
dtype=torch.int32)

# Now extract the sets of paths from the lattices corresponding to each of
# those n-best phone sequences; these will effectively be lattices with one
# path but alternative alignments.
path_decoding_graphs = k2.compose(ctc_topo,
phone_fsas,
treat_epsilons_specially=False)

paths_lats = k2.intersect_dense(path_decoding_graphs,
dense_fsa_vec,
output_beam=10.0,
a_to_b_map=path_to_seq_map,
seqframe_idx_name='seqframe_idx')

# by seq we mean the original sequence indexes, by path we mean the indexes
# of the n-best paths; path_to_seq_map maps from path-index to seq-index.
seqs_shape = dense_fsa_vec.dense_fsa_vec.shape()

# paths_shape will also be a k2.RaggedInt with 2 axes
paths_shape, _ = k2.ragged.index(seqs_shape,
path_to_seq_map,
need_value_indexes=False)

seq_starts = seqs_shape.row_splits(1)[:-1]
path_starts = paths_shape.row_splits(1)[:-1]

# We can map from seqframe_idx for paths, to seqframe_idx for seqs,
# by adding path_offsets. path_offsets is indexed by path-index.
path_offsets = path_starts - k2.index(seq_starts, path_to_seq_map)

# assign new attribute 'pathframe_idx' that combines path and frame.
paths_lats_arc2path = k2.index(paths_lats.arcs.shape().row_ids(1),
paths_lats.arcs.shape().row_ids(2))

paths_lats.pathframe_idx = paths_lats.seqframe_idx + k2.index(
path_offsets, paths_lats_arc2path)

pathframe_to_pathphone = k2.create_sparse(rows=paths_lats.pathframe_idx,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether this is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you perhaps print out paths_pats.pathphone_idx for the two decoding graphs and see if there is an obvious difference, e.g. perhaps one has a lot more zeros or -1's in it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cols=paths_lats.pathphone_idx,
values=paths_lats.get_arc_post(
True, True).exp(),
min_col_index=0)

frame_idx = torch.arange(paths_shape.num_elements()) - k2.index(
path_starts, paths_shape.row_ids(1))

# TODO(fangjun): we can swap `rows` and `cols`
# while creating `pathframe_to_pathphone` so that
# `t()` can be omitted here.
weighted_occupation = torch.sparse.mm(
pathframe_to_pathphone.t(),
frame_idx.unsqueeze(-1).to(pathframe_to_pathphone.dtype))

return weighted_occupation.squeeze()