diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index ea257b3a7b95a7..004d945c1274c9 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -78,7 +78,7 @@ class LlamaConfig(PretrainedConfig): tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how @@ -122,6 +122,7 @@ def __init__( pretraining_tp=1, tie_word_embeddings=False, rope_scaling=None, + rope_theta=10000, **kwargs, ): self.vocab_size = vocab_size @@ -143,6 +144,7 @@ def __init__( self.use_cache = use_cache self.rope_scaling = rope_scaling self._rope_scaling_validation() + self.rope_theta = rope_theta super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5709fc6a778110..170a3e1ab41905 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -243,6 +243,7 @@ def __init__(self, config: LlamaConfig): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.pretraining_tp = config.pretraining_tp self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -257,21 +258,25 @@ def __init__(self, config: LlamaConfig): def _init_rope(self): if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta + ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.head_dim, max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, scaling_factor=scaling_factor ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + self.head_dim, max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, scaling_factor=scaling_factor ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()