Skip to content

Commit

Permalink
Move alignment code to separate submodule (#3536)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3536

Reviewed By: huangruizhe

Differential Revision: D48120170

Pulled By: mthrok

fbshipit-source-id: dec7575db07734490099b35a8bfc854252952c6e
  • Loading branch information
mthrok authored and facebook-github-bot committed Aug 7, 2023
1 parent 5e211d6 commit 90143e9
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 131 deletions.
4 changes: 1 addition & 3 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._alignment import forced_align, merge_tokens, TokenSpan
from .filtering import (
allpass_biquad,
band_biquad,
Expand Down Expand Up @@ -35,15 +36,13 @@
detect_pitch_frequency,
edit_distance,
fftconvolve,
forced_align,
griffinlim,
inverse_spectrogram,
linear_fbanks,
loudness,
mask_along_axis,
mask_along_axis_iid,
melscale_fbanks,
merge_tokens,
mu_law_decoding,
mu_law_encoding,
mvdr_weights_rtf,
Expand All @@ -60,7 +59,6 @@
spectral_centroid,
spectrogram,
speed,
TokenSpan,
)

__all__ = [
Expand Down
131 changes: 131 additions & 0 deletions torchaudio/functional/_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torchaudio._extension import fail_if_no_align

__all__ = []


@fail_if_no_align
def forced_align(
log_probs: Tensor,
targets: Tensor,
input_lengths: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
blank: int = 0,
) -> Tuple[Tensor, Tensor]:
r"""Align a CTC label sequence to an emission.
.. devices:: CPU CUDA
.. properties:: TorchScript
Args:
log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length.
input_lengths (Tensor or None, optional):
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor or None, optional):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
Tuple(Tensor, Tensor):
Tensor: Label for each time step in the alignment path computed using forced alignment.
Tensor: Log probability scores of the labels for each time step.
Note:
The sequence length of `log_probs` must satisfy:
.. math::
L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}}
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size==1``.
"""
if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
if torch.max(targets) >= log_probs.shape[-1]:
raise ValueError("targets values must be less than the CTC dimension")

if input_lengths is None:
batch_size, length = log_probs.size(0), log_probs.size(1)
input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
if target_lengths is None:
batch_size, length = targets.size(0), targets.size(1)
target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)

# For TorchScript compatibility
assert input_lengths is not None
assert target_lengths is not None

paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores


@dataclass
class TokenSpan:
"""TokenSpan()
Token with time stamps and score. Returned by :py:func:`merge_tokens`.
"""

token: int
"""The token"""
start: int
"""The start time (inclusive) in emission time axis."""
end: int
"""The end time (exclusive) in emission time axis."""
score: float
"""The score of the this token."""

def __len__(self) -> int:
"""Returns the time span"""
return self.end - self.start


def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
"""Removes repeated tokens and blank tokens from the given CTC token sequence.
Args:
tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`.
scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`. When computing the token-size score, the given score is averaged
across the corresponding time span.
Returns:
list of TokenSpan
Example:
>>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
>>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
"""
if tokens.ndim != 1 or scores.ndim != 1:
raise ValueError("`tokens` and `scores` must be 1D Tensor.")
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")

t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
return spans
128 changes: 0 additions & 128 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import tempfile
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torchaudio
from torch import Tensor
from torchaudio._extension import fail_if_no_align
from torchaudio._internal.module_utils import deprecated

from .filtering import highpass_biquad, treble_biquad
Expand Down Expand Up @@ -53,9 +51,6 @@
"speed",
"preemphasis",
"deemphasis",
"forced_align",
"TokenSpan",
"merge_tokens",
]


Expand Down Expand Up @@ -2504,126 +2499,3 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)


@fail_if_no_align
def forced_align(
log_probs: Tensor,
targets: Tensor,
input_lengths: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
blank: int = 0,
) -> Tuple[Tensor, Tensor]:
r"""Align a CTC label sequence to an emission.
.. devices:: CPU CUDA
.. properties:: TorchScript
Args:
log_probs (Tensor): log probability of CTC emission output.
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
`C` is the number of characters in alphabet including blank.
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
where `L` is the target length.
input_lengths (Tensor or None, optional):
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
target_lengths (Tensor or None, optional):
Lengths of the targets. 1-D Tensor of shape `(B,)`.
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
Returns:
Tuple(Tensor, Tensor):
Tensor: Label for each time step in the alignment path computed using forced alignment.
Tensor: Log probability scores of the labels for each time step.
Note:
The sequence length of `log_probs` must satisfy:
.. math::
L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}}
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
For example, in str `"aabbc"`, the number of repeats are `2`.
Note:
The current version only supports ``batch_size==1``.
"""
if blank in targets:
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
if torch.max(targets) >= log_probs.shape[-1]:
raise ValueError("targets values must be less than the CTC dimension")

if input_lengths is None:
batch_size, length = log_probs.size(0), log_probs.size(1)
input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
if target_lengths is None:
batch_size, length = targets.size(0), targets.size(1)
target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)

# For TorchScript compatibility
assert input_lengths is not None
assert target_lengths is not None

paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
return paths, scores


@dataclass
class TokenSpan:
"""TokenSpan()
Token with time stamps and score. Returned by :py:func:`merge_tokens`.
"""

token: int
"""The token"""
start: int
"""The start time (inclusive) in emission time axis."""
end: int
"""The end time (exclusive) in emission time axis."""
score: float
"""The score of the this token."""

def __len__(self) -> int:
"""Returns the time span"""
return self.end - self.start


def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
"""Removes repeated tokens and blank tokens from the given CTC token sequence.
Args:
tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`.
scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
Shape: `(time, )`. When computing the token-size score, the given score is averaged
across the corresponding time span.
Returns:
list of TokenSpan
Example:
>>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
>>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
"""
if tokens.ndim != 1 or scores.ndim != 1:
raise ValueError("`tokens` and `scores` must be 1D Tensor.")
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")

t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
return spans

0 comments on commit 90143e9

Please sign in to comment.