Skip to content

Commit

Permalink
Fix FA bundle
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Aug 7, 2023
1 parent 90143e9 commit 6cd3268
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions torchaudio/pipelines/_wav2vec2/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __call__(self, transcript: List[str]) -> List[List[int]]:
return [[self.dictionary[c] for c in word] for word in transcript]


def _align_emission_and_tokens(emission: Tensor, tokens: List[int]):
def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
device = emission.device
emission = emission.unsqueeze(0)
targets = torch.tensor([tokens], dtype=torch.int32, device=device)

aligned_tokens, scores = F.forced_align(emission, targets, 0)
aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)

scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
Expand Down Expand Up @@ -75,11 +75,14 @@ def _flatten(nested_list):


class Aligner(IAligner):
def __init__(self, blank):
self.blank = blank

def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
if emission.ndim != 2:
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")

emission = torch.log_softmax(emission, dim=-1)
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens))
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
spans = F.merge_tokens(aligned_tokens, scores)
return _unflatten(spans, [len(ts) for ts in tokens])
2 changes: 1 addition & 1 deletion torchaudio/pipelines/_wav2vec2/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ def get_aligner(self) -> Aligner:
Returns:
Aligner
"""
return aligner.Aligner()
return aligner.Aligner(blank=0)


MMS_FA = Wav2Vec2FABundle(
Expand Down

0 comments on commit 6cd3268

Please sign in to comment.