Skip to content

Commit

Permalink
Merge pull request #77 from rhymes-ai/upgrade
Browse files Browse the repository at this point in the history
upgrade vllm to the latest version
  • Loading branch information
xffxff authored Nov 20, 2024
2 parents 947ec22 + 9d2850f commit 9c5139f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
6 changes: 5 additions & 1 deletion aria/model/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def __init__(
self.image_token_index = image_token_index

attn_implementation = kwargs.pop("attn_implementation", None)
self._attn_implementation = attn_implementation

# Set the default attention implementation to flash_attention_2 if not specified
self._attn_implementation = (
"flash_attention_2" if attn_implementation is None else attn_implementation
)

# Convert the keys and values of projector_patch_to_query_dict to integers
# This ensures consistency even if they were provided as strings
Expand Down
25 changes: 16 additions & 9 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
from transformers import LlamaConfig
from transformers.utils import logging
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.inputs import INPUT_REGISTRY, LLMInputs
from vllm.inputs import INPUT_REGISTRY, token_inputs
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -337,6 +337,10 @@ def __init__(
prefix: str = "",
) -> None:
nn.Module.__init__(self)

# FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'.
self.do_not_compile = True

self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (
Expand Down Expand Up @@ -679,18 +683,19 @@ def input_processor(ctx, llm_inputs):
# TODO: Supports dynamic image size support
setattr(model_config.multimodal_config, "max_image_size", max(max_image_size))

new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
placeholder_token_id=hf_config.image_token_index,
repeat_count=image_feature_sizes,
)

return LLMInputs(
prompt=new_prompt,
return token_inputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
# multi_modal_placeholders={"image": ranges},
)


Expand All @@ -708,12 +713,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):

def __init__(
self,
config: AriaConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

# prepare the image_size to tokens mapping for the image preprocess, see input_processor
setattr(
config,
Expand Down
7 changes: 1 addition & 6 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,12 @@ pip install -e .[vllm]
from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, ModelRegistry, SamplingParams
from vllm.model_executor.models import _MULTIMODAL_MODELS
from aria.vllm.aria import AriaForConditionalGeneration
ModelRegistry.register_model(
"AriaForConditionalGeneration", AriaForConditionalGeneration
)
_MULTIMODAL_MODELS["AriaForConditionalGeneration"] = (
"aria",
"AriaForConditionalGeneration",
)
def main():
Expand Down Expand Up @@ -147,7 +142,7 @@ def main():
Image.open("assets/princess2.jpg"),
],
"max_image_size": 980, # [Optional] The max image patch size, default `980`
"split_image": True, # [Optional] whether to split the images, default `False`
"split_image": False, # [Optional] whether to split the images, default `False`
},
},
sampling_params=SamplingParams(max_tokens=200, top_k=1, stop=["<|im_end|>"]),
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ authors = [
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.4.0",
"torchvision==0.19.0",
"torch==2.5.1",
"torchvision==0.20.1",
"accelerate==0.34.1",
"deepspeed==0.15.0",
"peft==0.12.0",
"sentencepiece==0.2.0",
"transformers==4.45.0",
"transformers==4.46.3",
"trl==0.9.6",
"pillow==10.4.0",
"wandb==0.18.1",
Expand All @@ -32,7 +32,7 @@ dev = [
"pytest==8.3.3",
]
vllm = [
"vllm==0.6.2"
"vllm==0.6.4.post1"
]
grouped_gemm = [
"grouped_gemm==0.1.6"
Expand Down

0 comments on commit 9c5139f

Please sign in to comment.