Skip to content

Commit

Permalink
limit LoRA targets
Browse files Browse the repository at this point in the history
  • Loading branch information
corbt committed Apr 18, 2024
1 parent cd2f63f commit a71bcf4
Showing 1 changed file with 103 additions and 68 deletions.
171 changes: 103 additions & 68 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""

from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
Expand All @@ -29,28 +30,36 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
default_weight_loader,
kv_cache_scales_loader,
)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip


class LlamaMLP(nn.Module):

def __init__(
self,
hidden_size: int,
Expand All @@ -60,16 +69,22 @@ def __init__(
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size,
[intermediate_size] * 2,
bias=False,
linear_method=linear_method,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
linear_method=linear_method,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()

def forward(self, x):
Expand All @@ -80,7 +95,6 @@ def forward(self, x):


class LlamaAttention(nn.Module):

def __init__(
self,
hidden_size: int,
Expand Down Expand Up @@ -147,11 +161,13 @@ def __init__(
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
)

def forward(
self,
Expand All @@ -163,14 +179,12 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
output, _ = self.o_proj(attn_output)
return output


class LlamaDecoderLayer(nn.Module):

def __init__(
self,
config: LlamaConfig,
Expand All @@ -180,18 +194,21 @@ def __init__(
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
max_position_embeddings = getattr(
config, "max_position_embeddings", 8192
)
sliding_window = getattr(config, "sliding_window", None)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
config, "bias", False
)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
Expand All @@ -205,10 +222,12 @@ def __init__(
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)

def forward(
self,
Expand All @@ -224,7 +243,8 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
Expand All @@ -234,13 +254,13 @@ def forward(

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual


class LlamaModel(nn.Module):

def __init__(
self,
config: LlamaConfig,
Expand All @@ -250,19 +270,24 @@ def __init__(
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -316,11 +341,8 @@ class LlamaForCausalLM(nn.Module):
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
Expand All @@ -342,12 +364,14 @@ def __init__(
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
if not lora_config
else lora_config.lora_vocab_padding_size,
)

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
self.sampler = Sampler()

def forward(
Expand All @@ -357,14 +381,17 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
hidden_states = self.model(
input_ids, positions, kv_caches, attn_metadata
)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
def compute_logits(
self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
) -> torch.Tensor:
logits = self.logits_processor(
self.lm_head.weight, hidden_states, sampling_metadata
)
return logits

def sample(
Expand All @@ -388,12 +415,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
if (
"rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
Expand All @@ -409,8 +438,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)

# If this function is called, it should always initialize KV cache scale
Expand All @@ -420,9 +450,12 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
layer_self_attn = self.model.layers[layer_idx].self_attn

if is_hip():
Expand All @@ -434,5 +467,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
raise RuntimeError(
"Self attention has no KV cache scaling "
"factor attribute!"
)

0 comments on commit a71bcf4

Please sign in to comment.