diff --git a/llama/model.py b/llama/model.py index e388c03..27872c2 100644 --- a/llama/model.py +++ b/llama/model.py @@ -23,7 +23,7 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = -1 - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + multiple_of: int = 256 # Make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 rope_theta: float = 500000 @@ -50,7 +50,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device, dtype=torch.float32) freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Complex64 return freqs_cis @@ -168,7 +168,7 @@ def forward( keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen] - # repeat k/v heads if n_kv_heads < n_heads + # Repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv( keys, self.n_rep ) # (bs, cache_len + seqlen, n_local_heads, head_dim) @@ -200,7 +200,7 @@ def __init__( ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier + # Custom dim factor multiplier if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)