Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Support num_repeats in rescoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Mar 17, 2021
1 parent 2648b42 commit 8e47582
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 42 deletions.
11 changes: 5 additions & 6 deletions egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from lhotse.utils import fix_random_seed
from snowfall.common import find_first_disambig_symbol
from snowfall.common import get_texts
from snowfall.common import invert_permutation
Expand Down Expand Up @@ -85,8 +86,7 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
# lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0)
best_paths = k2.shortest_path(lattices, use_double_scores=True)
else:
# FIXME(fangjun): increase num_paths and fix `num_repeats` in rescore.py
paths = get_paths(lats=lattices, num_paths=1)
paths = get_paths(lats=lattices, num_paths=10)
word_fsas, seq_to_path_shape = get_word_fsas(lattices, paths)
replicated_lats = k2.index(lattices, seq_to_path_shape.row_ids(1))
word_lats = k2.compose(replicated_lats,
Expand Down Expand Up @@ -184,10 +184,7 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
float(num_cuts) / tot_num_cuts * 100))

num_cuts += len(texts)

# FIXME(fangjun): remove it
if batch_idx == 50:
break
if batch_idx == 140: break

return results

Expand Down Expand Up @@ -276,6 +273,8 @@ def main():
else:
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

fix_random_seed(42)

logging.info(f'enable second pass model for decoding: {args.enable_second_pass_decoding}')

# load L, G, symbol_table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_objf(batch: Dict,
# Now compute the tot_scores for the second pass
#
# TODO(fangjun): We probably need to split it into a separate function
padded_embeddings, len_per_path, path_to_seq, num_repeats = compute_embeddings(
padded_embeddings, len_per_path, path_to_seq, num_repeats, _ = compute_embeddings(
den_lats,
graph_compiler.ctc_topo,
dense_fsa_vec_2nd,
Expand Down
2 changes: 1 addition & 1 deletion snowfall/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def setup_logger(log_filename: Pathlike, log_level: str = 'info', use_console: bool = True) -> None:
now = datetime.now()
date_time = now.strftime('%Y-%m-%d-%H-%M-%S')
log_filename = '{}-{}'.format(log_filename, date_time)
log_filename = '{}-{}.txt'.format(log_filename, date_time)
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s'
level = logging.ERROR
Expand Down
23 changes: 15 additions & 8 deletions snowfall/decoding/rescore.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def rescore(lats: k2.Fsa,
# 0 <= path_to_seq[i] < num_seqs
#
# num_repeats is a k2.RaggedInt with two axes [seq][path_multiplicities]
padded_embeddings, len_per_path, path_to_seq, num_repeats = compute_embeddings_from_phone_seqs(
#
# CAUTION: Paths within a seq are reordered due to `k2.ragged.unique_sequences`.
padded_embeddings, len_per_path, path_to_seq, num_repeats, new2old = compute_embeddings_from_phone_seqs(
phone_seqs=phone_seqs,
ctc_topo=ctc_topo,
dense_fsa_vec=dense_fsa_vec,
Expand All @@ -142,31 +144,36 @@ def rescore(lats: k2.Fsa,
indices2 = torch.argsort(len_per_path, descending=True)
second_pass_supervision_segments = second_pass_supervision_segments[
indices2]
# note that path_to_seq is not changed!
# no need to modify second_pass_out
# Note that path_to_seq is not changed!
# No need to modify second_pass_out

num_repeats_float = k2.ragged.RaggedFloat(
num_repeats.shape(),
num_repeats.values().to(torch.float32))
path_weight = k2.ragged.normalize_scores(num_repeats_float,
use_log=False).values
path_weight = path_weight[indices2]

second_pass_dense_fsa_vec = k2.DenseFsaVec(
second_pass_out, second_pass_supervision_segments)

second_pass_lattices = k2.intersect_dense_pruned(
decoding_graph, second_pass_dense_fsa_vec, 20.0, 7.0, 30, 10000)
decoding_graph, second_pass_dense_fsa_vec, 20.0, 7.0, 30, 20000)

# second_pass_lattices = k2.intersect_dense(decoding_graph,
# second_pass_dense_fsa_vec, 10.0)

tot_scores = second_pass_lattices.get_tot_scores(
log_semiring=True, use_double_scores=use_double_scores)

inverted_indices2 = invert_permutation(indices2)
tot_scores_2nd = tot_scores[inverted_indices2]
# now tot_scores_2nd[i] corresponds to path_i
# Now tot_scores_2nd[i] corresponds to sorted_path_i
# `sorted` here is due to k2.ragged.unique_sequences.
# We have to use `new2old` to map it to the original unsorted path

# FIXME(fangjun): Handle the case when num_repeats contains entries > 1
tot_scores = tot_scores_1st + tot_scores_2nd
# Note that path_weight was not reordered
tot_scores = tot_scores_1st
tot_scores[new2old.long()] += tot_scores_2nd * path_weight
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
tot_scores.to(torch.float32))
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
Expand Down
49 changes: 27 additions & 22 deletions snowfall/training/compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def compute_expected_times(
dense_fsa_vec: k2.DenseFsaVec,
use_double_scores: bool = True,
debug: bool = False
) -> Tuple[torch.Tensor, k2.Fsa, torch.Tensor, k2.RaggedInt]:
) -> Tuple[torch.Tensor, k2.Fsa, torch.Tensor, k2.RaggedInt, torch.Tensor]:
'''
Args:
phone_seqs:
Expand All @@ -96,11 +96,13 @@ def compute_expected_times(
- phone_fsas, an FsaVec with indexes [path][phones]
- path_to_seq_map, 1-D torch.Tensor
- num_repeats, a k2.RaggedInt with 2 axes [path][multiplicities]
- new2old, a 1-D torch.Tensor, see :func:`k2.ragged.unique_sequences`
'''
device = ctc_topo.device

if phone_seqs.num_axes() == 3:
phone_seqs, num_repeats = k2.ragged.unique_sequences(phone_seqs, True)
phone_seqs, num_repeats, new2old = k2.ragged.unique_sequences(
phone_seqs, True, True)

# Remove the 1st axis from `phone_seqs` (that corresponds to `seq`) and
# keep it for later; we'll be processing paths separately.
Expand Down Expand Up @@ -128,6 +130,7 @@ def compute_expected_times(

seq_to_path_shape = num_repeats_shape
path_to_seq_map = seq_to_path_shape.row_ids(1) # an identity map
new2old = path_to_seq_map

# now compile decoding graphs corresponding to `phone_seqs` by constructing
# fsas from them (remember we already have the final -1's!) and composing
Expand Down Expand Up @@ -242,7 +245,7 @@ def compute_expected_times(
# TODO(fangjun): do we need to support `torch.int32` for the indexing
expected_times[first_epsilon_offset[:-1].long()] = 0

return expected_times, phone_fsas, path_to_seq_map, num_repeats
return expected_times, phone_fsas, path_to_seq_map, num_repeats, new2old


def compute_embeddings_from_nnet_output(expected_times: torch.Tensor,
Expand Down Expand Up @@ -358,8 +361,8 @@ def compute_embeddings_from_phone_seqs(
dense_fsa_vec: k2.DenseFsaVec,
max_phone_id: int,
use_double_scores: bool = True,
debug: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.RaggedInt]:
debug: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.
Tensor, k2.RaggedInt, torch.Tensor]:
'''
Args:
phone_seqs:
Expand All @@ -384,7 +387,7 @@ def compute_embeddings_from_phone_seqs(
- num_repeats, a ragged tensor of type k2.RaggedInt with 2
axes [path][multiplicities]
'''
expected_times, phone_fsas, path_to_seq_map, num_repeats = compute_expected_times( # noqa
expected_times, phone_fsas, path_to_seq_map, num_repeats, new2old = compute_expected_times( # noqa
phone_seqs=phone_seqs,
ctc_topo=ctc_topo,
dense_fsa_vec=dense_fsa_vec,
Expand Down Expand Up @@ -420,19 +423,19 @@ def compute_embeddings_from_phone_seqs(
padded_embeddings = torch.nn.utils.rnn.pad_sequence(embeddings_per_path,
batch_first=True)

return padded_embeddings.to(
torch.float32), len_per_path.cpu(), path_to_seq_map, num_repeats
return padded_embeddings.to(torch.float32), len_per_path.cpu(
), path_to_seq_map, num_repeats, new2old


def compute_embeddings(
lats: k2.Fsa,
ctc_topo: k2.Fsa,
dense_fsa_vec: k2.DenseFsaVec,
max_phone_id: int,
use_double_scores: bool = True,
num_paths: int = 3,
debug: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.RaggedInt]:
def compute_embeddings(lats: k2.Fsa,
ctc_topo: k2.Fsa,
dense_fsa_vec: k2.DenseFsaVec,
max_phone_id: int,
use_double_scores: bool = True,
num_paths: int = 3,
debug: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.
RaggedInt, torch.Tensor]:
'''
Args:
lats:
Expand Down Expand Up @@ -476,8 +479,8 @@ def compute_embeddings_deprecated(
max_phone_id: int,
use_double_scores=True,
num_paths=100,
debug=False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.RaggedInt]:
debug=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.
RaggedInt, torch.Tensor]:
'''Compute embeddings for an n-best list.
See the following comments for more information:
Expand Down Expand Up @@ -508,6 +511,7 @@ def compute_embeddings_deprecated(
before padding
- path_to_seq_map, its shape is (num_paths,)
- num_repeats (k2.RaggedInt)
- new2old, a 1-D torch.Tensor, see :func:`k2.ragged.unique_sequences`
'''
device = lats.device
assert len(lats.shape) == 3
Expand Down Expand Up @@ -541,7 +545,8 @@ def compute_embeddings_deprecated(

# Remove repeated sequences from `phone_seqs`
#
phone_seqs, num_repeats = k2.ragged.unique_sequences(phone_seqs, True)
phone_seqs, num_repeats, new2old = k2.ragged.unique_sequences(
phone_seqs, True, True)

# Remove the 1st axis from `phone_seqs` (that corresponds to `seq`) and
# keep it for later, we'll be processing paths separately.
Expand Down Expand Up @@ -781,5 +786,5 @@ def compute_embeddings_deprecated(

# It used `double` for `get_arc_post`, but the network input requires
# torch.float32
return padded_embeddings.to(
torch.float32), len_per_path.cpu(), path_to_seq_map, num_repeats
return padded_embeddings.to(torch.float32), len_per_path.cpu(
), path_to_seq_map, num_repeats, new2old
10 changes: 6 additions & 4 deletions snowfall/training/test_compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main():
den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0)

print('-' * 10, 'den_lats', '-' * 10)
den_padded_embeddings, den_len_per_path, den_path_to_seq, den_num_repeats = compute_embeddings(
den_padded_embeddings, den_len_per_path, den_path_to_seq, den_num_repeats, den_new2old = compute_embeddings(
den_lats,
graph_compiler.ctc_topo,
dense_fsa_vec,
Expand All @@ -116,7 +116,7 @@ def main():

print('den', den_num_repeats)

den_padded_embeddings2, den_len_per_path2, den_path_to_seq2, den_num_repeats2 = compute_embeddings_deprecated(
den_padded_embeddings2, den_len_per_path2, den_path_to_seq2, den_num_repeats2, den_new2old2 = compute_embeddings_deprecated(
den_lats,
graph_compiler.ctc_topo,
dense_fsa_vec,
Expand All @@ -128,9 +128,10 @@ def main():
assert torch.allclose(den_len_per_path, den_len_per_path2)
assert torch.allclose(den_path_to_seq, den_path_to_seq2)
assert str(den_num_repeats) == str(den_num_repeats2)
assert torch.allclose(den_new2old, den_new2old2)

print('-' * 10, 'mbr_lats', '-' * 10)
mbr_padded_embeddings, mbr_len_per_path, mbr_path_to_seq, mbr_num_repeats = compute_embeddings(
mbr_padded_embeddings, mbr_len_per_path, mbr_path_to_seq, mbr_num_repeats, mbr_new2old = compute_embeddings(
mbr_lats,
graph_compiler.ctc_topo,
dense_fsa_vec,
Expand All @@ -157,7 +158,7 @@ def main():
assert mbr_padded_embeddings.requires_grad is True
assert mbr_padded_embeddings.dtype == torch.float32

mbr_padded_embeddings2, mbr_len_per_path2, mbr_path_to_seq2, mbr_num_repeats2 = compute_embeddings_deprecated(
mbr_padded_embeddings2, mbr_len_per_path2, mbr_path_to_seq2, mbr_num_repeats2, mbr_new2old2 = compute_embeddings_deprecated(
mbr_lats,
graph_compiler.ctc_topo,
dense_fsa_vec,
Expand All @@ -168,6 +169,7 @@ def main():
assert torch.allclose(mbr_len_per_path, mbr_len_per_path2)
assert torch.allclose(mbr_path_to_seq, mbr_path_to_seq2)
assert str(mbr_num_repeats) == str(mbr_num_repeats2)
assert torch.allclose(mbr_new2old, mbr_new2old2)

print('mbr', mbr_num_repeats)
print(mbr_padded_embeddings.sum())
Expand Down

0 comments on commit 8e47582

Please sign in to comment.