Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight decay API maybe unintuitive #241

Open
albertz opened this issue Nov 11, 2022 · 4 comments
Open

Weight decay API maybe unintuitive #241

albertz opened this issue Nov 11, 2022 · 4 comments
Assignees
Milestone

Comments

@albertz
Copy link
Member

albertz commented Nov 11, 2022

I'm not sure if we really recommend it like that anywhere, but I think it's natural to write code like this:

for p in net.parameters():
  p.weight_decay = 0.0001

I noticed, this has several problems:

  • What about auxiliary parameters? You probably don't want weight decay on them. Same as any integer of boolean parameters.
    • I think actually it would be ignored by RETURNN, so maybe it's not a problem? Or we could also just ignore it silently on returnn-common side to allow for such code?
  • Some variables maybe should not be decayed:
    • In LayerNorm, WeightNorm etc, the scale parameter, which is initialized at 1. Any decay should move it towards 1 and not towards 0. (Right?) In Lingvo, you actually find (here) that weight norm is reparameterized as (1 + g) instead of just g, to avoid this problem.
      • We could rewrite any such code to also use such reparameterization. Which is maybe a good thing but maybe not?
      • We could add some additional information, like decay_center or so, and the constraint would not be $w^2$ but $(w-c)^2$ instead, such that any configured weight decay would move it towards the configured center. This would need some extra implementation also on RETURNN side.
      • We could also add some flag Parameter.ignore_weight_decay on returnn-common side, and if that is enabled (via the module such as LayerNorm), it ignores any writes to weight_decay.
    • I'm not sure if a decay on biases is good or not.

Many of the arguments are to actually allow for the simple code above. Or maybe we don't want to allow such simple code? But how exactly would the canonical example of weight decay applied on some generic network look like then?

@albertz albertz added this to the first-release milestone Nov 11, 2022
@albertz albertz mentioned this issue Nov 11, 2022
@albertz
Copy link
Member Author

albertz commented Nov 11, 2022

Maybe another canonical example would be to use L2 only on selected modules and only selected parameters:

for m in model.modules():
  if isinstance(m, nn.Linear):
    m.weight.weight_decay = 0.01
  ...

Also see some related PyTorch discussion:
https://discuss.pytorch.org/t/weight-decay-only-for-weights-of-nn-linear-and-nn-conv/114348

Also see this insightful code from Karpathy: https://github.com/karpathy/minGPT/blob/3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150/mingpt/model.py#L136

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add('pos_emb')

        # validate that we considered every parameter
        ...

Further, see:

https://stats.stackexchange.com/questions/576463/why-not-perform-weight-decay-on-layernorm-embedding
https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L65
karpathy/minGPT#24 (comment)

albertz added a commit to rwth-i6/i6_experiments that referenced this issue Nov 12, 2022
@michelwi
Copy link

We could add some additional information, like decay_center or so, and the constraint would not be w2 but (w−c)2 instead, such that any configured weight decay would move it towards the configured center. This would need some extra implementation also on RETURNN side.

+1 for this approach. Sometimes (e.g. fine-tuning) we would like to specify an initial model that should be tuned but not moved away from too far. So we could define the baseline as decay_center and decay back towards it.

@albertz
Copy link
Member Author

albertz commented Nov 14, 2022

Via rwth-i6/returnn#1214 and #90, we can also do this in a very generic way, which allows the user to do some logic like decay_center, but also potentially any other logic.

@albertz
Copy link
Member Author

albertz commented Nov 14, 2022

What about auxiliary parameters? You probably don't want weight decay on them. Same as any integer of boolean parameters.

Note that in PyTorch, net.parametes() really only returns the trainable model parameters. Auxiliary parameters are not called "parameters" but "buffers" instead, and are handled via register_buffer in PyTorch.

So, should we maybe change our net.parameters() to exclude auxiliary parameters? Or make an explicit flag for it, like include_auxiliary, which is maybe False by default?

Atticus1806 pushed a commit to rwth-i6/i6_experiments that referenced this issue Nov 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants