diff --git a/docs/source/pipelines.rst b/docs/source/pipelines.rst index f2e88035521..9bb68592b9f 100644 --- a/docs/source/pipelines.rst +++ b/docs/source/pipelines.rst @@ -142,6 +142,36 @@ Pretrained Models HUBERT_ASR_LARGE HUBERT_ASR_XLARGE +wav2vec 2.0 / HuBERT - Forced Alignment +--------------------------------------- + +Interface +~~~~~~~~~ + +``Wav2Vec2FABundle`` bundles pre-trained model and its associated dictionary. Additionally, it supports appending ``star`` token dimension. + +.. image:: https://download.pytorch.org/torchaudio/doc-assets/pipelines-wav2vec2asrbundle.png + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/bundle_class.rst + + Wav2Vec2FABundle + +.. rubric:: Tutorials using ``Wav2Vec2FABundle`` + +.. minigallery:: torchaudio.pipelines.Wav2Vec2FABundle + +Pertrained Models +~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/bundle_data.rst + + MMS_FA .. _Tacotron2: diff --git a/docs/source/refs.bib b/docs/source/refs.bib index bac17ee6285..3853bfa919a 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -570,3 +570,12 @@ @incollection{45611 URL = {https://arxiv.org/abs/1609.09430}, booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)} } + +@misc{pratap2023scaling, + title={Scaling Speech Technology to 1,000+ Languages}, + author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli}, + year={2023}, + eprint={2305.13516}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} diff --git a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py index 6f78b0e5d36..f1b27eeb533 100644 --- a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py +++ b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py @@ -37,12 +37,6 @@ from torchaudio.functional import forced_align -###################################################################### -# - -SAMPLE_RATE = 16000 - - ###################################################################### # # Here we define utility functions for computing the frame-level @@ -161,7 +155,7 @@ def plot_emission(emission): # # utility function for plotting word alignments -def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE): +def plot_alignments(waveform, emission, segments, word_segments, sample_rate): fig, ax = plt.subplots() ax.specgram(waveform[0], Fs=sample_rate) xlim = ax.get_xlim() @@ -187,7 +181,7 @@ def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAM # # utility function for playing audio segments. -def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE): +def display_segment(i, waveform, word_segments, num_frames, sample_rate): ratio = waveform.size(1) / num_frames word = word_segments[i] x0 = int(ratio * word.start) @@ -207,92 +201,20 @@ def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_R # order to verify the alignment quality. Here we first load the model and dictionary. # -from torchaudio.models import wav2vec2_model - -model = wav2vec2_model( - extractor_mode="layer_norm", - extractor_conv_layer_config=[ - (512, 10, 5), - (512, 3, 2), - (512, 3, 2), - (512, 3, 2), - (512, 3, 2), - (512, 2, 2), - (512, 2, 2), - ], - extractor_conv_bias=True, - encoder_embed_dim=1024, - encoder_projection_dropout=0.0, - encoder_pos_conv_kernel=128, - encoder_pos_conv_groups=16, - encoder_num_layers=24, - encoder_num_heads=16, - encoder_attention_dropout=0.0, - encoder_ff_interm_features=4096, - encoder_ff_interm_dropout=0.1, - encoder_dropout=0.0, - encoder_layer_norm_first=True, - encoder_layer_drop=0.1, - aux_num_out=31, -) - +from torchaudio.pipelines import MMS_FA -model.load_state_dict( - torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" - ) -) -model.eval() +bundle = MMS_FA +model = bundle.get_model(with_star=False) model.to(device) +dictionary = bundle.get_dict(star=None) def get_emission(waveform): with torch.inference_mode(): - # NOTE: this step is essential - waveform = torch.nn.functional.layer_norm(waveform, waveform.shape) emission, _ = model(waveform) return torch.log_softmax(emission, dim=-1) -# Construct the dictionary -# '@' represents the OOV token -# and are fairseq's legacy tokens, which're not used. -# token is omitted as we do not use it in this tutorial -dictionary = { - "": 0, - "": 1, - "": 2, - "@": 3, - "a": 4, - "i": 5, - "e": 6, - "n": 7, - "o": 8, - "u": 9, - "t": 10, - "s": 11, - "r": 12, - "m": 13, - "k": 14, - "l": 15, - "d": 16, - "g": 17, - "h": 18, - "y": 19, - "b": 20, - "p": 21, - "w": 22, - "c": 23, - "v": 24, - "j": 25, - "z": 26, - "f": 27, - "'": 28, - "q": 29, - "x": 30, -} - - ###################################################################### # Before aligning the speech with transcripts, we need to make sure # the transcripts are already romanized. Here are the BASH commands @@ -341,7 +263,9 @@ def get_emission(waveform): ###################################################################### # -waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE)) +waveform, _ = torchaudio.load( + speech_file, frame_offset=int(0.5 * bundle.sample_rate), num_frames=int(2.5 * bundle.sample_rate) +) emission = get_emission(waveform.to(device)) num_frames = emission.size(1) @@ -352,48 +276,48 @@ def get_emission(waveform): segments, word_segments = compute_alignments(text_normalized, dictionary, emission) -plot_alignments(waveform, emission, segments, word_segments) +plot_alignments(waveform, emission, segments, word_segments, bundle.sample_rate) ###################################################################### # -display_segment(0, waveform, word_segments, num_frames) +display_segment(0, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(1, waveform, word_segments, num_frames) +display_segment(1, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(2, waveform, word_segments, num_frames) +display_segment(2, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(3, waveform, word_segments, num_frames) +display_segment(3, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(4, waveform, word_segments, num_frames) +display_segment(4, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(5, waveform, word_segments, num_frames) +display_segment(5, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(6, waveform, word_segments, num_frames) +display_segment(6, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(7, waveform, word_segments, num_frames) +display_segment(7, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # Chinese @@ -430,52 +354,52 @@ def get_emission(waveform): segments, word_segments = compute_alignments(text_normalized, dictionary, emission) -plot_alignments(waveform, emission, segments, word_segments) +plot_alignments(waveform, emission, segments, word_segments, bundle.sample_rate) ###################################################################### # -display_segment(0, waveform, word_segments, num_frames) +display_segment(0, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(1, waveform, word_segments, num_frames) +display_segment(1, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(2, waveform, word_segments, num_frames) +display_segment(2, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(3, waveform, word_segments, num_frames) +display_segment(3, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(4, waveform, word_segments, num_frames) +display_segment(4, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(5, waveform, word_segments, num_frames) +display_segment(5, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(6, waveform, word_segments, num_frames) +display_segment(6, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(7, waveform, word_segments, num_frames) +display_segment(7, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(8, waveform, word_segments, num_frames) +display_segment(8, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### @@ -493,7 +417,7 @@ def get_emission(waveform): ###################################################################### # -waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE)) +waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * bundle.sample_rate)) emission = get_emission(waveform.to(device)) num_frames = emission.size(1) @@ -504,47 +428,47 @@ def get_emission(waveform): segments, word_segments = compute_alignments(text_normalized, dictionary, emission) -plot_alignments(waveform, emission, segments, word_segments) +plot_alignments(waveform, emission, segments, word_segments, bundle.sample_rate) ###################################################################### # -display_segment(0, waveform, word_segments, num_frames) +display_segment(0, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(1, waveform, word_segments, num_frames) +display_segment(1, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(2, waveform, word_segments, num_frames) +display_segment(2, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(3, waveform, word_segments, num_frames) +display_segment(3, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(4, waveform, word_segments, num_frames) +display_segment(4, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(5, waveform, word_segments, num_frames) +display_segment(5, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(6, waveform, word_segments, num_frames) +display_segment(6, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(7, waveform, word_segments, num_frames) +display_segment(7, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # Portuguese @@ -561,7 +485,9 @@ def get_emission(waveform): ###################################################################### # -waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE)) +waveform, _ = torchaudio.load( + speech_file, frame_offset=int(bundle.sample_rate), num_frames=int(4.6 * bundle.sample_rate) +) emission = get_emission(waveform.to(device)) num_frames = emission.size(1) @@ -572,52 +498,52 @@ def get_emission(waveform): segments, word_segments = compute_alignments(text_normalized, dictionary, emission) -plot_alignments(waveform, emission, segments, word_segments) +plot_alignments(waveform, emission, segments, word_segments, bundle.sample_rate) ###################################################################### # -display_segment(0, waveform, word_segments, num_frames) +display_segment(0, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(1, waveform, word_segments, num_frames) +display_segment(1, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(2, waveform, word_segments, num_frames) +display_segment(2, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(3, waveform, word_segments, num_frames) +display_segment(3, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(4, waveform, word_segments, num_frames) +display_segment(4, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(5, waveform, word_segments, num_frames) +display_segment(5, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(6, waveform, word_segments, num_frames) +display_segment(6, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(7, waveform, word_segments, num_frames) +display_segment(7, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(8, waveform, word_segments, num_frames) +display_segment(8, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # Italian @@ -634,7 +560,7 @@ def get_emission(waveform): ###################################################################### # -waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE)) +waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * bundle.sample_rate)) emission = get_emission(waveform.to(device)) num_frames = emission.size(1) @@ -645,37 +571,37 @@ def get_emission(waveform): segments, word_segments = compute_alignments(text_normalized, dictionary, emission) -plot_alignments(waveform, emission, segments, word_segments) +plot_alignments(waveform, emission, segments, word_segments, bundle.sample_rate) ###################################################################### # -display_segment(0, waveform, word_segments, num_frames) +display_segment(0, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(1, waveform, word_segments, num_frames) +display_segment(1, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(2, waveform, word_segments, num_frames) +display_segment(2, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(3, waveform, word_segments, num_frames) +display_segment(3, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(4, waveform, word_segments, num_frames) +display_segment(4, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # -display_segment(5, waveform, word_segments, num_frames) +display_segment(5, waveform, word_segments, num_frames, bundle.sample_rate) ###################################################################### # Conclusion diff --git a/torchaudio/pipelines/__init__.py b/torchaudio/pipelines/__init__.py index 267526d4464..efec1f3521e 100644 --- a/torchaudio/pipelines/__init__.py +++ b/torchaudio/pipelines/__init__.py @@ -18,6 +18,7 @@ HUBERT_BASE, HUBERT_LARGE, HUBERT_XLARGE, + MMS_FA, VOXPOPULI_ASR_BASE_10K_DE, VOXPOPULI_ASR_BASE_10K_EN, VOXPOPULI_ASR_BASE_10K_ES, @@ -41,6 +42,7 @@ WAV2VEC2_XLSR_300M, Wav2Vec2ASRBundle, Wav2Vec2Bundle, + Wav2Vec2FABundle, WAVLM_BASE, WAVLM_BASE_PLUS, WAVLM_LARGE, @@ -51,6 +53,7 @@ __all__ = [ "Wav2Vec2Bundle", "Wav2Vec2ASRBundle", + "Wav2Vec2FABundle", "WAV2VEC2_BASE", "WAV2VEC2_LARGE", "WAV2VEC2_LARGE_LV60K", @@ -77,6 +80,7 @@ "HUBERT_XLARGE", "HUBERT_ASR_LARGE", "HUBERT_ASR_XLARGE", + "MMS_FA", "WAVLM_BASE", "WAVLM_BASE_PLUS", "WAVLM_LARGE", diff --git a/torchaudio/pipelines/_wav2vec2/impl.py b/torchaudio/pipelines/_wav2vec2/impl.py index 6a8faf1127a..29e3b7b4679 100644 --- a/torchaudio/pipelines/_wav2vec2/impl.py +++ b/torchaudio/pipelines/_wav2vec2/impl.py @@ -1,5 +1,6 @@ +import copy from dataclasses import dataclass -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple from torch.nn import Module @@ -146,7 +147,7 @@ def get_labels( *, blank: str = "-", ) -> Tuple[str, ...]: - """The output class labels (only applicable to fine-tuned bundles) + """The output class labels. The first is blank token, and it is customizable. @@ -159,8 +160,8 @@ def get_labels( the output class labels. Example - >>> import torchaudio - >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() + >>> from torchaudio.pipelines import HUBERT_ASR_LARGE as bundle + >>> bundle.get_labels() ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') """ # noqa: E501 return (blank, *self._labels) @@ -1518,3 +1519,159 @@ def _get_state_dict(self, dl_kwargs): Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details. """ # noqa: E501 + + +@dataclass +class Wav2Vec2FABundle(Wav2Vec2ASRBundle): + """Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model` for forced alignment. + + This class provides interfaces for instantiating the pretrained model along with + the information necessary to retrieve pretrained weights and additional data + to be used with the model. + + Torchaudio library instantiates objects of this class, each of which represents + a different pretrained model. Client code should access pretrained models via these + instances. + + Please see below for the usage and the available values. + + Example - Feature Extraction + >>> import torchaudio + >>> + >>> bundle = torchaudio.pipelines.MMS_FA + >>> + >>> # Build the model and load pretrained weight. + >>> model = bundle.get_model() + Downloading: + 100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s] + >>> + >>> # Resample audio to the expected sampling rate + >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) + >>> + >>> # Estimate the probability of token distribution + >>> emission, _ = model(waveform) + >>> + >>> # Generate frame-wise alignment + >>> alignment, scores = torchaudio.functional.forced_align( + >>> emission, targets, input_lengths, target_lengths, blank=0) + >>> + """ # noqa: E501 + + def get_labels(self, star: Optional[str] = "", blank: str = "") -> Tuple[str, ...]: + """Get the labels corresponding to the feature dimension of emission. + + The first is blank token, and it is customizable. + + Args: + star (str or None, optional): Change or disable star token. (default: ``""``) + blank (str, optional): Change the blank token. (default: ``'-'``) + + Returns: + Tuple[str, ...]: + For models fine-tuned on ASR, returns the tuple of strings representing + the output class labels. + + Example + >>> from torchaudio.pipelines import MMS_FA as bundle + >>> bundle.get_labels() + ('', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '') + >>> bundle.get_labels(star=None) + ('', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x') + """ # noqa: E501 + labels = super().get_labels(blank=blank) + return labels if star is None else (*labels, star) + + def _get_params_with_star(self): + params = copy.deepcopy(self._params) + params["aux_num_out"] += 1 + return params + + def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module: + """Construct the model and load the pretrained weight. + + The weight file is downloaded from the internet and cached with + :func:`torch.hub.load_state_dict_from_url` + + Args: + with_star (bool, optional): If enabled, the last dimension of output layer is + extended by one, which corresponds to `star` token. + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`. + """ + params = self._get_params_with_star() if with_star else self._params + model = utils._get_model(self._model_type, params) + state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis, with_star) + model.load_state_dict(state_dict) + if self._normalize_waveform: + model = utils._apply_input_layer_norm(model) + model.eval() + return model + + def get_dict(self, star: Optional[str] = "", blank: str = "") -> Dict[str, int]: + """Get the mapping from token to index (in emission feature dim) + + Args: + star (str or None, optional): Change or disable star token. (default: ``""``) + blank (str, optional): Change the blank token. (default: ``'-'``) + + Returns: + Tuple[str, ...]: + For models fine-tuned on ASR, returns the tuple of strings representing + the output class labels. + + Example + >>> from torchaudio.pipelines import MMS_FA as bundle + >>> bundle.get_dict() + {'': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '': 28} + >>> bundle.get_dict(star=None) + {'': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27} + """ # noqa: E501 + return {k: i for i, k in enumerate(self.get_labels(star=star, blank=blank))} + + +MMS_FA = Wav2Vec2FABundle( + "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt", + { + "extractor_mode": "layer_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": True, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.0, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.0, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.1, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": True, + "encoder_layer_drop": 0.1, + "aux_num_out": 28, + }, + _labels=utils._get_mms_labels(), + _sample_rate=16000, + _normalize_waveform=True, + _model_type="Wav2Vec2", +) +MMS_FA.__doc__ = """ +Trained on 31K hours of data in 1,130 languages from *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling`. + +Published by the authors of *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling` under [`CC-BY-NC 4.0 License `__]. + +Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2FABundle` for usage details. + +.. note:: + + Unlike other Wav2Vec2 bundles, this model does not have a token for word boundary (like `|`). This makes the post-processing of alignments slightly different. +""" # noqa: E501 diff --git a/torchaudio/pipelines/_wav2vec2/utils.py b/torchaudio/pipelines/_wav2vec2/utils.py index 0ab459f34ae..69e869208b2 100644 --- a/torchaudio/pipelines/_wav2vec2/utils.py +++ b/torchaudio/pipelines/_wav2vec2/utils.py @@ -65,13 +65,23 @@ def _remove_aux_axes(state_dict, axes): state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes]) -def _get_state_dict(url, dl_kwargs, remove_axes=None): +def _add_star_dim(state_dict): + w, b = state_dict["aux.weight"], state_dict["aux.bias"] + zeros = torch.zeros((1, w.size(1)), device=w.device, dtype=w.dtype) + state_dict["aux.weight"] = torch.cat((zeros, w), dim=0) + ones = torch.ones((1,), device=b.device, dtype=b.dtype) + state_dict["aux.bias"] = torch.cat((b, ones), dim=0) + + +def _get_state_dict(url, dl_kwargs, remove_axes=None, add_star=False): if not url.startswith("https"): url = f"https://download.pytorch.org/torchaudio/models/{url}" dl_kwargs = {} if dl_kwargs is None else dl_kwargs state_dict = load_state_dict_from_url(url, **dl_kwargs) if remove_axes: _remove_aux_axes(state_dict, remove_axes) + if add_star: + _add_star_dim(state_dict) return state_dict @@ -301,3 +311,35 @@ def _get_it_labels(): "í", "ï", ) + + +def _get_mms_labels(): + return ( + "a", + "i", + "e", + "n", + "o", + "u", + "t", + "s", + "r", + "m", + "k", + "l", + "d", + "g", + "h", + "y", + "b", + "p", + "w", + "c", + "v", + "j", + "z", + "f", + "'", + "q", + "x", + )