Skip to content

Commit

Permalink
Add merge_tokens / TokenSpan (#3535)
Browse files Browse the repository at this point in the history
Summary:
This commit adds `merge_tokens` function which removes repeated tokens from CTC token sequences returned from `forced_align`.

Resolving repeated tokens is a necessary step and almost universal, thus it makes sense to have such helper function in torchaudio.

Pull Request resolved: #3535

Reviewed By: huangruizhe

Differential Revision: D48111202

Pulled By: mthrok

fbshipit-source-id: 25354bfa210aa5c03f8c1d3e201f253ca3761b24
  • Loading branch information
mthrok authored and facebook-github-bot committed Aug 7, 2023
1 parent cd80976 commit 30668af
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ Utility
preemphasis
deemphasis
speed

Forced Alignment
----------------
.. autosummary::
:toctree: generated
:nosignatures:

forced_align
merge_tokens
TokenSpan


Filtering
Expand Down
62 changes: 62 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,68 @@ def test_forced_align_fail(self, targets_dtype):
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

def _assert_tokens(self, first, second):
assert len(first) == len(second)

for f, s in zip(first, second):
self.assertEqual(f.token, s.token)
self.assertEqual(f.score, s.score)
self.assertEqual(f.start, s.start)
self.assertEqual(f.end, s.end)

@parameterized.expand(
[
([], [], []),
([F.TokenSpan(1, 0, 1, 1.0)], [1], [1.0]),
([F.TokenSpan(1, 0, 2, 0.5)], [1, 1], [0.4, 0.6]),
([F.TokenSpan(1, 0, 3, 0.6)], [1, 1, 1], [0.5, 0.6, 0.7]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 1, 2, 0.9)], [1, 2], [0.8, 0.9]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 1, 3, 0.5)], [1, 2, 2], [1.0, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(1, 2, 3, 1.0)], [1, 0, 1], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 2, 3, 1.0)], [1, 0, 2], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 2, 4, 0.5)], [1, 0, 1, 1], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 2, 4, 0.5)], [1, 0, 2, 2], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 4, 0.4)], [1, 0, 0, 1], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 4, 0.4)], [1, 0, 0, 2], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 5, 0.5)], [1, 0, 0, 1, 1], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 5, 0.5)], [1, 0, 0, 2, 2], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 2, 3, 0.5)], [1, 1, 2], [1.0, 0.8, 0.5]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 4, 0.7)], [1, 1, 0, 1], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 4, 0.7)], [1, 1, 0, 2], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 5, 0.4)], [1, 1, 0, 1, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 5, 0.4)], [1, 1, 0, 2, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 5, 0.3)], [1, 1, 0, 0, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 5, 0.3)], [1, 1, 0, 0, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 6, 0.2)],
[1, 1, 0, 0, 1, 1],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 6, 0.2)],
[1, 1, 0, 0, 2, 2],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
]
)
def test_merge_repeated_tokens(self, expected, tokens, scores):
scores_ = torch.tensor(scores, dtype=torch.float32, device=self.device)
tokens_ = torch.tensor(tokens, dtype=torch.int64, device=self.device)
spans = F.merge_tokens(tokens_, scores_, blank=0)
print(tokens_, scores_)
self._assert_tokens(spans, expected)

# Append blanks at the beginning and at the end.
for num_prefix, num_suffix in itertools.product([0, 1, 2], repeat=2):
tokens_ = ([0] * num_prefix) + tokens + ([0] * num_suffix)
scores_ = ([0.1] * num_prefix) + scores + ([0.1] * num_suffix)
tokens_ = torch.tensor(tokens_, dtype=torch.int64, device=self.device)
scores_ = torch.tensor(scores_, dtype=torch.float32, device=self.device)
expected_ = [F.TokenSpan(s.token, s.start + num_prefix, s.end + num_prefix, s.score) for s in expected]
print(tokens_, scores_)
spans = F.merge_tokens(tokens_, scores_, blank=0)
self._assert_tokens(spans, expected_)


class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
mask_along_axis,
mask_along_axis_iid,
melscale_fbanks,
merge_tokens,
mu_law_decoding,
mu_law_encoding,
mvdr_weights_rtf,
Expand All @@ -59,6 +60,7 @@
spectral_centroid,
spectrogram,
speed,
TokenSpan,
)

__all__ = [
Expand Down Expand Up @@ -94,6 +96,8 @@
"filtfilt",
"flanger",
"forced_align",
"merge_tokens",
"TokenSpan",
"gain",
"highpass_biquad",
"lfilter",
Expand Down
61 changes: 61 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -53,6 +54,8 @@
"preemphasis",
"deemphasis",
"forced_align",
"TokenSpan",
"merge_tokens",
]


Expand Down Expand Up @@ -2566,3 +2569,61 @@ def forced_align(

paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores


@dataclass
class TokenSpan:
"""TokenSpan()
Token with time stamps and score. Returned by :py:func:`merge_tokens`.
"""

token: int
"""The token"""
start: int
"""The start time (inclusive) in emission time axis."""
end: int
"""The end time (exclusive) in emission time axis."""
score: float
"""The score of the this token."""

def __len__(self) -> int:
"""Returns the time span"""
return self.end - self.start


def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
"""Removes repeated tokens and blank tokens from the given CTC token sequence.
Args:
tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`.
scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`. When computing the token-size score, the given score is averaged
across the corresponding time span.
Returns:
list of TokenSpan
Example:
>>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
>>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
"""
if tokens.ndim != 1 or scores.ndim != 1:
raise ValueError("`tokens` and `scores` must be 1D Tensor.")
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")

t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
return spans

0 comments on commit 30668af

Please sign in to comment.