From a6bbeaf37242af576f854d33dda84c22ebd6867e Mon Sep 17 00:00:00 2001 From: Steve Korshakov Date: Wed, 20 Nov 2024 16:12:30 -0800 Subject: [PATCH] fix: fixing batching in MMS_FA --- src/torchaudio/pipelines/_wav2vec2/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchaudio/pipelines/_wav2vec2/utils.py b/src/torchaudio/pipelines/_wav2vec2/utils.py index e690e8103c..65a7a6a2c6 100644 --- a/src/torchaudio/pipelines/_wav2vec2/utils.py +++ b/src/torchaudio/pipelines/_wav2vec2/utils.py @@ -38,7 +38,7 @@ def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[ if self.apply_log_softmax: output = torch.nn.functional.log_softmax(output, dim=-1) if self.append_star: - star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device) + star_dim = torch.zeros((output.size(0), output.size(1), 1), dtype=output.dtype, device=output.device) output = torch.cat((output, star_dim), dim=-1) return output, output_lengths