-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weight decay API maybe unintuitive #241
Comments
Maybe another canonical example would be to use L2 only on selected modules and only selected parameters: for m in model.modules():
if isinstance(m, nn.Linear):
m.weight.weight_decay = 0.01
... Also see some related PyTorch discussion: Also see this insightful code from Karpathy: https://github.com/karpathy/minGPT/blob/3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150/mingpt/model.py#L136 def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add('pos_emb')
# validate that we considered every parameter
... Further, see: https://stats.stackexchange.com/questions/576463/why-not-perform-weight-decay-on-layernorm-embedding |
+1 for this approach. Sometimes (e.g. fine-tuning) we would like to specify an initial model that should be tuned but not moved away from too far. So we could define the baseline as |
Via rwth-i6/returnn#1214 and #90, we can also do this in a very generic way, which allows the user to do some logic like |
Note that in PyTorch, So, should we maybe change our |
I'm not sure if we really recommend it like that anywhere, but I think it's natural to write code like this:
I noticed, this has several problems:
LayerNorm
,WeightNorm
etc, thescale
parameter, which is initialized at 1. Any decay should move it towards 1 and not towards 0. (Right?) In Lingvo, you actually find (here) that weight norm is reparameterized as(1 + g)
instead of justg
, to avoid this problem.decay_center
or so, and the constraint would not beParameter.ignore_weight_decay
on returnn-common side, and if that is enabled (via the module such asLayerNorm
), it ignores any writes toweight_decay
.Many of the arguments are to actually allow for the simple code above. Or maybe we don't want to allow such simple code? But how exactly would the canonical example of weight decay applied on some generic network look like then?
The text was updated successfully, but these errors were encountered: