-
Notifications
You must be signed in to change notification settings - Fork 656
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Port the MMS FA model from tutorial to the library with post-processing module. Pull Request resolved: #3521 Reviewed By: huangruizhe Differential Revision: D48038285 Pulled By: mthrok fbshipit-source-id: 571cf0fceaaab4790983be2719f1a85805b814f5
- Loading branch information
1 parent
30668af
commit 5e211d6
Showing
7 changed files
with
399 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, List | ||
|
||
import torch | ||
import torchaudio.functional as F | ||
from torch import Tensor | ||
from torchaudio.functional import TokenSpan | ||
|
||
|
||
class ITokenizer(ABC): | ||
@abstractmethod | ||
def __call__(self, transcript: List[str]) -> List[List[str]]: | ||
"""Tokenize the given transcript (list of word) | ||
.. note:: | ||
The toranscript must be normalized. | ||
Args: | ||
transcript (list of str): Transcript (list of word). | ||
Returns: | ||
(list of int): List of token sequences | ||
""" | ||
|
||
|
||
class Tokenizer(ITokenizer): | ||
def __init__(self, dictionary: Dict[str, int]): | ||
self.dictionary = dictionary | ||
|
||
def __call__(self, transcript: List[str]) -> List[List[int]]: | ||
return [[self.dictionary[c] for c in word] for word in transcript] | ||
|
||
|
||
def _align_emission_and_tokens(emission: Tensor, tokens: List[int]): | ||
device = emission.device | ||
emission = emission.unsqueeze(0) | ||
targets = torch.tensor([tokens], dtype=torch.int32, device=device) | ||
|
||
aligned_tokens, scores = F.forced_align(emission, targets, 0) | ||
|
||
scores = scores.exp() # convert back to probability | ||
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension | ||
return aligned_tokens, scores | ||
|
||
|
||
class IAligner(ABC): | ||
@abstractmethod | ||
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]: | ||
"""Generate list of time-stamped token sequences | ||
Args: | ||
emission (Tensor): Sequence of token probability distributions. | ||
Shape: `(time, tokens)`. | ||
tokens (list of integer sequence): Tokenized transcript. | ||
Output from :py:class:`Wav2Vec2FABundle.Tokenizer`. | ||
Returns: | ||
(list of TokenSpan sequence): Tokens with time stamps and scores. | ||
""" | ||
|
||
|
||
def _unflatten(list_, lengths): | ||
assert len(list_) == sum(lengths) | ||
i = 0 | ||
ret = [] | ||
for l in lengths: | ||
ret.append(list_[i : i + l]) | ||
i += l | ||
return ret | ||
|
||
|
||
def _flatten(nested_list): | ||
return [item for list_ in nested_list for item in list_] | ||
|
||
|
||
class Aligner(IAligner): | ||
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]: | ||
if emission.ndim != 2: | ||
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}") | ||
|
||
emission = torch.log_softmax(emission, dim=-1) | ||
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens)) | ||
spans = F.merge_tokens(aligned_tokens, scores) | ||
return _unflatten(spans, [len(ts) for ts in tokens]) |
Oops, something went wrong.