Skip to content

Commit

Permalink
Fix FA bundle
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Aug 8, 2023
1 parent 3f98fb9 commit 1435365
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 34 deletions.
9 changes: 6 additions & 3 deletions torchaudio/pipelines/_wav2vec2/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __call__(self, transcript: List[str]) -> List[List[int]]:
return [[self.dictionary[c] for c in word] for word in transcript]


def _align_emission_and_tokens(emission: Tensor, tokens: List[int]):
def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
device = emission.device
emission = emission.unsqueeze(0)
targets = torch.tensor([tokens], dtype=torch.int32, device=device)

aligned_tokens, scores = F.forced_align(emission, targets, 0)
aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)

scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
Expand Down Expand Up @@ -75,11 +75,14 @@ def _flatten(nested_list):


class Aligner(IAligner):
def __init__(self, blank):
self.blank = blank

def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
if emission.ndim != 2:
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")

emission = torch.log_softmax(emission, dim=-1)
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens))
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
spans = F.merge_tokens(aligned_tokens, scores)
return _unflatten(spans, [len(ts) for ts in tokens])
26 changes: 13 additions & 13 deletions torchaudio/pipelines/_wav2vec2/impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

Expand Down Expand Up @@ -93,7 +92,7 @@ def get_model(self, *, dl_kwargs=None) -> Module:
state_dict = self._get_state_dict(dl_kwargs)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = utils._apply_input_layer_norm(model)
model = utils._extend_model(model, normalize_waveform=True)
model.eval()
return model

Expand Down Expand Up @@ -1587,11 +1586,6 @@ def get_labels(self, star: Optional[str] = "*", blank: str = "-") -> Tuple[str,
labels = super().get_labels(blank=blank)
return labels if star is None else (*labels, star)

def _get_params_with_star(self):
params = copy.deepcopy(self._params)
params["aux_num_out"] += 1
return params

def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight.
Expand All @@ -1605,13 +1599,19 @@ def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
Returns:
Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
.. note::
The model created with this method returns probability in log-domain,
(i.e. :py:func:`torch.nn.functional.log_softmax` is applied), whereas
the other Wav2Vec2 models returns logit.
"""
params = self._get_params_with_star() if with_star else self._params
model = utils._get_model(self._model_type, params)
state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis, with_star)
model = utils._get_model(self._model_type, self._params)
state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = utils._apply_input_layer_norm(model)
model = utils._extend_model(
model, normalize_waveform=self._normalize_waveform, apply_log_softmax=True, append_star=with_star
)
model.eval()
return model

Expand Down Expand Up @@ -1650,7 +1650,7 @@ def get_aligner(self) -> Aligner:
Returns:
Aligner
"""
return aligner.Aligner()
return aligner.Aligner(blank=0)


MMS_FA = Wav2Vec2FABundle(
Expand Down
37 changes: 19 additions & 18 deletions torchaudio/pipelines/_wav2vec2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,23 @@ class _Wav2Vec2Model(nn.Module):
This is used for layer normalization at the input
"""

def __init__(self, model: Wav2Vec2Model):
def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool):
super().__init__()
self.model = model
self.normalize_waveform = normalize_waveform
self.apply_log_softmax = apply_log_softmax
self.append_star = append_star

def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)
if self.normalize_waveform:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
output = self.model(waveforms, lengths)
if self.apply_log_softmax:
output = torch.nn.functional.log_softmax(output, dim=-1)
if self.append_star:
star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device)
output = torch.cat((output, star_dim), dim=-1)
return output

@torch.jit.export
def extract_features(
Expand All @@ -39,13 +49,14 @@ def extract_features(
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
if self.normalize_waveform:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)


def _apply_input_layer_norm(module):
"""Add extra layer_norm to the model"""
return _Wav2Vec2Model(module)
def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False):
"""Add extra transformations to the model"""
return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star)


def _remove_aux_axes(state_dict, axes):
Expand All @@ -65,23 +76,13 @@ def _remove_aux_axes(state_dict, axes):
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])


def _add_star_dim(state_dict):
w, b = state_dict["aux.weight"], state_dict["aux.bias"]
zeros = torch.zeros((1, w.size(1)), device=w.device, dtype=w.dtype)
state_dict["aux.weight"] = torch.cat((zeros, w), dim=0)
ones = torch.ones((1,), device=b.device, dtype=b.dtype)
state_dict["aux.bias"] = torch.cat((b, ones), dim=0)


def _get_state_dict(url, dl_kwargs, remove_axes=None, add_star=False):
def _get_state_dict(url, dl_kwargs, remove_axes=None):
if not url.startswith("https"):
url = f"https://download.pytorch.org/torchaudio/models/{url}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if remove_axes:
_remove_aux_axes(state_dict, remove_axes)
if add_star:
_add_star_dim(state_dict)
return state_dict


Expand Down

0 comments on commit 1435365

Please sign in to comment.