diff --git a/nGPT_pytorch/nGPT.py b/nGPT_pytorch/nGPT.py index 2ed6e2f..a684c02 100644 --- a/nGPT_pytorch/nGPT.py +++ b/nGPT_pytorch/nGPT.py @@ -41,17 +41,27 @@ def l2norm( t, dim = -1, norm_eps = 0., - eps = None + eps = None, + groups = 1 ): eps = default(eps, 1e-5 if t.dtype == torch.float16 else 1e-10) + if groups > 1: + t = t.chunk(groups, dim = dim) + t = torch.stack(t) + if norm_eps == 0.: - return F.normalize(t, dim = dim, p = 2, eps = eps) + out = F.normalize(t, dim = dim, p = 2, eps = eps) + else: + norm = t.norm(dim = dim, keepdim = True) + target_norm = norm.detach().clamp(min = 1. - norm_eps, max = 1. + norm_eps) + divisor = norm / target_norm + out = t / divisor.clamp(min = eps) + + if groups > 1: + out = torch.cat([*out], dim = dim) - norm = t.norm(dim = dim, keepdim = True) - target_norm = norm.detach().clamp(min = 1. - norm_eps, max = 1. + norm_eps) - divisor = norm / target_norm - return t / divisor.clamp(min = eps) + return out # scale @@ -75,13 +85,14 @@ def forward(self): # for use with parametrize class L2Norm(Module): - def __init__(self, dim = -1, norm_eps = 0.): + def __init__(self, dim = -1, norm_eps = 0., groups = 1): super().__init__() self.dim = dim self.norm_eps = norm_eps + self.groups = groups def forward(self, t): - return l2norm(t, dim = self.dim, norm_eps = self.norm_eps) + return l2norm(t, dim = self.dim, norm_eps = self.norm_eps, groups = self.groups) class NormLinear(Module): def __init__( @@ -90,13 +101,14 @@ def __init__( dim_out, norm_dim_in = True, parametrize = True, - norm_eps = 0. + norm_eps = 0., + groups = 1 ): super().__init__() self.linear = nn.Linear(dim, dim_out, bias = False) self.parametrize = parametrize - self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0, norm_eps = norm_eps) + self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0, norm_eps = norm_eps, groups = groups) if parametrize: register_parametrization( @@ -143,13 +155,14 @@ def __init__( enable_math = True, enable_mem_efficient = True ), - norm_eps = 0. + norm_eps = 0., + num_hyperspheres = 1 ): super().__init__() self.causal = causal - NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps) - self.l2norm = partial(l2norm, norm_eps = norm_eps) + NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres) + self.l2norm = partial(l2norm, norm_eps = norm_eps, groups = num_hyperspheres) dim_sqrt = dim ** 0.5 self.dim_sqrt = dim_sqrt @@ -237,10 +250,11 @@ def __init__( s_hidden_scale = 1., s_gate_init = 1., s_gate_scale = 1., - norm_eps = 0. + norm_eps = 0., + num_hyperspheres = 1 ): super().__init__() - NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps) + NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres) self.dim = dim dim_inner = int(dim * expand_factor * 2 / 3) @@ -278,6 +292,7 @@ def __init__( ce_ignore_index = -1, manual_norm_weights = False, tied_embedding = False, + num_hyperspheres = 1, causal = True, # below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified @@ -301,8 +316,8 @@ def __init__( norm_eps = 0. # greater than 0 allows the norm to be around (1. - norm_eps) to (1. + norm_eps) ): super().__init__() - NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps) - self.l2norm = partial(l2norm, norm_eps = norm_eps) + NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres) + self.l2norm = partial(l2norm, norm_eps = norm_eps, groups = num_hyperspheres) self.dim = dim self.causal = causal @@ -350,7 +365,8 @@ def __init__( s_qk_init = s_qk_init_, s_qk_scale = s_qk_scale_, flash_kwargs = attn_flash_kwargs, - norm_eps = norm_eps + norm_eps = norm_eps, + num_hyperspheres = num_hyperspheres ) ff = FeedForward( @@ -361,7 +377,8 @@ def __init__( s_hidden_scale = s_ff_hidden_scale_, s_gate_init = s_ff_gate_init_, s_gate_scale = s_ff_gate_scale_, - norm_eps = norm_eps + norm_eps = norm_eps, + num_hyperspheres = num_hyperspheres ) attn_interp_factor = Scale( diff --git a/pyproject.toml b/pyproject.toml index fa68cac..f9ccb75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nGPT-pytorch" -version = "0.1.1" +version = "0.1.2" description = "nGPT" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }