diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/__init__.py b/egs/librispeech/asr/simple_v1/espnet_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/asr.py b/egs/librispeech/asr/simple_v1/espnet_utils/asr.py new file mode 100644 index 00000000..cc17878b --- /dev/null +++ b/egs/librispeech/asr/simple_v1/espnet_utils/asr.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +import argparse +import logging +from typing import Tuple + +import numpy as np +import torch + +from espnet_utils.common import load_espnet_model_config +from espnet_utils.common import rename_state_dict, combine_qkv +from espnet_utils.frontened import Fbank +from espnet_utils.frontened import GlobalMVN +from espnet_utils.numericalizer import SpmNumericalizer +from snowfall.models.conformer import Conformer + +_ESPNET_ENCODER_KEY_TO_SNOWFALL_KEY = [ + ('frontend.logmel.melmat', 'frontend.melmat'), + ('encoder.embed.out.0.weight', 'encoder.embed.out.weight'), + ('encoder.embed.out.0.bias', 'encoder.embed.out.bias'), + (r'(encoder.encoders.)(\d+)(.self_attn.)linear_out([\s\S*])', + r'\1\2\3out_proj\4'), + (r'(encoder.encoders.)(\d+)', r'\1layers.\2'), + (r'(encoder.encoders.layers.)(\d+)(.feed_forward.)(w_1)', + r'\1\2.feed_forward.0'), + (r'(encoder.encoders.layers.)(\d+)(.feed_forward.)(w_2)', + r'\1\2.feed_forward.3'), + (r'(encoder.encoders.layers.)(\d+)(.feed_forward_macaron.)(w_1)', + r'\1\2.feed_forward_macaron.0'), + (r'(encoder.encoders.layers.)(\d+)(.feed_forward_macaron.)(w_2)', + r'\1\2.feed_forward_macaron.3'), + (r'(encoder.embed.)([\s\S*])', r'encoder.encoder_embed.\2'), + (r'(encoder.encoders.)([\s\S*])', r'encoder.encoder.\2'), + (r'(ctc.ctc_lo.)([\s\S*])', r'encoder.encoder_output_layer.1.\2'), +] + + +class ESPnetASRModel(torch.nn.Module): + + def __init__( + self, + frontend: None, + normalize: None, + encoder: None, + ): + + super().__init__() + self.frontend = frontend + self.normalize = normalize + self.encoder = encoder + + def encode( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + feats, feats_lengths = self.frontend(speech, speech_lengths) + + feats, feats_lengths = self.normalize(feats, feats_lengths) + + feats = feats.permute(0, 2, 1) + + nnet_output, _, _ = self.encoder(feats) + nnet_output = nnet_output.permute(2, 0, 1) + return nnet_output + + @classmethod + def build_model(cls, asr_train_config, asr_model_file, device): + args = load_espnet_model_config(asr_train_config) + # {'fs': '16k', 'hop_length': 256, 'n_fft': 512} + frontend = Fbank(**args.frontend_conf) + normalize = GlobalMVN(**args.normalize_conf) + encoder = Conformer(num_features=80, + num_classes=len(args.token_list), + subsampling_factor=4, + d_model=512, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + cnn_module_kernel=31, + num_decoder_layers=0, + is_espnet_structure=True) + + model = ESPnetASRModel( + frontend=frontend, + normalize=normalize, + encoder=encoder, + ) + + state_dict = torch.load(asr_model_file, map_location=device) + + state_dict = { + k: v for k, v in state_dict.items() if not k.startswith('decoder') + } + + combine_qkv(state_dict, num_encoder_layers=11) + rename_state_dict(rename_patterns=_ESPNET_ENCODER_KEY_TO_SNOWFALL_KEY, + state_dict=state_dict) + + model.load_state_dict(state_dict, strict=False) + model = model.to(torch.device(device)) + + numericalizer = SpmNumericalizer(tokenizer_type='spm', + tokenizer_file=args.bpemodel, + token_list=args.token_list, + unk_symbol='') + return model, numericalizer diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/common.py b/egs/librispeech/asr/simple_v1/espnet_utils/common.py new file mode 100644 index 00000000..d794a5c2 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/espnet_utils/common.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +import argparse +import re +import yaml + +from typing import List, Tuple, Dict +from pathlib import Path + +import torch + + +def load_espnet_model_config(config_file): + config_file = Path(config_file) + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + return argparse.Namespace(**args) + + +def rename_state_dict(rename_patterns: List[Tuple[str, str]], + state_dict: Dict[str, torch.Tensor]): + # Rename state dict to load espent model + if rename_patterns is not None: + for old_pattern, new_pattern in rename_patterns: + old_keys = [ + k for k in state_dict if re.match(old_pattern, k) is not None + ] + for k in old_keys: + v = state_dict.pop(k) + new_k = re.sub(old_pattern, new_pattern, k) + state_dict[new_k] = v + + +def combine_qkv(state_dict: Dict[str, torch.Tensor], num_encoder_layers=11): + for layer in range(num_encoder_layers + 1): + q_w = state_dict[f'encoder.encoders.{layer}.self_attn.linear_q.weight'] + k_w = state_dict[f'encoder.encoders.{layer}.self_attn.linear_k.weight'] + v_w = state_dict[f'encoder.encoders.{layer}.self_attn.linear_v.weight'] + q_b = state_dict[f'encoder.encoders.{layer}.self_attn.linear_q.bias'] + k_b = state_dict[f'encoder.encoders.{layer}.self_attn.linear_k.bias'] + v_b = state_dict[f'encoder.encoders.{layer}.self_attn.linear_v.bias'] + + for param_type in ['weight', 'bias']: + for layer_type in ['q', 'k', 'v']: + key_to_remove = f'encoder.encoders.{layer}.self_attn.linear_{layer_type}.{param_type}' + state_dict.pop(key_to_remove) + + in_proj_weight = torch.cat([q_w, k_w, v_w], dim=0) + in_proj_bias = torch.cat([q_b, k_b, v_b], dim=0) + key_weight = f'encoder.encoders.{layer}.self_attn.in_proj.weight' + state_dict[key_weight] = in_proj_weight + key_bias = f'encoder.encoders.{layer}.self_attn.in_proj.bias' + state_dict[key_bias] = in_proj_bias diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/frontened.py b/egs/librispeech/asr/simple_v1/espnet_utils/frontened.py new file mode 100644 index 00000000..ae3e9242 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/espnet_utils/frontened.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +import humanfriendly +import librosa +import numpy as np +import torch + +from pathlib import Path +from typeguard import check_argument_types +from typing import Optional, Tuple, Union + + +# Modified from: +# https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet/nets/pytorch_backend/frontends/feature_transform.py#L135 +class GlobalMVN(torch.nn.Module): + """Apply global mean and variance normalization + + TODO(kamo): Make this class portable somehow + + Args: + stats_file: npy file + norm_means: Apply mean normalization + norm_vars: Apply var normalization + eps: + """ + + def __init__( + self, + stats_file: Union[Path, str], + norm_means: bool = True, + norm_vars: bool = True, + eps: float = 1.0e-20, + ): + assert check_argument_types() + super().__init__() + self.norm_means = norm_means + self.norm_vars = norm_vars + self.eps = eps + stats_file = Path(stats_file) + + self.stats_file = stats_file + stats = np.load(stats_file) + if isinstance(stats, np.ndarray): + # Kaldi like stats + count = stats[0].flatten()[-1] + mean = stats[0, :-1] / count + var = stats[1, :-1] / count - mean * mean + else: + # New style: Npz file + count = stats["count"] + sum_v = stats["sum"] + sum_square_v = stats["sum_square"] + mean = sum_v / count + var = sum_square_v / count - mean * mean + std = np.sqrt(np.maximum(var, eps)) + + self.register_buffer("mean", torch.from_numpy(mean)) + self.register_buffer("std", torch.from_numpy(std)) + + def forward( + self, + x: torch.Tensor, + ilens: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function + + Args: + x: (B, L, ...) + ilens: (B,) + """ + if ilens is None: + ilens = x.new_full([x.size(0)], x.size(1)) + norm_means = self.norm_means + norm_vars = self.norm_vars + self.mean = self.mean.to(x.device, x.dtype) + self.std = self.std.to(x.device, x.dtype) + + # feat: (B, T, D) + if norm_means: + if x.requires_grad: + x = x - self.mean + else: + x -= self.mean + + if norm_vars: + x /= self.std + + return x, ilens + + +# Modified from: +# https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet2/layers/stft.py#L14:7 +class Stft(torch.nn.Module): + + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + super().__init__() + self.n_fft = n_fft + if win_length is None: + self.win_length = n_fft + else: + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.normalized = normalized + self.onesided = onesided + if window is not None and not hasattr(torch, f"{window}_window"): + raise ValueError(f"{window} window is not implemented") + self.window = window + + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """STFT forward function. + + Args: + input: (Batch, Nsamples) or (Batch, Nsample, Channels) + ilens: (Batch) + Returns: + output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) + + """ + bs = input.size(0) + + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + window = window_func(self.win_length, + dtype=input.dtype, + device=input.device) + else: + window = None + output = torch.stft( + input, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=self.center, + window=window, + normalized=self.normalized, + onesided=self.onesided, + ) + output = output.transpose(1, 2) + + if self.center: + pad = self.win_length // 2 + ilens = ilens + 2 * pad + + olens = (ilens - self.win_length) // self.hop_length + 1 + + return output, olens + + +# Modified from: +# https://github.com/espnet/espnet/blob/08feae5bb93fa8f6dcba66760c8617a4b5e39d70/espnet2/asr/frontend/default.py#L19 +class Fbank(torch.nn.Module): + """ + + Stft -> Power-spec -> Mel-Fbank + """ + + def __init__( + self, + fs: Union[int, str] = 16000, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: int = None, + fmax: int = None, + ): + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + center=center, + window=window, + normalized=normalized, + onesided=onesided, + ) + + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + _mel_options = dict( + sr=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + ) + + # _mel_options = {'sr': 16000, 'n_fft': 512, 'n_mels': 80, 'fmin': 0, 'fmax': 8000.0, 'htk': False} + melmat = librosa.filters.mel(**_mel_options) + + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + + def forward( + self, input: torch.Tensor, + input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + input_stft, feats_lens = self.stft(input, input_lengths) + input_stft = torch.complex(input_stft[..., 0], input_stft[..., 1]) + + input_power = input_stft.real**2 + input_stft.imag**2 + + mel_feat = torch.matmul(input_power, self.melmat) + mel_feat = torch.clamp(mel_feat, min=1e-10) + + input_feats = mel_feat.log() + + return input_feats, feats_lens diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/load_lm_model.py b/egs/librispeech/asr/simple_v1/espnet_utils/load_lm_model.py new file mode 100644 index 00000000..ab510bef --- /dev/null +++ b/egs/librispeech/asr/simple_v1/espnet_utils/load_lm_model.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +import argparse +import re +from typing import Dict, List, Tuple, Union +from pathlib import Path + +import torch +import yaml + +from snowfall.models.lm_transformer import TransformerLM +from espnet_utils.common import rename_state_dict + +_ESPNET_TRANSFORMER_LM_KEY_TO_SNOWFALL_KEY = [ + (r'([\s\S]*).feed_forward.w_1', r'\1.linear1'), + (r'([\s\S]*).feed_forward.w_2', r'\1.linear2'), + (r'([\s\S]*).encoder.embed([\s\S]*)', r'\1.input_embed\2'), + (r'(lm.encoder.encoders.)(\d+)', r'\1layers.\2'), + (r'(lm.)([\s\S]*)', r'\2'), +] + + +def load_espnet_model( + config: Dict, + model_file: Union[Path, str], +): + """This method is used to load LM model downloaded from espnet model zoo. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + + """ + model = TransformerLM(**config) + + assert model_file is not None, f"model file doesn't exist" + state_dict = torch.load(model_file) + + rename_state_dict( + rename_patterns=_ESPNET_TRANSFORMER_LM_KEY_TO_SNOWFALL_KEY, + state_dict=state_dict) + model.load_state_dict(state_dict) + + return model + + +def build_lm_model_from_file(config=None, + model_file=None, + model_type='espnet'): + if model_type == 'espnet': + return load_espnet_model(config, model_file) + elif model_type == 'snowfall': + raise NotImplementedError(f'Snowfall model to be suppported') + else: + raise ValueError(f'Unsupported model type {model_type}') diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/nnlm_evaluator.py b/egs/librispeech/asr/simple_v1/espnet_utils/nnlm_evaluator.py new file mode 100644 index 00000000..4c00c4f1 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/espnet_utils/nnlm_evaluator.py @@ -0,0 +1,113 @@ +import argparse +import copy +import os +import yaml + +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import torch +from espnet_utils.common import load_espnet_model_config +from espnet_utils.text_dataset import DatasetOption, TextFileDataIterator, AuxlabelDataIterator, AbsLMDataIterator +from espnet_utils.load_lm_model import build_lm_model_from_file +from espnet_utils.numericalizer import get_numericalizer +from snowfall.models.lm_transformer import TransformerLM + +# TODO(Liyong Guo): types may need to be supported ['text', 'token', 'token_id'] +_TYPES_SUPPORTED = ['text_file', 'auxlabel'] + + +def _validate_input_type(input_type: Optional[str] = None): + # A valid input_type must be assigned from the client + assert input_type is not None + assert input_type in _TYPES_SUPPORTED + + +@dataclass(frozen=True) +class PPLResult: + nlls: List[float] + ntokens: int + nwords: int + + @property + def total_nll(self): + return sum(self.nlls) + + @property + def token_ppl(self): + return np.exp(self.total_nll / self.ntokens) + + @property + def word_ppl(self): + return np.exp(self.total_nll / self.nwords) + + +class NNLMEvaluator(object): + + @torch.no_grad() + def nll(self, text_source): + nlls = [] + total_nll = 0.0 + total_ntokens = 0 + total_nwords = 0 + for xs_pad, target_pad, word_lens, token_lens in self.dataset( + text_source): + xs_pad = xs_pad.to(self.device) + target_pad = target_pad.to(self.device) + nll = self.lm.nll(xs_pad, target_pad, token_lens) + nll = nll.detach().cpu().numpy().sum(1) + nlls.extend(nll) + total_ntokens += sum(token_lens) + total_nwords += sum(word_lens) + ppl_result = PPLResult(nlls=nlls, + ntokens=total_ntokens, + nwords=total_nwords) + return ppl_result + + +@dataclass +class EspnetNNLMEvaluator(NNLMEvaluator): + lm: TransformerLM + dataset: AbsLMDataIterator + device: Union[str, torch.device] + + @classmethod + def build_model(cls, + lm_train_config, + lm_model_file, + device='cpu', + input_type='text_file', + batch_size=32, + numericalizer=None): + _validate_input_type(input_type) + lm_model_file = lm_model_file + train_args = load_espnet_model_config(lm_train_config) + + lm_config = copy.deepcopy(train_args.lm_conf) + lm_config['vocab_size'] = len(train_args.token_list) + + model = build_lm_model_from_file(config=lm_config, + model_file=lm_model_file, + model_type='espnet') + model.to(device) + + if numericalizer is None: + numericalizer = get_numericalizer( + tokenizer_type='spm', + tokenizer_file=train_args.bpemodel, + token_list=train_args.token_list) + dataset_option = DatasetOption(input_type=input_type, + preprocessor=numericalizer) + + if input_type == 'text_file': + dataset = TextFileDataIterator(dataset_option) + elif input_type == 'auxlabel': + dataset = AuxlabelDataIterator(dataset_option, + numericalizer=numericalizer) + + evaluator = EspnetNNLMEvaluator(lm=model, + dataset=dataset, + device=device) + return evaluator diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/numericalizer.py b/egs/librispeech/asr/simple_v1/espnet_utils/numericalizer.py new file mode 100644 index 00000000..4f2ce503 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/espnet_utils/numericalizer.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterable, List, Optional, Union +from pathlib import Path + +import numpy as np + +from torchtext.data.functional import load_sp_model + + +class NumericalizerMixin(ABC): + + def _assign_special_symbols(self): + # and share same index for model download from espnet model zoo + assert '' in self.token2idx \ + or ('' in self.token2idx and '' in self.tokenid) + assert '' in self.token2idx + self.sos_idx = self.token2idx[ + ''] if '' in self.token2idx else self.token2idx[ + ''] + self.eos_idx = self.token2idx[ + ''] if '' in self.token2idx else self.token2idx[ + ''] + self.unk_idx = self.token2idx[''] + + +@dataclass +class SpmNumericalizer(NumericalizerMixin): + + def __init__(self, + tokenizer_type, + tokenizer_file, + token_list, + unk_symbol=''): + assert tokenizer_type == 'spm' + self.tokenizer_file = tokenizer_file + self.token_list = token_list + self._token2idx = None + self._tokenizer = None + self._assign_special_symbols() + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = load_sp_model(self.tokenizer_file) + return self._tokenizer + + def text2tokens(self, line: str) -> List[str]: + return self.tokenizer.EncodeAsPieces(line) + + def tokens2text(self, tokens: Iterable[str]) -> str: + return self.tokenizer.DecodePieces(list(tokens)) + + @property + def token2idx(self): + if self._token2idx is None: + self._token2idx = {} + for idx, token in enumerate(self.token_list): + if token in self._token2idx: + raise RuntimeError(f'Symbol "{token}" is duplicated') + self._token2idx[token] = idx + + return self._token2idx + + def ids2tokens(self, integers: Union[np.ndarray, + Iterable[int]]) -> List[str]: + if isinstance(integers, np.ndarray) and integers.ndim != 1: + raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") + return [self.token_list[i] for i in integers] + + def __call__(self, text: str) -> List[int]: + tokens = self.text2tokens(text) + token_idxs = [self.sos_idx] + [ + self.token2idx.get(token, self.unk_idx) for token in tokens + ] + [self.eos_idx] + return token_idxs + + +def get_numericalizer( + tokenizer_type, + tokenizer_file, + token_list, +): + if tokenizer_type == 'spm': + numericalizer = SpmNumericalizer(tokenizer_type=tokenizer_type, + tokenizer_file=tokenizer_file, + token_list=token_list) + elif tokenizer_type == 'huggingface': + raise NotImplementedError(f'{token_type} is to be supported') + else: + raise ValueError(f'Unsupported tokenizer type {token_type}') + + return numericalizer diff --git a/egs/librispeech/asr/simple_v1/espnet_utils/text_dataset.py b/egs/librispeech/asr/simple_v1/espnet_utils/text_dataset.py new file mode 100644 index 00000000..04f9f131 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/espnet_utils/text_dataset.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Union + +import k2 +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence + +from espnet_utils.numericalizer import SpmNumericalizer + +AnyPreProcessor = Union['SpmPreProcessor'] + + +def auxlabel_to_word(word_seqs: k2.RaggedInt, numericalizer=None) -> List[str]: + assert numericalizer is not None + utts = [] + token_ints = k2.ragged.to_list(word_seqs) + for token_int in token_ints: + token = numericalizer.ids2tokens(token_int) + text = numericalizer.tokens2text(token) + utts.append(text) + + return utts + + +class CollateFunc(object): + '''Collate function for LMDataset + ''' + + def __init__(self, pad_index=None): + # pad_index should be identical to ignore_index of torch.nn.NLLLoss + # and padding_idx in torch.nn.Embedding + self.pad_index = pad_index + + def __call__(self, batch: List[List[int]]): + ''' + batch is a ragged 2-d array, with a row + represents a tokenized text, whose format is: + token_id token_id token_id *** + ''' + # data_pad: [batch_size, max_seq_len] + # max_seq_len == len(max(batch, key=len)) + data_pad = pad_sequence( + [torch.from_numpy(np.array(x)).long() for x in batch], True, + self.pad_index) + data_pad = data_pad.contiguous() + xs_pad = data_pad[:, :-1].contiguous() + ys_pad = data_pad[:, 1:].contiguous() + # xs_pad/ys_pad: [batch_size, max_seq_len - 1] # - 1 for removing or + return xs_pad, ys_pad + + +@dataclass +class DatasetOption: + preprocessor: AnyPreProcessor + input_type: Optional[str] = 'text_file' + batch_size: int = 32 + pad_value: int = 0 + + +@dataclass +class AbsLMDataIterator(ABC): + preprocessor: AnyPreProcessor + input_type: Optional[str] = 'text_file' + batch_size: int = 32 + pad_value: int = 0 + words_txt: Optional[Path] = None + _collate_fn = None + + @property + def collate_fn(self): + if self._collate_fn is None: + self._collate_fn = CollateFunc(self.pad_value) + return self._collate_fn + + def _reset_container(self): + self.token_ids_list = [] + self.token_lens = [] + self.word_lens = [] + + @abstractmethod + def _text_generator(self, text_source): + pass + + def __call__(self, text_source): + """ + Args: + text_source may be text_file / word_seqs + """ + self._reset_container() + for text in self._text_generator(text_source): + self.word_lens.append(len(text.split()) + 1) # +1 for + + token_ids = self.preprocessor(text) + self.token_ids_list.append(token_ids) + self.token_lens.append(len(token_ids) - 1) # -1 to remove + + if len(self.token_ids_list) == self.batch_size: + xs_pad, ys_pad = self.collate_fn(self.token_ids_list) + + yield xs_pad, ys_pad, self.word_lens, self.token_lens + self._reset_container() + + if len(self.token_ids_list) != 0: + xs_pad, ys_pad = self.collate_fn(self.token_ids_list) + yield xs_pad, ys_pad, self.word_lens, self.token_lens + self._reset_container() + + +class TextFileDataIterator(AbsLMDataIterator): + + def __init__(self, dataset_option): + super().__init__(**(dataset_option.__dict__)) + + def _text_generator(self, text_file): + with open(text_file, 'r') as f: + for text in f: + text = text.strip().split(maxsplit=1)[1] + yield text + + +class AuxlabelDataIterator(AbsLMDataIterator): + + def __init__(self, dataset_option, numericalizer): + super().__init__(**(dataset_option.__dict__)) + self.numericalizer = numericalizer + + def _text_generator(self, word_seqs): + # word_seqs --> text + texts = auxlabel_to_word(word_seqs, self.numericalizer) + for text in texts: + yield text diff --git a/egs/librispeech/asr/simple_v1/nnlm_nbest_rescore.sh b/egs/librispeech/asr/simple_v1/nnlm_nbest_rescore.sh new file mode 100644 index 00000000..9ef33c8f --- /dev/null +++ b/egs/librispeech/asr/simple_v1/nnlm_nbest_rescore.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +# Example of transformer LM n-best rescoring with espnet pretrained models. + +set -eou pipefail + +stage=0 + +if [ $stage -le 0 ]; then + # check test*.json are already generated + echo "check data prepration" + for test_set in test-clean test-other; do + if [ ! -f exp/data/cuts_${test_set}.json.gz ]; then + echo "Refer ./run.sh to generate manifest files, i.e. exp/data/*.gz." + exit 1 + fi + done +fi +if [ $stage -le 1 ]; then + # Download espnet pretrained models + # The original link of these models is: + # https://zenodo.org/record/4604066#.YKtNrqgzZPY + # which is accessible by espnet utils + # The are ported to following link for users who don't have espnet dependencies. + if [ ! -d snowfall_model_zoo ]; then + echo "About to download pretrained models." + git clone https://huggingface.co/GuoLiyong/snowfall_model_zoo + ln -sf snowfall_model_zoo/exp/kamo-naoyuki/ exp/ + fi + echo "Pretrained models are ready." + +fi + +if [ $stage -le 2 ]; then + echo "Start to recognize." + export CUDA_VISIBLE_DEVICES=3 + model_path=exp/kamo-naoyuki/librispeech_asr_train_asr_conformer6_n_fft512_hop_length256_raw_en_bpe5000_scheduler_confwarmup_steps40000_optim_conflr0.0025_sp_valid.acc.ave/ + python3 tokenizer_ctc_att_transformer_decode.py \ + --num_paths 100 \ + --asr_train_config $model_path/config.yaml \ + --asr_model_file $model_path/valid.acc.ave_10best.pth \ + --lm_train_config $model_path/lm/config.yaml \ + --lm_model_file $model_path/lm/valid.loss.ave_10best.pth + +fi diff --git a/egs/librispeech/asr/simple_v1/tokenizer_ctc_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/tokenizer_ctc_att_transformer_decode.py new file mode 100755 index 00000000..2131b9be --- /dev/null +++ b/egs/librispeech/asr/simple_v1/tokenizer_ctc_att_transformer_decode.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) +# Apache 2.0 + +import argparse +import logging +import os +import random +import re +import sys + +from pathlib import Path +from typing import Union + +import k2 +import numpy as np +import torch + +from kaldialign import edit_distance +from lhotse import load_manifest +from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.dataset.input_strategies import AudioSamples + +from espnet_utils.asr import ESPnetASRModel +from espnet_utils.nnlm_evaluator import EspnetNNLMEvaluator +from snowfall.common import store_transcripts +from snowfall.common import write_error_stats +from snowfall.decoding.lm_rescore import decode_with_lm_rescoring +from snowfall.training.ctc_graph import build_ctc_topo + + +def decode(dataloader: torch.utils.data.DataLoader, + model: None, + device: Union[str, torch.device], + ctc_topo: None, + G=None, + evaluator=None, + numericalizer=None, + num_paths=-1): + tot_num_cuts = len(dataloader.dataset.cuts) + num_cuts = 0 + results = [] + for batch_idx, batch in enumerate(dataloader): + assert isinstance(batch, dict), type(batch) + speech = batch['inputs'].squeeze() + lengths = batch['supervisions']['num_samples'] + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + # data: (Nsamples,) -> (1, Nsamples) + speech = speech.unsqueeze(0) + speech = speech.to(torch.device(device)) + lengths = lengths.to(torch.device(device)) + + nnet_output = model.encode(speech=speech, speech_lengths=lengths) + nnet_output = nnet_output.detach() + + blank_bias = -1.0 + nnet_output[:, :, 0] += blank_bias + + supervision_segments = torch.tensor([[0, 0, nnet_output.shape[1]]], + dtype=torch.int32) + + with torch.no_grad(): + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + + output_beam_size = 8 + lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0, + output_beam_size, 30, 10000) + + use_whole_lattice = False + best_paths = decode_with_lm_rescoring( + lattices, + G, + evaluator, + num_paths=num_paths, + use_whole_lattice=use_whole_lattice) + + token_int = list( + filter( + lambda x: x not in + [-1, 0, numericalizer.sos_idx, numericalizer.eos_idx], + best_paths.aux_labels.cpu().numpy())) + + token = numericalizer.ids2tokens(token_int) + + text = numericalizer.tokens2text(token) + + ref = batch['supervisions']['text'] + for i in range(len(ref)): + hyp_words = text.split(' ') + ref_words = ref[i].split(' ') + results.append((ref_words, hyp_words)) + if batch_idx % 10 == 0: + logging.info( + 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( + batch_idx, num_cuts, tot_num_cuts, + float(num_cuts) / tot_num_cuts * 100)) + num_cuts += 1 + return results + + +def get_parser(): + parser = argparse.ArgumentParser( + description="ASR Decoding with model from espnet model zoo", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("--seed", type=int, default=2021, help="Random seed") + + group = parser.add_argument_group("The model configuration related") + group.add_argument("--asr_train_config", type=str, required=True) + group.add_argument("--asr_model_file", type=str, required=True) + group.add_argument('--lm_train_config', type=str, required=True) + group.add_argument('--lm_model_file', type=str, required=True) + group.add_argument('--num_paths', type=int, required=True) + + return parser + + +def main(): + parser = get_parser() + logging.basicConfig(level=logging.DEBUG) + args = parser.parse_args() + asr_train_config = args.asr_train_config + asr_model_file = args.asr_model_file + seed = args.seed + + device = "cuda" + + # 1. Set random-seed + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + asr_model, numericalizer = ESPnetASRModel.build_model( + asr_train_config, asr_model_file, device) + + asr_model.eval() + + phone_ids_with_blank = [i for i in range(len(numericalizer.token_list))] + + exp_dir = Path('exp/') + ctc_path = exp_dir / 'ctc_topo.pt' + + if not os.path.exists(ctc_path): + logging.info("Generating ctc topo...") + ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) + torch.save(ctc_topo.as_dict(), ctc_path) + + else: + logging.info("Loading pre-compiled ctc topo fst") + d_ctc_topo = torch.load(ctc_path) + ctc_topo = k2.Fsa.from_dict(d_ctc_topo) + ctc_topo = ctc_topo.to(device) + + evaluator = EspnetNNLMEvaluator.build_model(args.lm_train_config, + args.lm_model_file, + device=device, + input_type='auxlabel', + numericalizer=numericalizer) + evaluator.lm.eval() + feature_dir = Path('exp/data') + + test_sets = ['test-clean', 'test-other'] + for test_set in test_sets: + cuts_test = load_manifest(feature_dir / f'cuts_{test_set}.json.gz') + sampler = SingleCutSampler(cuts_test, max_cuts=1) + + test = K2SpeechRecognitionDataset(cuts_test, + input_strategy=AudioSamples()) + test_dl = torch.utils.data.DataLoader(test, + batch_size=None, + sampler=sampler) + results = decode(dataloader=test_dl, + model=asr_model, + device=device, + ctc_topo=ctc_topo, + evaluator=evaluator, + numericalizer=numericalizer, + num_paths=args.num_paths) + + recog_path = exp_dir / f'recogs-{test_set}.txt' + store_transcripts(path=recog_path, texts=results) + logging.info(f'The transcripts are stored in {recog_path}') + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f'errs-{test_set}.txt' + with open(errs_filename, 'w') as f: + write_error_stats(f, test_set, results) + logging.info('Wrote detailed error stats to {}'.format(errs_filename)) + + dists = [edit_distance(r, h) for r, h in results] + errors = { + key: sum(dist[key] for dist in dists) + for key in ['sub', 'ins', 'del', 'total'] + } + total_words = sum(len(ref) for ref, _ in results) + # Print Kaldi-like message: + # %WER 2.62 [ 1380 / 52576, 176 ins, 106 del, 1098 sub ] + logging.info( + f'[{test_set}] %WER {errors["total"] / total_words:.2%} ' + f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' + ) + + +if __name__ == "__main__": + main() diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index 1bfff333..87db9576 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -69,7 +69,7 @@ def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, ''' device = lats.device assert len(lats.shape) == 3 - assert hasattr(lats, 'lm_scores') + # assert hasattr(lats, 'lm_scores') # k2.compose() currently does not support b_to_a_map. To void # replicating `lats`, we use k2.intersect_device here. @@ -94,8 +94,9 @@ def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device) - # The `scores` of every arc consists of `am_scores` and `lm_scores` - am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores + if hasattr(am_path_lats, 'lm_scores'): + # The `scores` of every arc consists of `am_scores` and `lm_scores` + am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores am_scores = am_path_lats.get_tot_scores(True, True) @@ -103,7 +104,7 @@ def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, @torch.no_grad() -def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, +def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, evaluator: None, num_paths: int) -> k2.Fsa: '''Decode using n-best list with LM rescoring. @@ -127,15 +128,19 @@ def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, An FsaVec representing the best decoding path for each sequence in the lattice. ''' + assert G is not None or evaluator is not None, 'Neither of G nor neural lm is available!' + # Todo(Guo Liyong): Still figuring out how to combine n-gram and neural language models + assert not (G is not None and evaluator is not None), \ + 'Both G and neural lm are available! Please assign either of them to None' device = lats.device assert len(lats.shape) == 3 assert hasattr(lats, 'aux_labels') - assert hasattr(lats, 'lm_scores') - assert G.shape == (1, None, None) - assert G.device == device - assert hasattr(G, 'aux_labels') is False + if G is not None: + assert G.shape == (1, None, None) + assert G.device == device + assert hasattr(G, 'aux_labels') is False # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] @@ -187,12 +192,16 @@ def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, # Now compute lm_scores b_to_a_map = torch.zeros_like(path_to_seq_map) - lm_path_lats = _intersect_device(G, - word_fsas_with_epsilon_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True) - lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device) - lm_scores = lm_path_lats.get_tot_scores(True, True) + if G is not None: + lm_path_lats = _intersect_device(G, + word_fsas_with_epsilon_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True) + lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device) + lm_scores = lm_path_lats.get_tot_scores(True, True) + elif evaluator is not None: + ppl_result = evaluator.nll(unique_word_seqs) + lm_scores = - torch.tensor(ppl_result.nlls).to(am_scores.device) tot_scores = am_scores + lm_scores @@ -297,7 +306,7 @@ def rescore_with_whole_lattice(lats: k2.Fsa, @torch.no_grad() -def decode_with_lm_rescoring(lats: k2.Fsa, G: k2.Fsa, num_paths: int, +def decode_with_lm_rescoring(lats: k2.Fsa, G: k2.Fsa, evaluator: None, num_paths: int, use_whole_lattice: bool) -> k2.Fsa: '''Decode using n-best list with LM rescoring. @@ -328,4 +337,4 @@ def decode_with_lm_rescoring(lats: k2.Fsa, G: k2.Fsa, num_paths: int, if use_whole_lattice: return rescore_with_whole_lattice(lats, G) else: - return rescore_with_n_best_list(lats, G, num_paths) + return rescore_with_n_best_list(lats, G, evaluator, num_paths) diff --git a/snowfall/models/conformer.py b/snowfall/models/conformer.py index 4bf921b2..48ab0a42 100644 --- a/snowfall/models/conformer.py +++ b/snowfall/models/conformer.py @@ -36,7 +36,8 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, vgg_frontend: bool = False) -> None: + normalize_before: bool = True, vgg_frontend: bool = False, + is_espnet_structure: bool = False) -> None: super(Conformer, self).__init__(num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor, d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, @@ -44,8 +45,12 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int self.encoder_pos = RelPositionalEncoding(d_model, dropout) - encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before) + encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before, is_espnet_structure) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + self.is_espnet_structure = is_espnet_structure + if self.normalize_before and self.is_espnet_structure: + self.after_norm = nn.LayerNorm(d_model) def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor, Optional[Tensor]]: """ @@ -65,6 +70,9 @@ def encode(self, x: Tensor, supervisions: Optional[Dict] = None) -> Tuple[Tensor mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask != None else None x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + if self.normalize_before and self.is_espnet_structure: + x = x.permute(1, 0, 2) + x = self.after_norm(x) return x, mask @@ -90,9 +98,10 @@ class ConformerEncoderLayer(nn.Module): """ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, - cnn_module_kernel: int = 31, normalize_before: bool = True) -> None: + cnn_module_kernel: int = 31, normalize_before: bool = True, + is_espnet_structure=False) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -319,7 +328,8 @@ class RelPositionMultiheadAttention(nn.Module): >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) """ - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None: + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0., + is_espnet_structure: bool = False) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -338,6 +348,7 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None: self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self._reset_parameters() + self.is_espnet_structure = is_espnet_structure def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) @@ -538,7 +549,8 @@ def multi_head_attention_forward(self, query: Tensor, _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - q = q * scaling + if not self.is_espnet_structure: + q = q * scaling if attn_mask is not None: assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ @@ -596,7 +608,10 @@ def multi_head_attention_forward(self, query: Tensor, matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2) + if not self.is_espnet_structure: + attn_output_weights = (matrix_ac + matrix_bd) # (batch, head, time1, time2) + else: + attn_output_weights = (matrix_ac + matrix_bd) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) diff --git a/snowfall/models/lm_transformer.py b/snowfall/models/lm_transformer.py new file mode 100644 index 00000000..97314e68 --- /dev/null +++ b/snowfall/models/lm_transformer.py @@ -0,0 +1,267 @@ +from typing import Any +from typing import List +from typing import Tuple + +import math +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from snowfall.models.transformer import generate_square_subsequent_mask +from snowfall.models.transformer import TransformerEncoderLayer + + +# modified from: +# https://github.com/espnet/espnet/blob/dab2092bc9c8e184c48cc6e603037333bd97dcd1/espnet/nets/pytorch_backend/nets_utils.py#L64 +def make_pad_mask(lengths): + """Make mask tensor containing indices of padded part. + Args: + lengths (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor containing indices of padded part. + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + if not isinstance(lengths, list): + lengths = lengths.tolist() + + maxlen = int(max(lengths)) + bs = int(len(lengths)) + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float( + numpy.finfo(torch.tensor( + 0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, + dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, + query, + key, + value, + key_padding_mask=None, + attn_mask=None): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, attn_mask), None + + +class LMEncoder(nn.Module): + + def __init__( + self, + attention_dim=512, + attention_heads=8, + attention_dropout_rate=0.0, + num_blocks=16, + dim_feedforward=2048, + normalize_before=True, + ): + super().__init__() + + self.normalize_before = normalize_before + + encoder_layer = TransformerEncoderLayer( + d_model=attention_dim, + custom_attn=MultiHeadedAttention(attention_heads, attention_dim, + attention_dropout_rate), + nhead=attention_heads, + dim_feedforward=dim_feedforward, + normalize_before=True, + dropout=attention_dropout_rate, + ) + + self.encoders = nn.TransformerEncoder(encoder_layer, num_blocks, None) + + if self.normalize_before: + self.after_norm = nn.LayerNorm(attention_dim) + + def forward(self, xs, masks): + # xs: [batch_size, max_seq_len] + # masks: [1, max_seq_len, max_seq_len], looks like + # tensor([[[ True, False, False, ..., False, False, False], + # [ True, True, False, ..., False, False, False], + # [ True, True, True, ..., False, False, False], + # ..., + # [ True, True, True, ..., True, False, False], + # [ True, True, True, ..., True, True, False], + # [ True, True, True, ..., True, True, True]]]) + + import numpy as np + np.save('xs_before_encoders', xs.cpu().numpy()) + xs = self.encoders(xs, masks) + np.save('xs_after_encoders', xs.cpu().numpy()) + if self.normalize_before: + xs = self.after_norm(xs) + return xs + + +class TransformerLM(nn.Module): + + def __init__( + self, + vocab_size: int, + pos_enc: str = None, + embed_unit: int = 128, + att_unit: int = 512, + head: int = 8, + unit: int = 2048, + layer: int = 16, + dropout_rate: float = 0.0, + ignore_id: int = 0, + ): + super().__init__() + + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.ignore_id = ignore_id + + self.embed = nn.Embedding(vocab_size, embed_unit) + self.input_embed = nn.Sequential( + nn.Linear(embed_unit, att_unit), + nn.LayerNorm(att_unit), + nn.Dropout(dropout_rate), + nn.ReLU(), + ) + self.encoder = LMEncoder(attention_dim=att_unit, + attention_heads=head, + num_blocks=layer, + dim_feedforward=unit) + self.decoder = nn.Linear(att_unit, vocab_size) + + def forward( + self, + input: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + import numpy as np + # input: [batch_size, max_seq_len] + x = self.embed(input) + np.save('embed_x', x.cpu().numpy()) + x = self.input_embed(x) + np.save('input_embed_x', x.cpu().numpy()) + mask = (generate_square_subsequent_mask( + input.shape[-1]) == 0).unsqueeze(0).to(x.device) + h = self.encoder(x, mask) + np.save('h', h.cpu().numpy()) + y = self.decoder(h) + # y: [batch_size, max_seq_len, vocab_size] + return y + + def nll(self, xs_pad, target_pad, token_lens): + # xs_pad/ys_pad: [batch_size, max_seq_len] + # max_seq_len == max(len([ token token token ... token]) + # == max(len([token token token ... token ]) + y = self.forward(xs_pad) + # nll: (batch_size * max_seq_len,) + nll = F.cross_entropy(y.view(-1, y.shape[-1]), + target_pad.view(-1), + reduction="none") + # assign padded postion with 0.0 + nll.masked_fill_( + make_pad_mask(token_lens).to(nll.device).view(-1), 0.0) + + # nll: (batch_size * max_seq_len,) -> (batch_size, max_seq_len) + nll = nll.view(xs_pad.size(0), -1) + return nll diff --git a/snowfall/models/transformer.py b/snowfall/models/transformer.py index 22583257..f97fd428 100644 --- a/snowfall/models/transformer.py +++ b/snowfall/models/transformer.py @@ -187,9 +187,15 @@ class TransformerEncoderLayer(nn.Module): """ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, - activation: str = "relu", normalize_before: bool = True) -> None: + activation: str = "relu", normalize_before: bool = True, + custom_attn=None) -> None: super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + + if custom_attn is not None: + self.self_attn = custom_attn + else: + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout)