Skip to content

Commit

Permalink
just copy paste and do a version with plain transformer for RL purposes
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 17, 2024
1 parent a79389b commit 8023931
Show file tree
Hide file tree
Showing 4 changed files with 442 additions and 4 deletions.
2 changes: 2 additions & 0 deletions nGPT_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
Attention,
nGPT
)

from nGPT_pytorch.nTransformer import nTransformer
11 changes: 8 additions & 3 deletions nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,9 @@ class nGPT(Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
num_tokens = None,
dim_head = 64,
heads = 8,
attn_norm_qk = True, # they say the query/key normalization is optional
Expand Down Expand Up @@ -347,7 +347,12 @@ def __init__(
self.causal = causal
alpha_init = default(alpha_init, 1. / depth)

self.token_embed = NormLinear_(dim, num_tokens)
# allow for plain stack of attention and feedforward, for trying to use in a different setting

only_transformer = not exists(num_tokens)
self.only_transformer = only_transformer

self.token_embed = NormLinear_(dim, num_tokens) if not only_transformer else None

self.layers = ModuleList([])

Expand Down Expand Up @@ -421,7 +426,7 @@ def __init__(

self.layers.append(ModuleList([attn_with_residual, ff_with_residual]))

self.to_logits = NormLinear_(dim, num_tokens) if not tied_embedding else None
self.to_logits = NormLinear_(dim, num_tokens) if (not tied_embedding or only_transformer) or not exists(num_tokens) else None

self.logit_scale = Scale(num_tokens, s_logit_init, default(s_logit_scale, dim ** -0.5))

Expand Down
Loading

0 comments on commit 8023931

Please sign in to comment.