Skip to content

Commit

Permalink
make all scaling hyperparameter configurable for completeness
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 10, 2024
1 parent a34355d commit c58093f
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 16 deletions.
110 changes: 95 additions & 15 deletions nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from functools import partial

import torch
Expand All @@ -19,6 +21,11 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def cast_tuple(t, length = 1):
out = t if isinstance(t, tuple) else ((t,) * length)
assert len(out) == length
return out

def l2norm(t, dim = -1):
return F.normalize(t, dim = dim, p = 2)

Expand Down Expand Up @@ -101,7 +108,9 @@ def __init__(
dim_head = 64,
heads = 8,
norm_qk = True,
manual_norm_weights = False
manual_norm_weights = False,
s_qk_init = 1.,
s_qk_scale = None
):
super().__init__()
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
Expand Down Expand Up @@ -167,7 +176,11 @@ def __init__(
dim,
*,
expand_factor = 4,
manual_norm_weights = False
manual_norm_weights = False,
s_hidden_init = 1.,
s_hidden_scale = 1.,
s_gate_init = 1.,
s_gate_scale = 1.
):
super().__init__()
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
Expand All @@ -178,8 +191,8 @@ def __init__(
self.to_hidden = NormLinear_(dim, dim_inner)
self.to_gate = NormLinear_(dim, dim_inner)

self.hidden_scale = Scale(dim_inner)
self.gate_scale = Scale(dim_inner)
self.hidden_scale = Scale(dim_inner, s_hidden_init, s_hidden_scale)
self.gate_scale = Scale(dim_inner, s_gate_init, s_gate_scale)

self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False)

Expand All @@ -206,31 +219,98 @@ def __init__(
attn_norm_qk = True, # they say the query/key normalization is optional
ff_expand_factor = 4.,
ce_ignore_index = -1,
residual_lerp_scale_init = None,
manual_norm_weights = False,
tied_embedding = False
tied_embedding = False,
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
s_logit_init: float = 1.,
s_logit_scale: float | None = None,
alpha_attn_init: float | tuple[float, ...] | None = None,
alpha_attn_scale: float | tuple[float, ...] | None = None,
alpha_ff_init: float | tuple[float, ...] | None = None,
alpha_ff_scale: float | tuple[float, ...] | None = None,
s_qk_init: float | tuple[float, ...] = 1.,
s_qk_scale: float | tuple[float, ...] | None = None,
s_ff_hidden_init: float | tuple[float, ...] = 1.,
s_ff_hidden_scale: float | tuple[float, ...] = 1.,
s_ff_gate_init: float | tuple[float, ...] = 1.,
s_ff_gate_scale: float | tuple[float, ...] = 1.
):
super().__init__()
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)

self.dim = dim
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
alpha_init = default(alpha_init, 1. / depth)

self.token_embed = NormLinear_(dim, num_tokens)

self.layers = ModuleList([])

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, norm_qk = attn_norm_qk, manual_norm_weights = manual_norm_weights),
FeedForward(dim, expand_factor = ff_expand_factor, manual_norm_weights = manual_norm_weights),
Scale(dim, residual_lerp_scale_init, dim ** -0.5),
Scale(dim, residual_lerp_scale_init, dim ** -0.5),
]))
scale_hparams = (
alpha_attn_init,
alpha_attn_scale,
alpha_ff_init,
alpha_ff_scale,
s_qk_init,
s_qk_scale,
s_ff_hidden_init,
s_ff_hidden_scale,
s_ff_gate_init,
s_ff_gate_scale
)

scale_hparams = tuple(cast_tuple(hparam, depth) for hparam in scale_hparams)

for (
alpha_attn_init_,
alpha_attn_scale_,
alpha_ff_init_,
alpha_ff_scale_,
s_qk_init_,
s_qk_scale_,
s_ff_hidden_init_,
s_ff_hidden_scale_,
s_ff_gate_init_,
s_ff_gate_scale_
) in zip(*scale_hparams):

attn = Attention(
dim,
dim_head = dim_head,
heads = heads,
norm_qk = attn_norm_qk,
manual_norm_weights = manual_norm_weights,
s_qk_init = s_qk_init_,
s_qk_scale = s_qk_scale_,
)

ff = FeedForward(
dim,
expand_factor = ff_expand_factor,
manual_norm_weights = manual_norm_weights,
s_hidden_init = s_ff_hidden_init_,
s_hidden_scale = s_ff_hidden_scale_,
s_gate_init = s_ff_gate_init_,
s_gate_scale = s_ff_gate_scale_
)

attn_interp_factor = Scale(
dim,
default(alpha_attn_init_, alpha_init),
default(alpha_attn_scale_, dim ** -0.5)
)

ff_interp_factor = Scale(
dim,
default(alpha_ff_init_, alpha_init),
default(alpha_ff_scale_, dim ** -0.5)
)

self.layers.append(ModuleList([attn, ff, attn_interp_factor, ff_interp_factor]))

self.to_logits = NormLinear_(dim, num_tokens) if not tied_embedding else None

self.logit_scale = Scale(num_tokens, 1., dim ** -0.5)
self.logit_scale = Scale(num_tokens, s_logit_init, default(s_logit_scale, dim ** -0.5))

self.ignore_index = ce_ignore_index

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.0.11"
version = "0.0.12"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def base_decoding(
dim = 512,
depth = 8,
manual_norm_weights = True,
tied_embedding = True
).to(device)

# prepare enwik8 data
Expand Down

0 comments on commit c58093f

Please sign in to comment.