Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Implement n-best decoding without an LM. (#213)
Browse files Browse the repository at this point in the history
* Implement n-best decoding without an LM.
  • Loading branch information
csukuangfj authored Jun 15, 2021
1 parent ac77916 commit 187ae11
Showing 1 changed file with 148 additions and 4 deletions.
152 changes: 148 additions & 4 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@
# 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0

# Usage of this script:
'''
# Without LM rescoring
## Use n-best decoding
./mmi_att_transformer_decode.py \
--use-lm-rescoring=0 \
--num-paths=100 \
--max-duration=300
## Use 1-best decoding
./mmi_att_transformer_decode.py \
--use-lm-rescoring=0 \
--num-paths=1 \
--max-duration=300
# With LM rescoring
## Use whole lattice
./mmi_att_transformer_decode.py \
--use-lm-rescoring=1 \
--num-paths=-1 \
--max-duration=300
## Use n-best list
./mmi_att_transformer_decode.py \
--use-lm-rescoring=1 \
--num-paths=100 \
--max-duration=300
'''

import argparse
import k2
import logging
Expand Down Expand Up @@ -41,6 +73,110 @@
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import get_phone_symbols

def nbest_decoding(lats: k2.Fsa, num_paths: int):
'''
(Ideas of this function are from Dan)
It implements something like CTC prefix beam search using n-best lists
The basic idea is to first extra n-best paths from the given lattice,
build a word seqs from these paths, and compute the total scores
of these sequences in the log-semiring. The one with the max score
is used as the decoding output.
'''

# First, extract `num_paths` paths for each sequence.
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

# word_seqs is a k2.RaggedInt sharing the same shape as `paths`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.

word_seqs = k2.index(lats.aux_labels, paths)
# Note: the above operation supports also the case when
# lats.aux_labels is a ragged tensor. In that case,
# `remove_axis=True` is used inside the pybind11 binding code,
# so the resulting `word_seqs` still has 3 axes, like `paths`.
# The 3 axes are [seq][path][word]

# Remove epsilons and -1 from word_seqs
word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

# Remove repeated sequences to avoid redundant computation later.
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.num_elements()
unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
word_seqs, need_num_repeats=False, need_new2old_indexes=True)
# Note: unique_word_seqs still has the same axes as word_seqs

seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path
# belongs.
path_to_seq_map = seq_to_path_shape.row_ids(1)

# Remove the seq axis.
# Now unique_word_seqs has only two axes [path][word]
unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

# word_fsas is an FsaVec with axes [path][state][arc]
word_fsas = k2.linear_fsa(unique_word_seqs)

word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

# lats has phone IDs as labels and word IDs as aux_labels.
# inv_lats has word IDs as labels and phone IDs as aux_labels
inv_lats = k2.invert(lats)
inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted

path_lats = k2.intersect_device(inv_lats,
word_fsas_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True)
# path_lats has word IDs as labels and phone IDs as aux_labels

path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device))

tot_scores = path_lats.get_tot_scores(True, True)
# RaggedFloat currently supports float32 only.
# We may bind Ragged<double> as RaggedDouble if needed.
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
tot_scores.to(torch.float32))

argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

# Since we invoked `k2.ragged.unique_sequences`, which reorders
# the index from `paths`, we use `new2old`
# here to convert argmax_indexes to the indexes into `paths`.
#
# Use k2.index here since argmax_indexes' dtype is torch.int32
best_path_indexes = k2.index(new2old, argmax_indexes)

paths_2axes = k2.ragged.remove_axis(paths, 0)

# best_paths is a k2.RaggedInt with 2 axes [path][arc_pos]
best_paths = k2.index(paths_2axes, best_path_indexes)

# labels is a k2.RaggedInt with 2 axes [path][phone_id]
# Note that it contains -1s.
labels = k2.index(lats.labels.contiguous(), best_paths)

labels = k2.ragged.remove_values_eq(labels, -1)

# lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes
aux_labels = k2.index(lats.aux_labels, best_paths.values())

best_path_fsas = k2.linear_fsa(labels)
best_path_fsas.aux_labels = aux_labels

return best_path_fsas


def decode_one_batch(batch: Dict[str, Any],
model: AcousticModel,
Expand Down Expand Up @@ -79,8 +215,7 @@ def decode_one_batch(batch: Dict[str, Any],
If False and if `G` is not None, then `num_paths` must be positive
and it will use n-best list for LM rescoring.
num_paths:
Used only if `G` is not None and use_whole_lattice is False.
It specifies the size of n-best list for LM rescoring.
It specifies the size of `n` in n-best list decoding.
G:
The LM. If it is None, no rescoring is used.
Otherwise, LM rescoring is used.
Expand Down Expand Up @@ -123,9 +258,14 @@ def decode_one_batch(batch: Dict[str, Any],
lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000)

if G is None:
best_paths = k2.shortest_path(lattices, use_double_scores=True)
if num_paths > 1:
best_paths = nbest_decoding(lattices, num_paths)
key=f'no_rescore-{num_paths}'
else:
key = 'no_rescore'
best_paths = k2.shortest_path(lattices, use_double_scores=True)
hyps = get_texts(best_paths, indices)
return {'no_rescore': hyps}
return {key: hyps}

lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
Expand Down Expand Up @@ -405,6 +545,10 @@ def main():
else:
logging.debug('Decoding without LM rescoring')
G = None
if num_paths > 1:
logging.debug(f'Use n-best list decoding, n is {num_paths}')
else:
logging.debug('Use 1-best decoding')

logging.debug("convert HLG to device")
HLG = HLG.to(device)
Expand Down

0 comments on commit 187ae11

Please sign in to comment.