This repository has been archived by the owner on Oct 13, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 42
WIP: Compute expected times per pathphone_idx. #106
Open
csukuangfj
wants to merge
37
commits into
k2-fsa:master
Choose a base branch
from
csukuangfj:expected-times
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
36c56ee
WIP: Compute expected times per pathphone_idx.
csukuangfj fa0a0cd
add normalization.
csukuangfj f6d3dcb
add attribute `phones` to `den_lats`.
csukuangfj aeaee32
Replace expected times for self-loops with average times of neighboring
csukuangfj 3fc6ef2
print total_occupation per pathphone_idx
csukuangfj 2c98f3d
print some debug output.
csukuangfj 44e77f0
add a toy example to display the `phones` attribute.
csukuangfj 6800f60
add decoding_graphs.phones
csukuangfj f50e4f4
add test scripts.
csukuangfj 6eba4ac
Initialize P.scores randomly.
csukuangfj 2feca7c
Compute embeddings from expected times.
csukuangfj b3b9fc7
avoid division by zero.
csukuangfj 35bd60f
append phone indexes and expected times to the embeddings.
csukuangfj 6475200
Fix a bug in computing expected times for epsilon self-loops.
csukuangfj 838f8ae
Add TDNN model for the second pass.
csukuangfj 218bf8e
pad embeddings of different paths to the same length.
csukuangfj 3b855ef
Also return `path_to_seq_map` in `compute_embeddings`.
csukuangfj f6d718e
add training script for embeddings (not working now)
csukuangfj fdfa412
Finish the neural network part for the second pass.
csukuangfj 1d2c5d6
compute tot_scores for the second pass.
csukuangfj 5412d32
Finish computing total scores from 1st and 2nd pass.
csukuangfj 50cf49b
add decoding script.
csukuangfj efc542b
Merge remote-tracking branch 'dan/master' into expected-times
csukuangfj 0155dc4
Support saving to/loading from checkpoints for the second pass model.
csukuangfj da2a80d
Visualize first & second pass obj separately.
csukuangfj 04100f0
disable sorting in the decode script.
csukuangfj 12ec856
refactoring.
csukuangfj 8caa8ba
Support decoding with the second pass model.
csukuangfj 231ffdb
add more comments to the second pass training code after review.
csukuangfj 741a448
add an extra layer to the first pass model for computing embeddings.
csukuangfj 83a5c89
Place the extra layer before LSTMs in the first pass model.
csukuangfj 2648b42
Use the second pass model for rescoring.
csukuangfj 8e47582
Support `num_repeats` in rescoring.
csukuangfj 34f202a
top_sort word_lats before invoking get_tot_scores.
csukuangfj b1978f3
Rescore with posteriors using the 2nd-pass lattice.
csukuangfj 9a514b5
print the log-probs of the reference input phones.
csukuangfj cda1c92
Replace expected time with duration and remove EOS in embeddings.
csukuangfj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang) | ||
# | ||
|
||
import k2 | ||
import torch | ||
|
||
|
||
def _create_phone_fsas(phone_seqs: k2.RaggedInt) -> k2.Fsa: | ||
''' | ||
Args: | ||
phone_seqs: | ||
It contains two axes with elements being phone IDs. | ||
The last element of each sub-list is -1. | ||
Returns: | ||
Return an FsaVec representing the phone seqs. | ||
''' | ||
assert phone_seqs.num_axes() == 2 | ||
phone_seqs = k2.ragged.remove_values_eq(phone_seqs, -1) | ||
return k2.linear_fsa(phone_seqs) | ||
|
||
|
||
def compute_expected_times_per_phone(mbr_lats: k2.Fsa, | ||
ctc_topo: k2.Fsa, | ||
dense_fsa_vec: k2.DenseFsaVec, | ||
use_double_scores=True, | ||
num_paths=100) -> torch.Tensor: | ||
'''Compute expected times per phone in a n-best list. | ||
|
||
See the following comments for more information: | ||
|
||
- `<https://github.com/k2-fsa/snowfall/issues/96>`_ | ||
- `<https://github.com/k2-fsa/k2/issues/641>`_ | ||
|
||
Args: | ||
mbr_lats: | ||
An FsaVec. | ||
ctc_topo: | ||
The return value of :func:`build_ctc_topo`. | ||
dense_fsa_vec: | ||
It contains nnet_output. | ||
use_double_scores: | ||
True to use `double` in :func:`k2.random_paths`; false to use `float`. | ||
num_paths: | ||
Number of random paths to draw in :func:`k2.random_paths`. | ||
Returns: | ||
A 1-D torch.Tensor contains the expected times per pathphone_idx. | ||
''' | ||
lats = mbr_lats | ||
assert len(lats.shape) == 3 | ||
assert hasattr(lats, 'phones') | ||
|
||
# paths will be k2.RaggedInt with 3 axes: [seq][path][arc_pos], | ||
# containing arc_idx012 | ||
paths = k2.random_paths(lats, | ||
use_double_scores=use_double_scores, | ||
num_paths=num_paths) | ||
|
||
# phone_seqs will be k2.RaggedInt like paths, but containing phones | ||
# (and final -1's, and 0's for epsilon) | ||
phone_seqs = k2.index(lats.phones, paths) | ||
|
||
# Remove epsilons from `phone_seqs` | ||
phone_seqs = k2.ragged.remove_values_eq(phone_seqs, 0) | ||
|
||
# Remove repeated sequences from `phone_seqs` | ||
# | ||
# TODO(fangjun): `num_repeats` is currently not used | ||
phone_seqs, num_repeats = k2.ragged.unique_sequences(phone_seqs, True) | ||
|
||
# Remove the 1st axis from `phone_seqs` (that corresponds to `seq`) and | ||
# keep it for later, we'll be processing paths separately. | ||
seq_to_path_shape = k2.ragged.get_layer(phone_seqs.shape(), 0) | ||
path_to_seq_map = seq_to_path_shape.row_ids(1) | ||
|
||
phone_seqs = k2.ragged.remove_axis(phone_seqs, 0) | ||
|
||
# now compile decoding graphs corresponding to `phone_seqs` by constructing | ||
# fsas from them (remember we already have the final -1's!) and composing | ||
# with ctc_topo. | ||
phone_fsas = _create_phone_fsas(phone_seqs) | ||
phone_fsas = k2.add_epsilon_self_loops(phone_fsas) | ||
|
||
# Set an attribute called pathphone_idx, which corresponds to the arc-index | ||
# in `phone_fsas` with self-loops. | ||
# Each phone has an index but there are blanks between them and at the start | ||
# and end. | ||
phone_fsas.pathphone_idx = torch.arange(phone_fsas.arcs.num_elements(), | ||
dtype=torch.int32) | ||
|
||
# Now extract the sets of paths from the lattices corresponding to each of | ||
# those n-best phone sequences; these will effectively be lattices with one | ||
# path but alternative alignments. | ||
path_decoding_graphs = k2.compose(ctc_topo, | ||
phone_fsas, | ||
treat_epsilons_specially=False) | ||
|
||
paths_lats = k2.intersect_dense(path_decoding_graphs, | ||
dense_fsa_vec, | ||
output_beam=10.0, | ||
a_to_b_map=path_to_seq_map, | ||
seqframe_idx_name='seqframe_idx') | ||
|
||
# by seq we mean the original sequence indexes, by path we mean the indexes | ||
# of the n-best paths; path_to_seq_map maps from path-index to seq-index. | ||
seqs_shape = dense_fsa_vec.dense_fsa_vec.shape() | ||
|
||
# paths_shape will also be a k2.RaggedInt with 2 axes | ||
paths_shape, _ = k2.ragged.index(seqs_shape, | ||
path_to_seq_map, | ||
need_value_indexes=False) | ||
|
||
seq_starts = seqs_shape.row_splits(1)[:-1] | ||
path_starts = paths_shape.row_splits(1)[:-1] | ||
|
||
# We can map from seqframe_idx for paths, to seqframe_idx for seqs, | ||
# by adding path_offsets. path_offsets is indexed by path-index. | ||
path_offsets = path_starts - k2.index(seq_starts, path_to_seq_map) | ||
|
||
# assign new attribute 'pathframe_idx' that combines path and frame. | ||
paths_lats_arc2path = k2.index(paths_lats.arcs.shape().row_ids(1), | ||
paths_lats.arcs.shape().row_ids(2)) | ||
|
||
paths_lats.pathframe_idx = paths_lats.seqframe_idx + k2.index( | ||
path_offsets, paths_lats_arc2path) | ||
|
||
pathframe_to_pathphone = k2.create_sparse(rows=paths_lats.pathframe_idx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure whether this is correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you perhaps print out paths_pats.pathphone_idx for the two decoding graphs and see if there is an obvious difference, e.g. perhaps one has a lot more zeros or -1's in it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
cols=paths_lats.pathphone_idx, | ||
values=paths_lats.get_arc_post( | ||
True, True).exp(), | ||
min_col_index=0) | ||
|
||
frame_idx = torch.arange(paths_shape.num_elements()) - k2.index( | ||
path_starts, paths_shape.row_ids(1)) | ||
|
||
# TODO(fangjun): we can swap `rows` and `cols` | ||
# while creating `pathframe_to_pathphone` so that | ||
# `t()` can be omitted here. | ||
weighted_occupation = torch.sparse.mm( | ||
pathframe_to_pathphone.t(), | ||
frame_idx.unsqueeze(-1).to(pathframe_to_pathphone.dtype)) | ||
|
||
return weighted_occupation.squeeze() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just found that
paths
is empty whenlats
ismbr_lats
and device is cuda.Will look into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. BTW did you already look at the decoding graph that you used to generate
den_lats
, and make sure that it has more nonzero labels than nonzero phones?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I've added two print statements in this pull-request:
snowfall/snowfall/training/mmi_mbr_graph.py
Lines 162 to 168 in f50e4f4
and the output is
den.phones
contain more 0s thanden.labels
.