Skip to content

Commit

Permalink
Merge pull request #78 from rhymes-ai/vllm
Browse files Browse the repository at this point in the history
Refactor(vllm): to prepare for merging it to upstream
  • Loading branch information
xffxff authored Nov 20, 2024
2 parents 9c5139f + 92ea030 commit 083b5a8
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 216 deletions.
276 changes: 60 additions & 216 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
Expand All @@ -39,24 +38,20 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput, SamplingMetadata
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
RMSNorm,
)
from vllm.model_executor.models.utils import (
PPMissingLayer,
AutoWeightsLoader,
WeightsMapper,
make_layers,
maybe_prefix,
merge_multimodal_embeddings,
)
from vllm.model_executor.utils import set_weight_attrs
Expand All @@ -70,17 +65,11 @@
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

from aria.model.configuration_aria import AriaConfig
from aria.model.projector import AriaProjector
from aria.model.vision_encoder import AriaVisionModel
from .projector import AriaProjector
from .vision_encoder import AriaVisionModel

logger = logging.get_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}


class AriaMoELMConfig(LlamaConfig):
"""
Expand Down Expand Up @@ -156,36 +145,37 @@ def __init__(self, config: AriaMoELMConfig):
)
)
)
set_weight_attrs(self.router_weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.w1, {"weight_loader": self.weight_loader})
set_weight_attrs(self.w2, {"weight_loader": self.weight_loader})
set_weight_attrs(
self.router_weight, {"weight_loader": self._weight_loader_for_router}
)
set_weight_attrs(self.w1, {"weight_loader": self._weight_loader_for_w1})
set_weight_attrs(self.w2, {"weight_loader": self._weight_loader_for_w2})

def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
def _weight_loader_for_router(
self, param: nn.Parameter, loaded_weight: torch.Tensor
):
if shard_id == "router":
param.data.copy_(loaded_weight)
elif shard_id == "w1":
if self.tp_size > 1:
# the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size)
up, gate = loaded_weight.chunk(2, dim=-1)
up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank]
gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank]
up_and_gate = torch.cat(
[up_current_rank, gate_current_rank], dim=-1
).transpose(1, 2)
param.data.copy_(up_and_gate)
else:
param.data.copy_(loaded_weight.transpose(1, 2))
param.data.copy_(loaded_weight)

def _weight_loader_for_w1(self, param: nn.Parameter, loaded_weight: torch.Tensor):
# the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size)
if self.tp_size > 1:
up, gate = loaded_weight.chunk(2, dim=-1)
up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank]
gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank]
up_and_gate = torch.cat(
[up_current_rank, gate_current_rank], dim=-1
).transpose(1, 2)
param.data.copy_(up_and_gate)
else:
if self.tp_size > 1:
# the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size)
down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[
self.tp_rank
]
param.data.copy_(down_current_rank.transpose(1, 2))
else:
param.data.copy_(loaded_weight.transpose(1, 2))
param.data.copy_(loaded_weight.transpose(1, 2))

def _weight_loader_for_w2(self, param: nn.Parameter, loaded_weight: torch.Tensor):
# the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size)
if self.tp_size > 1:
down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[self.tp_rank]
param.data.copy_(down_current_rank.transpose(1, 2))
else:
param.data.copy_(loaded_weight.transpose(1, 2))

def forward(self, hidden_states):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
Expand Down Expand Up @@ -328,39 +318,18 @@ class AriaMoELMModel(LlamaModel):
config (LlamaConfig): Configuration object for the model.
"""

def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

# FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'.
self.do_not_compile = True

self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.layers = None

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MoEDecoderLayer(
Expand All @@ -371,112 +340,9 @@ def __init__(
),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()


class AriaMoELMForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm",
}

def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
nn.Module.__init__(self)

self.config = config
self.lora_config = lora_config

self.model = AriaMoELMModel(
config, cache_config, quant_config, lora_config=lora_config, prefix="model"
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()


def build_mm_projector(config: AriaConfig):
def build_mm_projector(config):
"""
Builds and returns an AriaProjector instance based on the provided configuration.
Expand Down Expand Up @@ -699,7 +565,6 @@ def input_processor(ctx, llm_inputs):
)


# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria)
@INPUT_REGISTRY.register_input_processor(input_processor)
Expand All @@ -718,7 +583,6 @@ def __init__(
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

# prepare the image_size to tokens mapping for the image preprocess, see input_processor
Expand All @@ -735,7 +599,8 @@ def __init__(
self.multi_modal_projector = build_mm_projector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaMoELMModel(
config.text_config, cache_config, quant_config
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
Expand Down Expand Up @@ -773,11 +638,10 @@ def forward(
torch.bfloat16
)
pixel_mask = pixel_mask.view(-1, *pixel_mask.shape[-2:])
image_outputs, image_attn_mask = self.vision_tower(
selected_image_feature, image_attn_mask = self.vision_tower(
pixel_values,
pixel_mask=pixel_mask,
)
selected_image_feature = image_outputs.last_hidden_state

image_features = self.multi_modal_projector(
selected_image_feature, attn_mask=image_attn_mask
Expand Down Expand Up @@ -814,37 +678,17 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("experts.router_weight", "router.weight", "router"),
("experts.w1", "experts.fc1.weight", "w1"),
("experts.w2", "experts.fc2.weight", "w2"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
shard_id = None
# Because we used the origin hf vit and vision projector, we cound keep the weight in the sharded shape.
# Only for the language model part needs to adjust the weight loading.
if "language_model" in name:
for param_name, weight_name, _shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
shard_id = _shard_id
break

param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if shard_id is not None:
weight_loader(param, loaded_weight, shard_id)
else:
weight_loader(param, loaded_weight)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "language_model",
"language_model.lm_head": "lm_head",
},
orig_to_new_suffix={
"experts.fc1.weight": "experts.w1",
"experts.fc2.weight": "experts.w2",
"router.weight": "experts.router_weight",
},
)

loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
Loading

0 comments on commit 083b5a8

Please sign in to comment.