diff --git a/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_decode.py b/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_decode.py index 2f5da604..a0684c3c 100755 --- a/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_decode.py @@ -16,6 +16,7 @@ from lhotse import CutSet from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.utils import fix_random_seed from snowfall.common import find_first_disambig_symbol from snowfall.common import get_texts from snowfall.common import invert_permutation @@ -85,8 +86,7 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, # lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) else: - # FIXME(fangjun): increase num_paths and fix `num_repeats` in rescore.py - paths = get_paths(lats=lattices, num_paths=1) + paths = get_paths(lats=lattices, num_paths=10) word_fsas, seq_to_path_shape = get_word_fsas(lattices, paths) replicated_lats = k2.index(lattices, seq_to_path_shape.row_ids(1)) word_lats = k2.compose(replicated_lats, @@ -184,10 +184,7 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) - - # FIXME(fangjun): remove it - if batch_idx == 50: - break + if batch_idx == 140: break return results @@ -276,6 +273,8 @@ def main(): else: setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') + fix_random_seed(42) + logging.info(f'enable second pass model for decoding: {args.enable_second_pass_decoding}') # load L, G, symbol_table diff --git a/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_train.py b/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_train.py index 4651884d..93257e5b 100755 --- a/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_bigram_embeddings_train.py @@ -162,7 +162,7 @@ def get_objf(batch: Dict, # Now compute the tot_scores for the second pass # # TODO(fangjun): We probably need to split it into a separate function - padded_embeddings, len_per_path, path_to_seq, num_repeats = compute_embeddings( + padded_embeddings, len_per_path, path_to_seq, num_repeats, _ = compute_embeddings( den_lats, graph_compiler.ctc_topo, dense_fsa_vec_2nd, diff --git a/snowfall/common.py b/snowfall/common.py index db49334d..e9c6a594 100755 --- a/snowfall/common.py +++ b/snowfall/common.py @@ -21,7 +21,7 @@ def setup_logger(log_filename: Pathlike, log_level: str = 'info', use_console: bool = True) -> None: now = datetime.now() date_time = now.strftime('%Y-%m-%d-%H-%M-%S') - log_filename = '{}-{}'.format(log_filename, date_time) + log_filename = '{}-{}.txt'.format(log_filename, date_time) os.makedirs(os.path.dirname(log_filename), exist_ok=True) formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s' level = logging.ERROR diff --git a/snowfall/decoding/rescore.py b/snowfall/decoding/rescore.py index def5849e..80b30d5b 100644 --- a/snowfall/decoding/rescore.py +++ b/snowfall/decoding/rescore.py @@ -117,7 +117,9 @@ def rescore(lats: k2.Fsa, # 0 <= path_to_seq[i] < num_seqs # # num_repeats is a k2.RaggedInt with two axes [seq][path_multiplicities] - padded_embeddings, len_per_path, path_to_seq, num_repeats = compute_embeddings_from_phone_seqs( + # + # CAUTION: Paths within a seq are reordered due to `k2.ragged.unique_sequences`. + padded_embeddings, len_per_path, path_to_seq, num_repeats, new2old = compute_embeddings_from_phone_seqs( phone_seqs=phone_seqs, ctc_topo=ctc_topo, dense_fsa_vec=dense_fsa_vec, @@ -142,31 +144,36 @@ def rescore(lats: k2.Fsa, indices2 = torch.argsort(len_per_path, descending=True) second_pass_supervision_segments = second_pass_supervision_segments[ indices2] - # note that path_to_seq is not changed! - # no need to modify second_pass_out + # Note that path_to_seq is not changed! + # No need to modify second_pass_out num_repeats_float = k2.ragged.RaggedFloat( num_repeats.shape(), num_repeats.values().to(torch.float32)) path_weight = k2.ragged.normalize_scores(num_repeats_float, use_log=False).values - path_weight = path_weight[indices2] second_pass_dense_fsa_vec = k2.DenseFsaVec( second_pass_out, second_pass_supervision_segments) second_pass_lattices = k2.intersect_dense_pruned( - decoding_graph, second_pass_dense_fsa_vec, 20.0, 7.0, 30, 10000) + decoding_graph, second_pass_dense_fsa_vec, 20.0, 7.0, 30, 20000) + + # second_pass_lattices = k2.intersect_dense(decoding_graph, + # second_pass_dense_fsa_vec, 10.0) tot_scores = second_pass_lattices.get_tot_scores( log_semiring=True, use_double_scores=use_double_scores) inverted_indices2 = invert_permutation(indices2) tot_scores_2nd = tot_scores[inverted_indices2] - # now tot_scores_2nd[i] corresponds to path_i + # Now tot_scores_2nd[i] corresponds to sorted_path_i + # `sorted` here is due to k2.ragged.unique_sequences. + # We have to use `new2old` to map it to the original unsorted path - # FIXME(fangjun): Handle the case when num_repeats contains entries > 1 - tot_scores = tot_scores_1st + tot_scores_2nd + # Note that path_weight was not reordered + tot_scores = tot_scores_1st + tot_scores[new2old.long()] += tot_scores_2nd * path_weight ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) diff --git a/snowfall/training/compute_embeddings.py b/snowfall/training/compute_embeddings.py index ad784aec..2b6847e8 100644 --- a/snowfall/training/compute_embeddings.py +++ b/snowfall/training/compute_embeddings.py @@ -73,7 +73,7 @@ def compute_expected_times( dense_fsa_vec: k2.DenseFsaVec, use_double_scores: bool = True, debug: bool = False -) -> Tuple[torch.Tensor, k2.Fsa, torch.Tensor, k2.RaggedInt]: +) -> Tuple[torch.Tensor, k2.Fsa, torch.Tensor, k2.RaggedInt, torch.Tensor]: ''' Args: phone_seqs: @@ -96,11 +96,13 @@ def compute_expected_times( - phone_fsas, an FsaVec with indexes [path][phones] - path_to_seq_map, 1-D torch.Tensor - num_repeats, a k2.RaggedInt with 2 axes [path][multiplicities] + - new2old, a 1-D torch.Tensor, see :func:`k2.ragged.unique_sequences` ''' device = ctc_topo.device if phone_seqs.num_axes() == 3: - phone_seqs, num_repeats = k2.ragged.unique_sequences(phone_seqs, True) + phone_seqs, num_repeats, new2old = k2.ragged.unique_sequences( + phone_seqs, True, True) # Remove the 1st axis from `phone_seqs` (that corresponds to `seq`) and # keep it for later; we'll be processing paths separately. @@ -128,6 +130,7 @@ def compute_expected_times( seq_to_path_shape = num_repeats_shape path_to_seq_map = seq_to_path_shape.row_ids(1) # an identity map + new2old = path_to_seq_map # now compile decoding graphs corresponding to `phone_seqs` by constructing # fsas from them (remember we already have the final -1's!) and composing @@ -242,7 +245,7 @@ def compute_expected_times( # TODO(fangjun): do we need to support `torch.int32` for the indexing expected_times[first_epsilon_offset[:-1].long()] = 0 - return expected_times, phone_fsas, path_to_seq_map, num_repeats + return expected_times, phone_fsas, path_to_seq_map, num_repeats, new2old def compute_embeddings_from_nnet_output(expected_times: torch.Tensor, @@ -358,8 +361,8 @@ def compute_embeddings_from_phone_seqs( dense_fsa_vec: k2.DenseFsaVec, max_phone_id: int, use_double_scores: bool = True, - debug: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.RaggedInt]: + debug: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch. + Tensor, k2.RaggedInt, torch.Tensor]: ''' Args: phone_seqs: @@ -384,7 +387,7 @@ def compute_embeddings_from_phone_seqs( - num_repeats, a ragged tensor of type k2.RaggedInt with 2 axes [path][multiplicities] ''' - expected_times, phone_fsas, path_to_seq_map, num_repeats = compute_expected_times( # noqa + expected_times, phone_fsas, path_to_seq_map, num_repeats, new2old = compute_expected_times( # noqa phone_seqs=phone_seqs, ctc_topo=ctc_topo, dense_fsa_vec=dense_fsa_vec, @@ -420,19 +423,19 @@ def compute_embeddings_from_phone_seqs( padded_embeddings = torch.nn.utils.rnn.pad_sequence(embeddings_per_path, batch_first=True) - return padded_embeddings.to( - torch.float32), len_per_path.cpu(), path_to_seq_map, num_repeats + return padded_embeddings.to(torch.float32), len_per_path.cpu( + ), path_to_seq_map, num_repeats, new2old -def compute_embeddings( - lats: k2.Fsa, - ctc_topo: k2.Fsa, - dense_fsa_vec: k2.DenseFsaVec, - max_phone_id: int, - use_double_scores: bool = True, - num_paths: int = 3, - debug: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.RaggedInt]: +def compute_embeddings(lats: k2.Fsa, + ctc_topo: k2.Fsa, + dense_fsa_vec: k2.DenseFsaVec, + max_phone_id: int, + use_double_scores: bool = True, + num_paths: int = 3, + debug: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2. + RaggedInt, torch.Tensor]: ''' Args: lats: @@ -476,8 +479,8 @@ def compute_embeddings_deprecated( max_phone_id: int, use_double_scores=True, num_paths=100, - debug=False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2.RaggedInt]: + debug=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, k2. + RaggedInt, torch.Tensor]: '''Compute embeddings for an n-best list. See the following comments for more information: @@ -508,6 +511,7 @@ def compute_embeddings_deprecated( before padding - path_to_seq_map, its shape is (num_paths,) - num_repeats (k2.RaggedInt) + - new2old, a 1-D torch.Tensor, see :func:`k2.ragged.unique_sequences` ''' device = lats.device assert len(lats.shape) == 3 @@ -541,7 +545,8 @@ def compute_embeddings_deprecated( # Remove repeated sequences from `phone_seqs` # - phone_seqs, num_repeats = k2.ragged.unique_sequences(phone_seqs, True) + phone_seqs, num_repeats, new2old = k2.ragged.unique_sequences( + phone_seqs, True, True) # Remove the 1st axis from `phone_seqs` (that corresponds to `seq`) and # keep it for later, we'll be processing paths separately. @@ -781,5 +786,5 @@ def compute_embeddings_deprecated( # It used `double` for `get_arc_post`, but the network input requires # torch.float32 - return padded_embeddings.to( - torch.float32), len_per_path.cpu(), path_to_seq_map, num_repeats + return padded_embeddings.to(torch.float32), len_per_path.cpu( + ), path_to_seq_map, num_repeats, new2old diff --git a/snowfall/training/test_compute_embeddings.py b/snowfall/training/test_compute_embeddings.py index c55e97cc..66b680b9 100755 --- a/snowfall/training/test_compute_embeddings.py +++ b/snowfall/training/test_compute_embeddings.py @@ -89,7 +89,7 @@ def main(): den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0) print('-' * 10, 'den_lats', '-' * 10) - den_padded_embeddings, den_len_per_path, den_path_to_seq, den_num_repeats = compute_embeddings( + den_padded_embeddings, den_len_per_path, den_path_to_seq, den_num_repeats, den_new2old = compute_embeddings( den_lats, graph_compiler.ctc_topo, dense_fsa_vec, @@ -116,7 +116,7 @@ def main(): print('den', den_num_repeats) - den_padded_embeddings2, den_len_per_path2, den_path_to_seq2, den_num_repeats2 = compute_embeddings_deprecated( + den_padded_embeddings2, den_len_per_path2, den_path_to_seq2, den_num_repeats2, den_new2old2 = compute_embeddings_deprecated( den_lats, graph_compiler.ctc_topo, dense_fsa_vec, @@ -128,9 +128,10 @@ def main(): assert torch.allclose(den_len_per_path, den_len_per_path2) assert torch.allclose(den_path_to_seq, den_path_to_seq2) assert str(den_num_repeats) == str(den_num_repeats2) + assert torch.allclose(den_new2old, den_new2old2) print('-' * 10, 'mbr_lats', '-' * 10) - mbr_padded_embeddings, mbr_len_per_path, mbr_path_to_seq, mbr_num_repeats = compute_embeddings( + mbr_padded_embeddings, mbr_len_per_path, mbr_path_to_seq, mbr_num_repeats, mbr_new2old = compute_embeddings( mbr_lats, graph_compiler.ctc_topo, dense_fsa_vec, @@ -157,7 +158,7 @@ def main(): assert mbr_padded_embeddings.requires_grad is True assert mbr_padded_embeddings.dtype == torch.float32 - mbr_padded_embeddings2, mbr_len_per_path2, mbr_path_to_seq2, mbr_num_repeats2 = compute_embeddings_deprecated( + mbr_padded_embeddings2, mbr_len_per_path2, mbr_path_to_seq2, mbr_num_repeats2, mbr_new2old2 = compute_embeddings_deprecated( mbr_lats, graph_compiler.ctc_topo, dense_fsa_vec, @@ -168,6 +169,7 @@ def main(): assert torch.allclose(mbr_len_per_path, mbr_len_per_path2) assert torch.allclose(mbr_path_to_seq, mbr_path_to_seq2) assert str(mbr_num_repeats) == str(mbr_num_repeats2) + assert torch.allclose(mbr_new2old, mbr_new2old2) print('mbr', mbr_num_repeats) print(mbr_padded_embeddings.sum())