Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: Compute expected times per pathphone_idx. #106

Open
wants to merge 37 commits into
base: master
Choose a base branch
from

Conversation

csukuangfj
Copy link
Collaborator

@csukuangfj csukuangfj commented Feb 22, 2021

Closes #96

@danpovey Do you have any idea how to test the code? And I am not sure how the return value is used.

I am using mbr_lats instead of den_lats as I find that phone_seqs from mbr_lats contains more zeros than that from den_lats.

I don't quite understand the normalization step from #96 (comment)
so it is not done in the code.

Multiply the posteriors by the frame_idx and then sum over the columns of the sparse matrix of posteriors to get the total (frame_idx * occupation prob) for each seqphone_idx, and divide by the total occupation_prob for each seqphone_idx

paths_lats.pathframe_idx = paths_lats.seqframe_idx + k2.index(
path_offsets, paths_lats_arc2path)

pathframe_to_pathphone = k2.create_sparse(rows=paths_lats.pathframe_idx,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether this is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you perhaps print out paths_pats.pathphone_idx for the two decoding graphs and see if there is an obvious difference, e.g. perhaps one has a lot more zeros or -1's in it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey
Copy link
Contributor

Great! Let me look at this to-morrow.
The normalization step should not be needed if you used the 'phones' attribute with no repeats to create the sparse matrix.
Instead of sum (weight * frame-index) it would be something like sum(weight * frame_index) / sum(weight) if there was normalization. Only necessary if the denominator is not bound to be 1.

@danpovey
Copy link
Contributor

BTW for testing: just making sure that the computed average times are monotonic and not greater than the length of the utterances would be a good start.

@danpovey
Copy link
Contributor

And for how the times are used:

  • We'll probably use the durations as part of an embedding.
  • The times themselves will be used to create embeddings from the output of the 1st-pass network, by interpolating between adjacent frames (whichever are closest to the average times).
  • We'll need to compute times for the positions between the phones (and between the first and last phones, and the start/end of the file), too. The positions/blanks between phones will also have frames corresponding to them.

@csukuangfj
Copy link
Collaborator Author

I just checked the return value, shown below.

Assume there are two sequences and each has 1000 frames.

It is not strictly monotonic for the first 50 pathphone_idx. The last entry is 1000, which is equal to the number of input frames
in the sequence.

Screen Shot 2021-02-23 at 11 19 17 AM

@danpovey
Copy link
Contributor

I think this might be due to normalization problems.. can you print the row-sums of the weights to check that they sum to 1.0?

@csukuangfj
Copy link
Collaborator Author

can you print the row-sums of the weights to check that they sum to 1.0?

An assertion is added to check this and it passes.

sum_per_row = torch.sparse.sum(pathframe_to_pathphone, dim=1).to_dense()
expected_sum_per_row = torch.ones_like(sum_per_row)
assert torch.allclose(sum_per_row, expected_sum_per_row)

If I replace mbr_lats with den_lats, then the first 50 pathphone_idx's also seem to be monotonic. Please
see the results below. (NOTE: It uses different random data from the previous screenshot.)

Screen Shot 2021-02-23 at 2 20 50 PM

@csukuangfj
Copy link
Collaborator Author

Will replace the expected times for the even pathphone_idxs, which belong to the epsilon self-loops, with
the average expected times of two neighboring phones.

@danpovey
Copy link
Contributor

RE "If I replace mbr_lats with den_lats, then the first 50 pathphone_idx's also seem to be monotonic"... for me this is a bit strange. I think we should try to find what is the difference between the two graphs that causes this difference.
There should be no constraint that an epsilon between phones should always be taken, so IMO it should be possible that for the epsilon positions, we would sometimes get total-counts that are less than one. Maybe not for every minibatch, but in general. It's very strange if not.

@csukuangfj
Copy link
Collaborator Author

we would sometimes get total-counts that are less than one

Is total-counts the same as total-occupation?

@csukuangfj
Copy link
Collaborator Author

The total_occupation[:50]'s for den_lats and mbr_lats are given below:

Screen Shot 2021-02-23 at 3 26 37 PM

Screen Shot 2021-02-23 at 3 26 56 PM

@csukuangfj
Copy link
Collaborator Author

csukuangfj commented Feb 23, 2021

Differences between den_lats and mbr_lats

phones

  • The phones attribute of den_graph is from k2.compose(ctc_topo, P_with_self_loops)
  • The phones attribute of decoding_graph is from k2.compose(ctc_topo, LG)

intersect

  • den_lats is generated from k2.intersect_dense(den_graph, dense_fsa_vec, 10.0)
  • mbr_lats is generated from
mbr_lats = k2.intersect_dense_pruned(decoding_graph,
                                     dense_fsa_vec,
                                     20.0,
                                     7.0,
                                     30,
                                     10000,
                                     seqframe_idx_name='seqframe_idx')

# TODO(fangjun): remove print
print('total_occupation[:50]\n', total_occupation[:50])

expected_times = weighted_occupation.squeeze() / total_occupation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be aware that division by zero is a possibility here, for epsilons.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, will fix it by adding an EPS to the denominator.

@danpovey
Copy link
Contributor

I don't see, in the current code at least in my branch I'm working on, anywhere where the 'phones' attribute is set in den_graph.

ctc_topo_P = k2.intersect(self.ctc_topo_inv,
P_with_self_loops,
treat_epsilons_specially=False).invert()
ctc_topo_P = k2.compose(self.ctc_topo,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see, in the current code at least in my branch I'm working on, anywhere where the 'phones' attribute is set in den_graph.

@danpovey

It is set here.

Copy link
Collaborator Author

@csukuangfj csukuangfj Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest k2 is needed.
k2-fsa/k2#670 prevents overwriting num_graph's phones attribute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into it, I suspect something was reversed somewhere.

phone_seqs = k2.index(lats.phones, paths)

# Remove epsilons from `phone_seqs`
print('before removing 0', phone_seqs.shape().row_splits(2))
Copy link
Collaborator Author

@csukuangfj csukuangfj Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shows that phones_seqs from mbr_lats contain more 0s, which are removed.

Screen Shot 2021-02-23 at 4 14 40 PM

Screen Shot 2021-02-23 at 4 14 54 PM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you verify that lats.phones is a Tensor and not a _k2.RaggedInt, in both cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I confirm that they are both 1-D tensors.

num = k2.compose(ctc_topo_P,
num_graphs_with_self_loops,
treat_epsilons_specially=False,
inner_labels='phones')
treat_epsilons_specially=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing inner_labels here so that num inherits the phones attribute from
ctc_topo_P. But it does not affect the result since it uses only mbr_lats and den_lats.

num_graphs_with_self_loops,
treat_epsilons_specially=False,
inner_labels='phones')
print('num2.phones\n', num2.phones)
Copy link
Collaborator Author

@csukuangfj csukuangfj Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey

The outputs are

ctc_topo_P.phones
 tensor([ 0,  1,  2,  3,  0,  0,  2,  3, -1,  0,  0,  1,  3, -1,  0,  0,  1,  2,
        -1,  0,  1,  2,  3, -1,  0,  1,  2,  3, -1,  0,  1,  2,  3, -1],
       dtype=torch.int32)
num1.phones
 tensor([ 0,  3,  0,  0,  1,  2,  0,  1,  2,  0,  0, -1,  0,  0,  1,  0, -1,  0,
         1], dtype=torch.int32)
num2.phones
 tensor([ 0,  3,  0,  0,  1,  2,  0,  1,  2,  0,  0, -1,  0,  0,  1,  0, -1,  0,
         1], dtype=torch.int32)
decoding_graph.phones
 tensor([ 0,  3,  0,  0,  1,  2,  0,  1,  2,  0,  0, -1,  0,  0,  1,  0, -1,  0,
         1], dtype=torch.int32)

ctc_topo_P.phones is equivalent to den_grpah.phones.

mbr_lats.phones is from decoding_graph.phones, though there is no G here.


It shows that ctc_topo_P has more phones than num.

Although both ctc_topo_P.phones and num1.phones have the same
number of 0s, i.e., 10, but num1.phones has a higher percentage for 0s
since it has fewer phones.

Also, ctc_topo_P.phones contains more -1s.

@@ -0,0 +1,97 @@
#!/usr/bin/env python3
Copy link
Collaborator Author

@csukuangfj csukuangfj Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey
Here is the test script. I cannot find any problems in the script.

According to the output, there are no repeated phones in the phone_seqs associated with the n-best paths for both mbr_lats and den_lats.

We can see from the output that a significant portion of entries in the phone_seqs frommbr_lats are 0s.


# paths will be k2.RaggedInt with 3 axes: [seq][path][arc_pos],
# containing arc_idx012
paths = k2.random_paths(lats,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found that paths is empty when lats is mbr_lats and device is cuda.

Will look into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. BTW did you already look at the decoding graph that you used to generate den_lats, and make sure that it has more nonzero labels than nonzero phones?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've added two print statements in this pull-request:

den = k2.index_fsa(ctc_topo_P_vec, indexes)
print('den.phones', den.phones.shape, 'nnz',
torch.count_nonzero(den.phones))
print('den.labels', den.labels.shape, 'nnz',
torch.count_nonzero(den.labels))

and the output is

den.phones torch.Size([30446]) nnz tensor(29928)
den.labels torch.Size([30446]) nnz tensor(30100)

den.phones contain more 0s than den.labels.

@csukuangfj
Copy link
Collaborator Author

Here are the WERs with/without rescoring.

num_paths=10 is used to sample 10 paths from the first pass lattice and rescore them with the second pass model.
A larger value of num_paths results in a failure in k2.intersect_dense_pruned due to limited bits in the hash implementation.

Now the WER of the second pass with rescoring is comparable with that of the first pass.

## Without second pass in decoding
2021-03-17 18:30:33,065 INFO [mmi_bigram_embeddings_decode.py:393] %WER 12.29% [4426 / 36021, 718 ins, 433 del, 3275 sub ]

## With second pass in decoding (Use rescoring)
2021-03-17 18:38:21,445 INFO [mmi_bigram_embeddings_decode.py:393] %WER 12.39% [4462 / 36021, 732 ins, 410 del, 3320 sub ]

## With second pass in decoding (NO rescoring)
2021-03-17 19:23:30,956 INFO [mmi_bigram_embeddings_decode.py:390] %WER 17.21% [6199 / 36021, 739 ins, 854 del, 4606 sub ]

Only the first 140 batches are used to compute the WER. Because the following assert from k2 fails.
https://github.com/k2-fsa/k2/blob/171ddc17509c0de5e9f7dcc4efeed9c712830233/k2/csrc/fsa_utils.cu#L699

K2_CHECK_GT(dest_state, state_idx01);

The following describes the title in the above WER.

  • Without second pass in decoding

    This uses only the first pass model and calls k2.shortest_path to decode.

  • With second pass in decoding (Use rescoring)

    This samples 10 paths from the first pass decoding lattice using k2.random_path and rescores them with
    the second pass model. The sampled path with the highest tot_score after rescoring is used for decoding.

  • With second pass in decoding (NO rescoring)

    It passes the best path from the first pass (computed with k2.shortest_path) to the second pass model
    and gets another best path from the second pass lattice with k2.shortest_path for decoding.


second_pass_dense_fsa_vec = k2.DenseFsaVec(
second_pass_out, second_pass_supervision_segments)

second_pass_lattices = k2.intersect_dense_pruned(
decoding_graph, second_pass_dense_fsa_vec, 20.0, 7.0, 30, 10000)
decoding_graph, second_pass_dense_fsa_vec, 20.0, 7.0, 30, 20000)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line SOMETIMES fails the following check
https://github.com/k2-fsa/k2/blob/171ddc17509c0de5e9f7dcc4efeed9c712830233/k2/csrc/fsa_utils.cu#L699

K2_CHECK_GT(dest_state, state_idx01);

I cannot find the reason since it does not always happen. @danpovey Do you have any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is essentially asserting that the FSA is top-sorted and acyclic. We'd have to consider where the FSA came from... what was the call stack?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try to get the call stack.

@danpovey
Copy link
Contributor

OK, good... in the case "With second pass in decoding (Use rescoring)", it would probaby make the most sense to have the total score be some linear combination of two things:
(i) the total probability of that path in the 1st-pass lattice
(ii) the posterior of that path in the 2nd-pass lattice.

(i) can be represented as the tot-score of (that path intersected with the 1st-pass lattice).
(ii) can be represented as a difference of tot-scores, namely: tot score of (that path intersected with its own 2nd-pass lattice) minus the tot-score of (that path's entire 2nd-pass lattice).

@danpovey
Copy link
Contributor

Oh, and let me know what the issue was with the hash. The size of that hash is extremely large.. it is likely a bug in the code somewhere rather than a collision.

@csukuangfj
Copy link
Collaborator Author

Oh, and let me know what the issue was with the hash. The size of that hash is extremely large.. it is likely a bug in the code somewhere rather than a collision.

A larger value of num_paths will trigger the following check (I've reduced max_frames):
https://github.com/k2-fsa/k2/blob/171ddc17509c0de5e9f7dcc4efeed9c712830233/k2/csrc/intersect_dense_pruned.cu#L697

      K2_CHECK_EQ(cur_frame->arcs.NumElements() >> shift, 0) <<
          "Too many arcs to store in hash; try smaller NUM_KEY_BITS (would "
          "require code change) or reduce max_states or minibatch size.";

@danpovey
Copy link
Contributor

OK, this is a different part of the code from what I had in mind. I want to see whether it's in the branch using 40 or 32 bits.

@danpovey
Copy link
Contributor

.. it may actually be visible from the printed log, if it prints template args.

@csukuangfj
Copy link
Collaborator Author

(ii) can be represented as a difference of tot-scores, namely: tot score of (that path intersected with its own 2nd-pass lattice) minus the tot-score of (that path's entire 2nd-pass lattice).

Thanks, will implement it.

word_lats = k2.compose(replicated_lats,
word_fsas,
treat_epsilons_specially=False)
tot_scores_1st = word_lats.get_tot_scores(use_double_scores=True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #106 (comment)

It is this line that somtimes causes the check failure from https://github.com/k2-fsa/k2/blob/171ddc17509c0de5e9f7dcc4efeed9c712830233/k2/csrc/fsa_utils.cu#L699

K2_CHECK_GT(dest_state, state_idx01);

Here is the call stack from the Python side

Traceback (most recent call last):
  File "./mmi_bigram_embeddings_decode.py", line 404, in <module>
    main()
  File "./mmi_bigram_embeddings_decode.py", line 370, in main
    results = decode(dataloader=test_dl,
  File "./mmi_bigram_embeddings_decode.py", line 97, in decode
    tot_scores_1st = word_lats.get_tot_scores(use_double_scores=True,
  File "/root/fangjun/open-source/k2/k2/python/k2/fsa.py", line 598, in get_tot_scores
    tot_scores = k2.autograd._GetTotScoresFunction.apply(
  File "/root/fangjun/open-source/k2/k2/python/k2/autograd.py", line 49, in forward
    tot_scores = fsas._get_tot_scores(use_double_scores=use_double_scores,
  File "/root/fangjun/open-source/k2/k2/python/k2/fsa.py", line 577, in _get_tot_scores
    forward_scores = self._get_forward_scores(use_double_scores,
  File "/root/fangjun/open-source/k2/k2/python/k2/fsa.py", line 526, in _get_forward_scores
    state_batches=self._get_state_batches(),
  File "/root/fangjun/open-source/k2/k2/python/k2/fsa.py", line 433, in _get_state_batches
    cache[name] = _k2.get_state_batches(self.arcs, transpose=True)

Regarding

That is essentially asserting that the FSA is top-sorted and acyclic. We'd have to consider where the FSA came from

The FsaVec word_lats is from k2.compose(first_pass_lats, word_fsas),
where word_fsas is from

word_fsas = k2.linear_fsa(word_seqs)
word_fsas_with_epsilons = k2.add_epsilon_self_loops(word_fsas)
return word_fsas_with_epsilons, seq_to_path_shape

After printing the properties of the problematic word_lats, it shows

"Valid|Nonempty|MaybeAccessible"

which is different from the normal word_lats with the following properties

"Valid|Nonempty|TopSorted|TopSortedAndAcyclic|MaybeAccessible"

I am looking into it.

@danpovey
Copy link
Contributor

OK. We should have caught that error earlier, at the point when it was clear that the properties were not as expected. Perhaps the properties should have been passed into the C++ function, or were passed and were not checked.
I think that composition algorithm can generate non-top-sorted output; we may need to manually top-sort it afterward.

@csukuangfj
Copy link
Collaborator Author

After top-sorting the word_lats, it's able to decode the whole dataset. The WERs are listed below.
You can see that the WER of the second pass with rescoring is comparable with that of the first pass.

## Without second pass in decoding
2021-03-18 13:06:04,988 INFO [mmi_bigram_embeddings_decode.py:394] %WER 12.16% [6394 / 52576, 1072 ins, 625 del, 4697 sub ]

## With second pass in decoding (Use rescoring)
2021-03-18 13:37:36,049 INFO [mmi_bigram_embeddings_decode.py:396] %WER 12.21% [6418 / 52576, 1103 ins, 579 del, 4736 sub ]

## With second pass in decoding (NO rescoring)
2021-03-18 13:21:22,276 INFO [mmi_bigram_embeddings_decode.py:394] %WER 17.01% [8943 / 52576, 1068 ins, 1181 del, 6694 sub ]

Trying to implement the following:

OK, good... in the case "With second pass in decoding (Use rescoring)", it would probaby make the most sense to have the total score be some linear combination of two things:

@danpovey
Copy link
Contributor

danpovey commented Mar 18, 2021 via email

@csukuangfj
Copy link
Collaborator Author

(ii) can be represented as a difference of tot-scores, namely: tot score of (that path intersected with its own 2nd-pass lattice) minus the tot-score of (that path's entire 2nd-pass lattice).

For path intersected with its own 2nd-pass lattice, I am using k2.compose(second_pass_lats, word_fsas)
and sometimes the tot_score of some Fsa in the resulting FsaVec is -inf. Moreover, the tot_scoress of all paths belonging to a seq may all be -inf. argmax_per_sublist returns -1 in this case. My current approach is to always select
the first path in the case of -1.

I've double-checked that there are no empty Fsas and the sum of all scores is neither -inf nor nan.
Don't know why tot_scores would return -inf in this case.

@danpovey
Copy link
Contributor

(ii) can be represented as a difference of tot-scores, namely: tot score of (that path intersected with its own 2nd-pass lattice) minus the tot-score of (that path's entire 2nd-pass lattice).

For path intersected with its own 2nd-pass lattice, I am using k2.compose(second_pass_lats, word_fsas)
and sometimes the tot_score of some Fsa in the resulting FsaVec is -inf.

Since the 2nd-pass lattice is generated with pruning, I suppose this is expected to happen sometimes.

Moreover, the tot_scoress of all paths belonging to a seq may all be -inf. argmax_per_sublist returns -1 in this case. My current approach is to always select
the first path in the case of -1.

OK... so for each path, its probability for "itself" is zero because of pruning.. mm... this seems a bit unusual. It should be possible to print out the "diagonal" probability sequence, that is, at each position, the probability of the reference label. I'm curious whether a particular position in that sequence is extremely low, e.g. at the start or end?

I've double-checked that there are no empty Fsas and the sum of all scores is neither -inf nor nan.
Don't know why tot_scores would return -inf in this case.

@danpovey
Copy link
Contributor

By " the probability of the reference label".. that is something that it should be possible to look up in the DenseFsaVec or its associated Tensor containing the scores.

@csukuangfj
Copy link
Collaborator Author

It should be possible to print out the "diagonal" probability sequence, that is, at each position, the probability of the reference label.

Do you mean to print out a 2-d matrix, in which the rows represent frame_id, the cols represent phones, and the matrix
contains nnet_out_2nd_pass[a_given_path_id, frame_id, phones]? Here the phones are limited to those that appear on the path.
For example, if the path consists of phone_seq [c, a, t], then the cols should be [blank, c, a, t].

@danpovey
Copy link
Contributor

danpovey commented Mar 19, 2021 via email

import torch


def get_log_probs(phone_fsas: k2.Fsa, nnet_output: torch.Tensor,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey

If I understand correctly, this is how to compute the log-prob for the reference input phone labels. Correct me if I am wrong.

The input phone_fsas contains epsilon self-loops and nnet_output
is the output of the second pass model after log-softmax.

@@ -191,10 +202,31 @@ def rescore(lats: k2.Fsa,
tot_scores_2nd_num = reorded_lats.get_tot_scores(
use_double_scores=True, log_semiring=True)

for k in [0, 1, 2, 30, 40, 50]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output log of this for loop is attached below:

log-decode-second-2021-03-19-21-25-09.txt

A part of them are listed as follows:

2021-03-19 21:25:30,611 INFO [rescore.py:209] 
path: 0
tot_scores: -inf
log_probs:[ [ -86.4828 -2.01945 -11.1086 -133.909 -8.51447 -4.66334  .....  -6.37545 -2.73227 -45.2354 -5.78217 ] ]

2021-03-19 21:25:30,612 INFO [rescore.py:209] 
path: 1
tot_scores: -inf
log_probs:[ [ -87.2116 -1.99367 -8.34123 -134.502 -8.31313 ....   -2.12719 -6.3754 -2.41306 -27.2158 -5.17843 ] ]

2021-03-19 21:25:30,612 INFO [rescore.py:209] 
path: 2
tot_scores: -inf
log_probs:[ [ -87.5116 -2.04247 -8.2759 -134.616 -8.62277 ..... -5.99715 -2.10805 -5.69009 -1.49818 -44.5129 -5.55875 ] ]

2021-03-19 21:25:30,613 INFO [rescore.py:209] 
path: 30
tot_scores: -358.8836602834303
log_probs:[ [ -166.046 -3.19678 -5.97864 -219.513 -6.48385  ...  -1.07081 -4.60576 -2.2935 -0.218605 -5.64143 ] ]

2021-03-19 21:25:30,613 INFO [rescore.py:209] 
path: 40
tot_scores: -401.60242305657346
log_probs:[ [ -112.147 -2.33504 -3824.3 -177.408 -9.14423 -3678.67 -6.93939 .... -2.56483 -1.64079 -4.76889 ] ]

2021-03-19 21:25:30,613 INFO [rescore.py:209] 
path: 50
tot_scores: -398.0919782588134
log_probs:[ [ -138.47 -3.37409 -5.20867 -219.942 -7.50989 -1.4931  ....   -5.48529 -2.04127 -2.47733 -4.61414 ] ]

I am not sure whether these log-probs look reasonable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numbers I was expecting would be quite close to zero, and even closer at odd positions (or maybe even.. i.e. where there are epsilons). I.e. I mean the posterior of the "reference phone" at each position (it's not really the reference, it's the sequence we use for alignment0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting you'd get it by indexing second_pass_dense_fsa_vec with some kind of tensor that's related to the reference phones. Or perhaps you could just take the sum over a particular axis, of (second_pass_dense_fsa_vec * phone_one_hot_input).

for idx, row in enumerate(this_fsa_nnet_output):
if idx >= len_this_fsa:
break
this_prob.append(row[labels[idx]].item())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a +1 here? I thought there was a shift, to map -1 to 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. also, I'm not sure if we may have a problem at the start of the sequence, due to there being no epsilon there??
I'm not sure where phone_fsas comes from, i.e. whether there were epsilons between each phone before adding the epsilon self-loops (I assume not), and whether there are epsilons at the start and/or the end in the sequence given as input to the network.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a +1 here? I thought there was a shift, to map -1 to 0.

The input to the second pass model shifts -1 (i.e., EOS) to 0. The output of the second pass model contains
only blank + phone_ids, no EOS.

I should have skipped the last label -1 of every path since there is no corresponding output for it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure where phone_fsas comes from, i.e. whether there were epsilons between each phone before adding the epsilon self-loops (I assume not)

The phone_fsas is created following k2-fsa/k2#641 (comment)
For example, if the phone_seqs is [c, a, t], then phone_fsas is [0, c, 0, a, 0, t, 0, -1].

Before adding the epsilon self-loop, there were no epsilons between phones. It follows
the comment in k2-fsa/k2#641 (comment)

     phone_seqs = k2.index(mbr_lats.phones, paths)
     # Remove epsilons from `phone_seqs`
     phone_seqs = k2.ragged.remove_values_eq(phones, 0)

, and whether there are epsilons at the start and/or the end in the sequence given as input to the network.

For the second pass network, there is an epsilon at the start of each sequence and there is an EOS at
the end of each sequence.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm, there must be some shift for which those probabilities are all close to zero. Perhaps if you print out the best-path label sequences, including epsilons, from the second pass it would be clear what it is doing, e.g. does it look like [ 0 nonzero 0 nonzero ... ?]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it look like [ 0 nonzero 0 nonzero ... ?]

Thanks, will check it.

@danpovey
Copy link
Contributor

danpovey commented Mar 19, 2021 via email

@csukuangfj
Copy link
Collaborator Author

But it's possible that the network may end up
learning a shift, i.e. left by one or right by one.

Shall we discard EOS for the second pass model as there is no such symbol in the first pass network?
I found that the current second pass model has about 1.38 times more parameters than that of the first pass (16193087/11676242=1.38). I suspect that there is not enough data to train the second pass network.

@danpovey
Copy link
Contributor

BTW, since it seems this is hard to get to work, if you feel like it you could work on a simpler idea.
In training time we'd take the best-path alignment and using some kind of RNN or masked attention
we'd predict the next label in that best-path. (We'd probably take in the output of the 1st network
as an input to that). The sequence length here is the same as the original sequence length.
In test time the way this would work at least initially, is we'd run this on n-best lists obtained from the
1st-pass decoding and use the scores to decide which of the n-best paths to keep. There are more accurate
decoding methods we could look into later.

@csukuangfj
Copy link
Collaborator Author

I've made the following changes to the second pass network:
(1) Replace expected_time with duration in the embeddings
(2) Remove EOS in embeddings
(3) Reduce the model size of the second pass network(# parameters: from 16193087 to 361943). Its left/right context is also increased from 8 to 40. The training/valid objf is lower than before and the training is still on-going.


if you feel like it you could work on a simpler idea.

Yes, I would like to do it. Looking into it.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

getting times per phone in n-best list
2 participants