From a59abcf1d6db405925f749f06bd42a152f4e7145 Mon Sep 17 00:00:00 2001 From: Tianshuo Deng Date: Thu, 21 Nov 2024 16:24:13 -0800 Subject: [PATCH 1/2] merge for comprehension when filtering parameters without grad --- model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/model.py b/model.py index c698f8b601..d4de95d679 100644 --- a/model.py +++ b/model.py @@ -261,10 +261,8 @@ def from_pretrained(cls, model_type, override_args=None): return model def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): - # start with all of the candidate parameters - param_dict = {pn: p for pn, p in self.named_parameters()} - # filter out those that do not require grad - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # All parameters that requires grad. + param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] From eabe78b6491adc06548ddf95a295e6996b628991 Mon Sep 17 00:00:00 2001 From: Tianshuo Deng Date: Thu, 21 Nov 2024 16:26:10 -0800 Subject: [PATCH 2/2] update comments --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index d4de95d679..d7bb22a8ba 100644 --- a/model.py +++ b/model.py @@ -261,7 +261,7 @@ def from_pretrained(cls, model_type, override_args=None): return model def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): - # All parameters that requires grad. + # Filter out parameters that does not require grad. param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.