Skip to content

Commit

Permalink
allow for norming the input with in nTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 17, 2024
1 parent f64fba1 commit cb7185c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 6 additions & 1 deletion nGPT_pytorch/nTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def __init__(
norm_eps = 0. # greater than 0 allows the norm to be around (1. - norm_eps) to (1. + norm_eps)
):
super().__init__()
self.l2norm = partial(l2norm, norm_eps = norm_eps, groups = num_hyperspheres)

self.dim = dim
self.causal = causal
Expand Down Expand Up @@ -404,8 +405,12 @@ def forward(
self,
tokens,
mask = None,
norm_input = False
):

if norm_input:
tokens = self.l2norm(tokens)

for attn, ff, attn_alpha, ff_alpha in self.layers:

attn_out = l2norm(attn(tokens, mask = mask))
Expand All @@ -427,5 +432,5 @@ def forward(

x = torch.randn(1, 1024, 512)

embed = transformer(x)
embed = transformer(x, norm_input = True)
assert x.shape == embed.shape
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.1.9"
version = "0.1.10"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit cb7185c

Please sign in to comment.