Skip to content

Commit

Permalink
Add MMS FA Bundle (#3521)
Browse files Browse the repository at this point in the history
Summary:
Port the MMS FA model from tutorial to the library with post-processing module.

Pull Request resolved: #3521

Reviewed By: huangruizhe

Differential Revision: D48038285

Pulled By: mthrok

fbshipit-source-id: 571cf0fceaaab4790983be2719f1a85805b814f5
  • Loading branch information
mthrok authored and facebook-github-bot committed Aug 7, 2023
1 parent 30668af commit 5e211d6
Show file tree
Hide file tree
Showing 7 changed files with 399 additions and 13 deletions.
49 changes: 42 additions & 7 deletions docs/source/_templates/autosummary/bundle_class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

.. autoclass:: {{ fullname }}()

{%- if name in ["RNNTBundle.FeatureExtractor", "RNNTBundle.TokenProcessor"] %}
{%- set support_classes = [] %}
{%- if name in ["RNNTBundle.FeatureExtractor", "RNNTBundle.TokenProcessor", "Wav2Vec2FABundle.Tokenizer"] %}
{%- set methods = ["__call__"] %}
{%- elif name == "Wav2Vec2FABundle.Aligner" %}
{%- set attributes = [] %}
{%- set methods = ["__call__"] %}
{%- set support_classes = ["Token"] %}
{%- elif name == "Tacotron2TTSBundle.TextProcessor" %}
{%- set attributes = ["tokens"] %}
{%- set methods = ["__call__"] %}
Expand All @@ -21,12 +26,17 @@
{%- set methods = ["__call__"] %}
{% endif %}

..
ATTRIBUTES
{%- if attributes %}

Properties
----------

{%- endif %}

{%- for item in attributes %}
{%- if not item.startswith('_') %}

{{ item | underline("-") }}
{{ item | underline("~") }}

.. container:: py attribute

Expand All @@ -35,17 +45,42 @@
{%- endif %}
{%- endfor %}

..
METHODS
{%- if methods %}

Methods
-------

{%- endif %}

{%- for item in methods %}
{%- if item != "__init__" %}

{{item | underline("-") }}
{{item | underline("~") }}

.. container:: py attribute

.. automethod:: {{[fullname, item] | join('.')}}

{%- endif %}
{%- endfor %}

{%- if support_classes %}

Support Structures
------------------

{%- endif %}

{%- for item in support_classes %}

{% set components = item.split('.') %}

{{ components[-1] | underline("~") }}

.. container:: py attribute

.. autoclass:: {{[fullname, item] | join('.')}}
:members:


{%- endfor %}
32 changes: 32 additions & 0 deletions docs/source/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,38 @@ Pretrained Models
HUBERT_ASR_LARGE
HUBERT_ASR_XLARGE

wav2vec 2.0 / HuBERT - Forced Alignment
---------------------------------------

Interface
~~~~~~~~~

``Wav2Vec2FABundle`` bundles pre-trained model and its associated dictionary. Additionally, it supports appending ``star`` token dimension.

.. image:: https://download.pytorch.org/torchaudio/doc-assets/pipelines-wav2vec2fabundle.png

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst

Wav2Vec2FABundle
Wav2Vec2FABundle.Tokenizer
Wav2Vec2FABundle.Aligner

.. rubric:: Tutorials using ``Wav2Vec2FABundle``

.. minigallery:: torchaudio.pipelines.Wav2Vec2FABundle

Pertrained Models
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst

MMS_FA

.. _Tacotron2:

Expand Down
9 changes: 9 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,12 @@ @incollection{45611
URL = {https://arxiv.org/abs/1609.09430},
booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)}
}

@misc{pratap2023scaling,
title={Scaling Speech Technology to 1,000+ Languages},
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
year={2023},
eprint={2305.13516},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
4 changes: 4 additions & 0 deletions torchaudio/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
HUBERT_BASE,
HUBERT_LARGE,
HUBERT_XLARGE,
MMS_FA,
VOXPOPULI_ASR_BASE_10K_DE,
VOXPOPULI_ASR_BASE_10K_EN,
VOXPOPULI_ASR_BASE_10K_ES,
Expand All @@ -41,6 +42,7 @@
WAV2VEC2_XLSR_300M,
Wav2Vec2ASRBundle,
Wav2Vec2Bundle,
Wav2Vec2FABundle,
WAVLM_BASE,
WAVLM_BASE_PLUS,
WAVLM_LARGE,
Expand All @@ -51,6 +53,7 @@
__all__ = [
"Wav2Vec2Bundle",
"Wav2Vec2ASRBundle",
"Wav2Vec2FABundle",
"WAV2VEC2_BASE",
"WAV2VEC2_LARGE",
"WAV2VEC2_LARGE_LV60K",
Expand All @@ -77,6 +80,7 @@
"HUBERT_XLARGE",
"HUBERT_ASR_LARGE",
"HUBERT_ASR_XLARGE",
"MMS_FA",
"WAVLM_BASE",
"WAVLM_BASE_PLUS",
"WAVLM_LARGE",
Expand Down
85 changes: 85 additions & 0 deletions torchaudio/pipelines/_wav2vec2/aligner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from typing import Dict, List

import torch
import torchaudio.functional as F
from torch import Tensor
from torchaudio.functional import TokenSpan


class ITokenizer(ABC):
@abstractmethod
def __call__(self, transcript: List[str]) -> List[List[str]]:
"""Tokenize the given transcript (list of word)
.. note::
The toranscript must be normalized.
Args:
transcript (list of str): Transcript (list of word).
Returns:
(list of int): List of token sequences
"""


class Tokenizer(ITokenizer):
def __init__(self, dictionary: Dict[str, int]):
self.dictionary = dictionary

def __call__(self, transcript: List[str]) -> List[List[int]]:
return [[self.dictionary[c] for c in word] for word in transcript]


def _align_emission_and_tokens(emission: Tensor, tokens: List[int]):
device = emission.device
emission = emission.unsqueeze(0)
targets = torch.tensor([tokens], dtype=torch.int32, device=device)

aligned_tokens, scores = F.forced_align(emission, targets, 0)

scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
return aligned_tokens, scores


class IAligner(ABC):
@abstractmethod
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
"""Generate list of time-stamped token sequences
Args:
emission (Tensor): Sequence of token probability distributions.
Shape: `(time, tokens)`.
tokens (list of integer sequence): Tokenized transcript.
Output from :py:class:`Wav2Vec2FABundle.Tokenizer`.
Returns:
(list of TokenSpan sequence): Tokens with time stamps and scores.
"""


def _unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret


def _flatten(nested_list):
return [item for list_ in nested_list for item in list_]


class Aligner(IAligner):
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
if emission.ndim != 2:
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")

emission = torch.log_softmax(emission, dim=-1)
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens))
spans = F.merge_tokens(aligned_tokens, scores)
return _unflatten(spans, [len(ts) for ts in tokens])
Loading

0 comments on commit 5e211d6

Please sign in to comment.