Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
add attribute phones to den_lats.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Feb 23, 2021
1 parent fa0a0cd commit f6d3dcb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
5 changes: 5 additions & 0 deletions snowfall/training/compute_expected_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def compute_expected_times_per_phone(mbr_lats: k2.Fsa,
True, True).exp(),
min_col_index=0)

# TODO(fangjun): this check is for test, will remove it
sum_per_row = torch.sparse.sum(pathframe_to_pathphone, dim=1).to_dense()
expected_sum_per_row = torch.ones_like(sum_per_row)
assert torch.allclose(sum_per_row, expected_sum_per_row)

frame_idx = torch.arange(paths_shape.num_elements()) - k2.index(
path_starts, paths_shape.row_ids(1))

Expand Down
7 changes: 4 additions & 3 deletions snowfall/training/mmi_mbr_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ def compile(self, texts: Iterable[str],
assert P.device == self.device
P_with_self_loops = k2.add_epsilon_self_loops(P)

ctc_topo_P = k2.intersect(self.ctc_topo_inv,
P_with_self_loops,
treat_epsilons_specially=False).invert()
ctc_topo_P = k2.compose(self.ctc_topo,
P_with_self_loops,
treat_epsilons_specially=False,
inner_labels='phones')
ctc_topo_P = k2.arc_sort(ctc_topo_P)

num_graphs = self.build_num_graphs(texts)
Expand Down

0 comments on commit f6d3dcb

Please sign in to comment.