Skip to content

Commit

Permalink
Option to modify learning rate by param groups
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Oct 29, 2023
1 parent afbba83 commit b215dae
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 47 deletions.
152 changes: 152 additions & 0 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,155 @@ def load_state_dict(self, sd):

self.num_iters = sd["num_iters"]
self.step(self.num_iters)


class GroupedAnnealingLR(object):
"""Anneals the learning rate."""

def __init__(
self,
optimizer,
start_lr,
warmup_iter,
total_iters,
decay_style,
last_iter,
min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
use_mup=False,
lr_param_groups_config=None,
):

# Class values.
self.optimizer = optimizer
self.start_lr = start_lr
self.min_lr = min_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter
self.end_iter = total_iters
assert self.end_iter > 0
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
self.use_mup = use_mup
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, (
"both override and " "use-checkpoint are set."
)
# Set the learning rate
self.lr_param_groups_config = lr_param_groups_config
self.step(self.num_iters)

print_rank_0("> learning rate decay style: {}".format(self.decay_style))

def get_lr(self, decay_style, start_lr, min_lr, wamrup_iter, end_iter):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

num_iters_ = self.num_iters
# Warmup.
if wamrup_iter > 0 and self.num_iters <= wamrup_iter:
return float(start_lr) * num_iters_ / wamrup_iter

num_iters_ = num_iters_ - wamrup_iter
if decay_style == "linear":
lr = start_lr * (end_iter - num_iters_) / end_iter
elif decay_style == "cosine":
end_iter_ = end_iter - wamrup_iter
lr = min_lr + (
(start_lr-min_lr)
/ 2.0
* (math.cos(math.pi * num_iters_ / end_iter_) + 1)
)
elif decay_style == "exponential":
# exp(-0.693) = 1/2
lr = start_lr * math.exp(-0.693 * num_iters_ / end_iter)
else:
lr = start_lr
return max(lr, min_lr)

def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None:
step_num = self.num_iters + 1
self.num_iters = step_num
for group in self.optimizer.param_groups:
if self.lr_param_groups_config is not None and group["name"] in self.lr_param_groups_config.keys():
new_lr = self.get_lr(
self.lr_param_groups_config[group["name"]]["decay_style"],
self.lr_param_groups_config[group["name"]]["start_lr"],
self.lr_param_groups_config[group["name"]]["min_lr"],
self.lr_param_groups_config[group["name"]]["warmup_iter"],
self.lr_param_groups_config[group["name"]]["end_iter"],
)
else:
new_lr = self.get_lr(
self.decay_style,
self.start_lr,
self.min_lr,
self.warmup_iter,
self.end_iter,
)
if self.use_mup and "width_mult" in group:
group["lr"] = new_lr / group["width_mult"]
else:
group["lr"] = new_lr

def state_dict(self):
state_dict = {
"start_lr": self.start_lr,
"warmup_iter": self.warmup_iter,
"num_iters": self.num_iters,
"decay_style": self.decay_style,
"end_iter": self.end_iter,
"min_lr": self.min_lr,
"lr_param_groups_config": self.lr_param_groups_config,
}
return state_dict

def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler:
print_rank_0(" > overriding {} value to {}".format(name, cls_value))
return cls_value

if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, (
"AnnealingLR: class input value"
"and checkpoint values for {} do not match".format(name)
)
print_rank_0(" > using checkpoint value {} for {}".format(sd_value, name))
return sd_value

def load_state_dict(self, sd):

self.start_lr = self._check_and_set(
self.start_lr, sd["start_lr"], "learning rate"
)
self.min_lr = self._check_and_set(
self.min_lr, sd["min_lr"], "minimum learning rate"
)
self.warmup_iter = self._check_and_set(
self.warmup_iter, sd["warmup_iter"], "warmup iterations"
)
self.end_iter = self._check_and_set(
self.end_iter, sd["end_iter"], "total number of iterations"
)
self.decay_style = self._check_and_set(
self.decay_style, sd["decay_style"], "decay style"
)
if sd["lr_param_groups_config"] is not None and self.lr_param_groups_config is None:
self.lr_param_groups_config = {}
for sd_group in sd["lr_param_groups_config"].keys():
if sd_group not in self.lr_param_groups_config.keys():
self.lr_param_groups_config[sd_group] = {}
self.lr_param_groups_config[sd_group]["start_lr"] = self._check_and_set(
self.lr_param_groups_config[sd_group]["start_lr"],
sd["lr_param_groups_config"][sd_group]["start_lr"],
"learning rate",
)

self.num_iters = sd["num_iters"]
self.step(self.num_iters)
21 changes: 13 additions & 8 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,15 @@ def add_to_logging(name):
)

# write losses, lr, etc. every step
tb_wandb_log(
"train/learning_rate",
learning_rate,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)
for param_group_key in learning_rate.keys():
tb_wandb_log(
f"train/learning_rate_{param_group_key}",
learning_rate[param_group_key],
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
)

for key in loss_dict:
tb_wandb_log(
f'train/{key.replace(" ", "_")}',
Expand Down Expand Up @@ -298,7 +300,10 @@ def add_to_logging(name):
log_string += " elapsed time per iteration (ms): {:.1f} |".format(
elapsed_time * 1000.0 / neox_args.log_interval
)
log_string += " learning rate: {:.3E} |".format(learning_rate)
for param_group_key in learning_rate.keys():
log_string += " param group {} lr: {:.3E} |".format(
param_group_key, learning_rate[param_group_key]
)
num_iterations = max(
1, neox_args.log_interval - total_loss_dict[skipped_iters_key]
)
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# limitations under the License.

from .gpt2_model import GPT2ModelPipe
from .utils import get_params_for_weight_decay_optimization
from .utils import get_param_groups
from .embeddings import SoftEmbedding
89 changes: 56 additions & 33 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,47 +24,70 @@
import torch.distributed as dist


def get_params_for_weight_decay_optimization(module, neox_args):
def update_params_for_weight_decay(weight_decay_params:dict, no_weight_decay_params:dict, module_, weight_decay:float):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
for module_ in module.modules():
if any(
if any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
else:
weight_decay_params["params"].extend(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, ScaleNorm),
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
]
) or (
neox_args.weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
else:
weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
]
)
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
]
)

def get_param_groups(module, neox_args):
param_groups = {}

# Defaults
param_groups["weight_decay"] = {"name": "weight_decay", "params": [], "weight_decay": neox_args.weight_decay}
param_groups["no_weight_decay"] = {"name": "no_weight_decay", "params": [], "weight_decay": 0.0}

neox_special_params = neox_args.lr_param_groups_config.keys() if neox_args.lr_param_groups_config else []
for name, module_ in module.named_modules():
for special_lr_key in neox_special_params:
if special_lr_key in name:
if f"{special_lr_key}_weight_decay" not in param_groups:
param_groups[f"{special_lr_key}_weight_decay"] = {"params": [], "name": special_lr_key, "weight_decay": neox_args.weight_decay}
param_groups[f"{special_lr_key}_no_weight_decay"] = {"params": [], "name": special_lr_key, "weight_decay": 0.0}
update_params_for_weight_decay(
param_groups[f"{special_lr_key}_weight_decay"], param_groups[f"{special_lr_key}_no_weight_decay"], module_, neox_args.weight_decay)
break
update_params_for_weight_decay(
param_groups["weight_decay"], param_groups["no_weight_decay"], module_, neox_args.weight_decay)

# Convert dictionary to list of dictionaries
param_groups = list(reversed([param_groups[key] for key in param_groups]))[:2]
if neox_args.weight_decay == 0.0:
# only return a single param group
# only return param groups without weight decay
# with onebitadam, we want to minimize the calls to compressed_allreduce. Every param group calls it once.
# to avoid this, only use a single param group when weight decay is off.
return [no_weight_decay_params]
return weight_decay_params, no_weight_decay_params

new_param_groups = []
for param_group in param_groups.keys():
if "no_weight_decay" in param_group:
new_param_groups.append(param_groups[param_group])
param_groups = new_param_groups
return param_groups

def exists(x):
return x is not None
Expand Down
1 change: 1 addition & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
"""
Use checkpoint to set the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from checkpoint and ignore input arguments.
"""
lr_param_groups_config: dict = None


@dataclass
Expand Down
15 changes: 10 additions & 5 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
from megatron.model import (
GPT2ModelPipe,
SoftEmbedding,
get_params_for_weight_decay_optimization,
get_param_groups,
)
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.data.data_utils import build_train_valid_test_data_iterators, build_streaming_train_valid_test_data_iterators
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.learning_rates import AnnealingLR, GroupedAnnealingLR
from megatron.logging import tb_wandb_log, training_log
from megatron.utils import (
OverflowMonitor,
Expand Down Expand Up @@ -541,7 +541,7 @@ def get_optimizer(model, neox_args):
)
exit()
# Build parameter groups (weight decay and non-decay).
param_groups = get_params_for_weight_decay_optimization(model, neox_args)
param_groups = get_param_groups(model, neox_args)
print_rank_0(
f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}'
)
Expand Down Expand Up @@ -672,7 +672,7 @@ def get_learning_rate_scheduler(optimizer, neox_args):
num_iters = max(1, num_iters)
init_step = 0
warmup_iter = neox_args.warmup * num_iters
lr_scheduler = AnnealingLR(
lr_scheduler = GroupedAnnealingLR(
optimizer,
start_lr=neox_args.lr,
warmup_iter=warmup_iter,
Expand All @@ -683,6 +683,7 @@ def get_learning_rate_scheduler(optimizer, neox_args):
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler,
override_lr_scheduler=neox_args.override_lr_scheduler,
use_mup=neox_args.use_mup,
lr_param_groups_config=neox_args.lr_param_groups_config,
)

return lr_scheduler
Expand Down Expand Up @@ -888,7 +889,11 @@ def train(
# get learning rate (if present) - if doing soft prompt tuning + pipe parallel, you
# may have no tunable parameters on a specific rank
if optimizer.param_groups:
lr = optimizer.param_groups[0].get("lr", 0)
lr = {} #optimizer.param_groups[0].get("lr", 0)
print(len(optimizer.param_groups))
for param_group in optimizer.param_groups:
print(param_group["name"])
lr[param_group["name"]] = param_group.get("lr", 0)
else:
lr = 0

Expand Down

0 comments on commit b215dae

Please sign in to comment.