Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create good Conformer baselines #233

Open
albertz opened this issue Nov 2, 2022 · 13 comments
Open

Create good Conformer baselines #233

albertz opened this issue Nov 2, 2022 · 13 comments
Assignees
Milestone

Comments

@albertz
Copy link
Member

albertz commented Nov 2, 2022

On some internal data, and maybe also Switchboard and Librispeech.

Using nn.Conformer. Making it somewhat more standard if possible, and then deviate from it when it makes sense.

Also compare it to earlier Conformer recipes, and earlier BLSTM recipes. Make sure the conditions are sane for comparison, e.g. same number of epochs.

When we have that, we should also change our the Conformer defaults to sth reasonable.

I think our earlier Conformer recipes (there are several variants floating around in our group...) are somewhat non-standard:

  • Check the frontend. Sometimes we use BLSTM, sometimes convolutions. Our conv-based frontends are also different to what is standard. Although the standard frontend probably uses too high dimensions, so that is maybe one thing to deviate from. Also see Conformer frontend should fix dimensions, be more standard #219.
  • Our earlier Conformer recipes used the old-style relative pos encoding, while the standard Conformer uses the rel pos enc from Transformer-XL. This is already implemented and the default in nn.Conformer but we never really compared it systematically, and also the nn.Conformer is not really well tested yet. See our wiki on relative positional encoding for further references.

References, and related:

@braddockcg
Copy link

I'll get right on it. May be diverted to other tasks but will try to find the time.

@albertz
Copy link
Member Author

albertz commented Nov 2, 2022

Note that I'm simultaneously also working on this. But there are so many different things to test here that this should not be a problem. Specifically, my current setting is a BPE-based monotonic transducer (RNA-like) on Switchboard, and I compare some old Conformer config vs nn.Conformer (with different settings) vs some BLSTM. I assume you would start with a hybrid NN-HMM? Or maybe a CTC-based model?

@albertz
Copy link
Member Author

albertz commented Nov 2, 2022

I noticed that the nn.SelfAttention is a bit different to SelfAttentionLayer: SelfAttentionLayer does not have biases for the qkv and proj linear projections, while nn.SelfAttention currently has. I opened a separate issue for that: #234.
I also added the option with_bias now, so you can play around with it.

@albertz
Copy link
Member Author

albertz commented Nov 2, 2022

We also should check param init. And also look at other frameworks code. E.g. here in Fairseq:
https://github.com/facebookresearch/fairseq/blob/b4001184f49ed0e20d619b54bb3d43088fabf990/fairseq/modules/multihead_attention.py#L168-L177

@albertz albertz self-assigned this Nov 2, 2022
@albertz
Copy link
Member Author

albertz commented Nov 2, 2022

On rel pos enc, I'm collecting some overview here: https://github.com/rwth-i6/returnn_common/wiki/Relative-positional-encoding
Also see #235.
Also see the docstring of RelPosSelfAttention.

@braddockcg
Copy link

@albertz a few questions:

  • You mention "our earlier Conformer recipes (there are several variants floating around in our group...)". Where can I find them? Which are the best starting points?
  • When you specify "more standard", what implementation is "standard"? The "official" Conformer implementation is at https://github.com/pengzhiliang/Conformer should that be our baseline?
  • Did you want me to run the same dataset on returnn also on the "official" conformer implementation for comparison?
  • You mention running "on some internal data" - point me to it please.
  • What GPU resources are available? Should I be running on the AppTek cluster?

@albertz
Copy link
Member Author

albertz commented Nov 7, 2022

You mention "our earlier Conformer recipes (there are several variants floating around in our group...)". Where can I find them? Which are the best starting points?

You will find a couple of recipes on returnn-experiments, for example:

https://github.com/rwth-i6/returnn-experiments/blob/master/2022-swb-conformer-hybrid-sat/table_1_and_2/ln_instead_of_bn.config#L125

I also have an adopted variant where I embed this old net dict in returnn-common here:

https://github.com/rwth-i6/i6_experiments/blob/main/users/zeyer/experiments/exp2022_07_21_transducer/exp_fs_base/old_nick_att_conformer_lrs2.py

I was also able to fully replicate this config now in pure returnn-common. That means, I wrote a checkpoint converter script and verified that I get exactly the same outputs after every layer:

https://github.com/rwth-i6/i6_experiments/blob/main/users/zeyer/experiments/exp2022_07_21_transducer/exp_fs_base/conformer_import_old_nick_att_conformer_lrs2.py

@albertz
Copy link
Member Author

albertz commented Nov 7, 2022

When you specify "more standard", what implementation is "standard"? The "official" Conformer implementation is at https://github.com/pengzhiliang/Conformer should that be our baseline?

By more standard I mean following what the original paper says, and the most popular implementations.

The repo you linked is totally off-topic and unrelated here.

This is the Conformer paper: https://arxiv.org/abs/2005.08100

There is no official public implementation. This was implemented by Google and the code is private. Although maybe you find some derivate now in Lingvo (https://github.com/tensorflow/lingvo)? It might be a good idea to check that.

I mostly refer to the ESPnet implementation, which is the most popular, public and widely used implementation, as far as I know. Starting points:

https://github.com/espnet/espnet/tree/master/egs2/librispeech/asr1
https://github.com/espnet/espnet/blob/b008ac7d58e9ced1a9f8c89cc85ee69d9e9461ab/espnet/nets/pytorch_backend/conformer/convolution.py
https://github.com/espnet/espnet/blob/b008ac7d58e9ced1a9f8c89cc85ee69d9e9461ab/espnet/nets/pytorch_backend/conformer/encoder_layer.py
https://github.com/espnet/espnet/blob/a65cc78de7e18c867f4be5fc0b9b695875c78c70/espnet/nets/pytorch_backend/transformer/attention.py
https://github.com/espnet/espnet/blob/b008ac7d58e9ced1a9f8c89cc85ee69d9e9461ab/espnet/nets/pytorch_backend/transformer/embedding.py
https://github.com/espnet/espnet/blob/b008ac7d58e9ced1a9f8c89cc85ee69d9e9461ab/espnet2/asr/encoder/conformer_encoder.py
https://github.com/espnet/espnet/blob/master/egs2/librispeech/asr1/conf/tuning/train_asr_conformer10_hop_length160.yaml

The implementation in Fairseq is also based on the ESPnet implementation.

@albertz
Copy link
Member Author

albertz commented Nov 7, 2022

Did you want me to run the same dataset on returnn also on the "official" conformer implementation for comparison?

I didn't thought too much about this yet. Maybe it makes sense when you want to really see the exact differences.

First step to really know that we implemented exactly the same model (or rather: we can configure our model such that it exactly matches some ESPnet variant) is to import the model parameters, and verify we get the same exact outputs, layer by layer. This is some work but mostly straightforward and very systematic. And this already usually leads to lots of interesting insights of differences that we did not really realize before.

Next step is to see that we get similar training behavior, in terms of learning curves and all numbers. This is trickier because this cannot really be checked systematically anymore. For this, you need to match the param init scheme, optimizer settings, other regularization settings, etc. We never really did this but actually this would be a very interesting experiment, because some people have observed that there are big differences in training behavior.

@albertz
Copy link
Member Author

albertz commented Nov 7, 2022

You mention running "on some internal data" - point me to it please.

Better ask Eugen (@curufinwe) what datasets he recommends.

What GPU resources are available? Should I be running on the AppTek cluster?

Also better ask Eugen about this.

@albertz
Copy link
Member Author

albertz commented Nov 7, 2022

Btw, you can see all my recent Conformer experiments here: https://github.com/rwth-i6/i6_experiments/tree/main/users/zeyer/experiments/exp2022_07_21_transducer/exp_fs_base

Basically I create a new file there for every experiment I run.

Also see the readme there with some notes.

@albertz
Copy link
Member Author

albertz commented Nov 13, 2022

I noticed that we do not have dropout after the self-attention:

    # MHSA
    x_mhsa_ln = self.self_att_layer_norm(x_ffn1_out)
    x_mhsa = self.self_att(x_mhsa_ln, axis=spatial_dim)
    x_mhsa_out = x_mhsa + x_ffn1_out

This is different to the standard Transformer.
This is also different to the paper.

@albertz
Copy link
Member Author

albertz commented Nov 16, 2022

Regarding param init, in ESPnet, there is almost no code at all for this, meaning it uses the PyTorch defaults. Mostly this is via Linear but also Conv1d for the convolutional module. Only in the rel pos self-att, it uses:

        torch.nn.init.xavier_uniform_(self.pos_bias_u)
        torch.nn.init.xavier_uniform_(self.pos_bias_v)

for the biases, which are directly created there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants