Skip to content

Commit

Permalink
allow for grouped l2norm (more than one hypersphere)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 10, 2024
1 parent 422ae64 commit e508997
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
55 changes: 36 additions & 19 deletions nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,27 @@ def l2norm(
t,
dim = -1,
norm_eps = 0.,
eps = None
eps = None,
groups = 1
):
eps = default(eps, 1e-5 if t.dtype == torch.float16 else 1e-10)

if groups > 1:
t = t.chunk(groups, dim = dim)
t = torch.stack(t)

if norm_eps == 0.:
return F.normalize(t, dim = dim, p = 2, eps = eps)
out = F.normalize(t, dim = dim, p = 2, eps = eps)
else:
norm = t.norm(dim = dim, keepdim = True)
target_norm = norm.detach().clamp(min = 1. - norm_eps, max = 1. + norm_eps)
divisor = norm / target_norm
out = t / divisor.clamp(min = eps)

if groups > 1:
out = torch.cat([*out], dim = dim)

norm = t.norm(dim = dim, keepdim = True)
target_norm = norm.detach().clamp(min = 1. - norm_eps, max = 1. + norm_eps)
divisor = norm / target_norm
return t / divisor.clamp(min = eps)
return out

# scale

Expand All @@ -75,13 +85,14 @@ def forward(self):
# for use with parametrize

class L2Norm(Module):
def __init__(self, dim = -1, norm_eps = 0.):
def __init__(self, dim = -1, norm_eps = 0., groups = 1):
super().__init__()
self.dim = dim
self.norm_eps = norm_eps
self.groups = groups

def forward(self, t):
return l2norm(t, dim = self.dim, norm_eps = self.norm_eps)
return l2norm(t, dim = self.dim, norm_eps = self.norm_eps, groups = self.groups)

class NormLinear(Module):
def __init__(
Expand All @@ -90,13 +101,14 @@ def __init__(
dim_out,
norm_dim_in = True,
parametrize = True,
norm_eps = 0.
norm_eps = 0.,
groups = 1
):
super().__init__()
self.linear = nn.Linear(dim, dim_out, bias = False)

self.parametrize = parametrize
self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0, norm_eps = norm_eps)
self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0, norm_eps = norm_eps, groups = groups)

if parametrize:
register_parametrization(
Expand Down Expand Up @@ -143,13 +155,14 @@ def __init__(
enable_math = True,
enable_mem_efficient = True
),
norm_eps = 0.
norm_eps = 0.,
num_hyperspheres = 1
):
super().__init__()
self.causal = causal

NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps)
self.l2norm = partial(l2norm, norm_eps = norm_eps)
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres)
self.l2norm = partial(l2norm, norm_eps = norm_eps, groups = num_hyperspheres)

dim_sqrt = dim ** 0.5
self.dim_sqrt = dim_sqrt
Expand Down Expand Up @@ -237,10 +250,11 @@ def __init__(
s_hidden_scale = 1.,
s_gate_init = 1.,
s_gate_scale = 1.,
norm_eps = 0.
norm_eps = 0.,
num_hyperspheres = 1
):
super().__init__()
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps)
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres)

self.dim = dim
dim_inner = int(dim * expand_factor * 2 / 3)
Expand Down Expand Up @@ -278,6 +292,7 @@ def __init__(
ce_ignore_index = -1,
manual_norm_weights = False,
tied_embedding = False,
num_hyperspheres = 1,
causal = True,
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
Expand All @@ -301,8 +316,8 @@ def __init__(
norm_eps = 0. # greater than 0 allows the norm to be around (1. - norm_eps) to (1. + norm_eps)
):
super().__init__()
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps)
self.l2norm = partial(l2norm, norm_eps = norm_eps)
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres)
self.l2norm = partial(l2norm, norm_eps = norm_eps, groups = num_hyperspheres)

self.dim = dim
self.causal = causal
Expand Down Expand Up @@ -350,7 +365,8 @@ def __init__(
s_qk_init = s_qk_init_,
s_qk_scale = s_qk_scale_,
flash_kwargs = attn_flash_kwargs,
norm_eps = norm_eps
norm_eps = norm_eps,
num_hyperspheres = num_hyperspheres
)

ff = FeedForward(
Expand All @@ -361,7 +377,8 @@ def __init__(
s_hidden_scale = s_ff_hidden_scale_,
s_gate_init = s_ff_gate_init_,
s_gate_scale = s_ff_gate_scale_,
norm_eps = norm_eps
norm_eps = norm_eps,
num_hyperspheres = num_hyperspheres
)

attn_interp_factor = Scale(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.1.1"
version = "0.1.2"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit e508997

Please sign in to comment.