Skip to content

Commit

Permalink
Merge pull request #5 from togethercomputer/support-code-llama
Browse files Browse the repository at this point in the history
llama: support RoPE theta for codellama
  • Loading branch information
RyanLucchese authored Aug 24, 2023
2 parents d0027d5 + 33a236b commit 41d4bf2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class LlamaConfig(PretrainedConfig):
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -143,6 +144,7 @@ def __init__(
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.rope_theta = rope_theta

super().__init__(
pad_token_id=pad_token_id,
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(self, config: LlamaConfig):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.pretraining_tp = config.pretraining_tp
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
Expand All @@ -257,21 +258,25 @@ def __init__(self, config: LlamaConfig):

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
self.head_dim, max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
self.head_dim, max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

Expand Down

0 comments on commit 41d4bf2

Please sign in to comment.