diff --git a/aria/vllm/aria.py b/aria/vllm/aria.py index 0b2be5a..f6abe09 100644 --- a/aria/vllm/aria.py +++ b/aria/vllm/aria.py @@ -28,7 +28,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import ( - get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, @@ -39,24 +38,20 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput, SamplingMetadata -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama import ( LlamaAttention, LlamaDecoderLayer, - LlamaForCausalLM, LlamaMLP, LlamaModel, RMSNorm, ) from vllm.model_executor.models.utils import ( - PPMissingLayer, + AutoWeightsLoader, + WeightsMapper, make_layers, + maybe_prefix, merge_multimodal_embeddings, ) from vllm.model_executor.utils import set_weight_attrs @@ -70,17 +65,11 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from aria.model.configuration_aria import AriaConfig -from aria.model.projector import AriaProjector -from aria.model.vision_encoder import AriaVisionModel +from .projector import AriaProjector +from .vision_encoder import AriaVisionModel logger = logging.get_logger(__name__) -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} - class AriaMoELMConfig(LlamaConfig): """ @@ -156,36 +145,37 @@ def __init__(self, config: AriaMoELMConfig): ) ) ) - set_weight_attrs(self.router_weight, {"weight_loader": self.weight_loader}) - set_weight_attrs(self.w1, {"weight_loader": self.weight_loader}) - set_weight_attrs(self.w2, {"weight_loader": self.weight_loader}) + set_weight_attrs( + self.router_weight, {"weight_loader": self._weight_loader_for_router} + ) + set_weight_attrs(self.w1, {"weight_loader": self._weight_loader_for_w1}) + set_weight_attrs(self.w2, {"weight_loader": self._weight_loader_for_w2}) - def weight_loader( - self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str + def _weight_loader_for_router( + self, param: nn.Parameter, loaded_weight: torch.Tensor ): - if shard_id == "router": - param.data.copy_(loaded_weight) - elif shard_id == "w1": - if self.tp_size > 1: - # the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size) - up, gate = loaded_weight.chunk(2, dim=-1) - up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank] - gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank] - up_and_gate = torch.cat( - [up_current_rank, gate_current_rank], dim=-1 - ).transpose(1, 2) - param.data.copy_(up_and_gate) - else: - param.data.copy_(loaded_weight.transpose(1, 2)) + param.data.copy_(loaded_weight) + + def _weight_loader_for_w1(self, param: nn.Parameter, loaded_weight: torch.Tensor): + # the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size) + if self.tp_size > 1: + up, gate = loaded_weight.chunk(2, dim=-1) + up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank] + gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank] + up_and_gate = torch.cat( + [up_current_rank, gate_current_rank], dim=-1 + ).transpose(1, 2) + param.data.copy_(up_and_gate) else: - if self.tp_size > 1: - # the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size) - down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[ - self.tp_rank - ] - param.data.copy_(down_current_rank.transpose(1, 2)) - else: - param.data.copy_(loaded_weight.transpose(1, 2)) + param.data.copy_(loaded_weight.transpose(1, 2)) + + def _weight_loader_for_w2(self, param: nn.Parameter, loaded_weight: torch.Tensor): + # the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size) + if self.tp_size > 1: + down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[self.tp_rank] + param.data.copy_(down_current_rank.transpose(1, 2)) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) def forward(self, hidden_states): router_output = torch.nn.functional.linear(hidden_states, self.router_weight) @@ -328,39 +318,18 @@ class AriaMoELMModel(LlamaModel): config (LlamaConfig): Configuration object for the model. """ - def __init__( - self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config # FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'. self.do_not_compile = True - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or ( - config.tie_word_embeddings and get_pp_group().is_last_rank - ): - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - quant_config=quant_config, - ) - else: - self.embed_tokens = PPMissingLayer() + self.layers = None + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MoEDecoderLayer( @@ -371,112 +340,9 @@ def __init__( ), prefix=f"{prefix}.layers", ) - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - - -class AriaMoELMForCausalLM(LlamaForCausalLM): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - # Mistral/Llama models can also be loaded with --load-format mistral - # from consolidated.safetensors checkpoints - mistral_mapping = { - "layers": "model.layers", - "attention": "self_attn", - "wq": "q_proj", - "wk": "k_proj", - "wv": "v_proj", - "wo": "o_proj", - "attention_norm": "input_layernorm", - "feed_forward": "mlp", - "w1": "gate_proj", - "w2": "down_proj", - "w3": "up_proj", - "ffn_norm": "post_attention_layernorm", - "tok_embeddings": "model.embed_tokens", - "output": "lm_head", - "norm": "model.norm", - } - - def __init__( - self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - nn.Module.__init__(self) - - self.config = config - self.lora_config = lora_config - self.model = AriaMoELMModel( - config, cache_config, quant_config, lora_config=lora_config, prefix="model" - ) - if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) - self.sampler = Sampler() - else: - self.lm_head = PPMissingLayer() - -def build_mm_projector(config: AriaConfig): +def build_mm_projector(config): """ Builds and returns an AriaProjector instance based on the provided configuration. @@ -699,7 +565,6 @@ def input_processor(ctx, llm_inputs): ) -# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria) @INPUT_REGISTRY.register_input_processor(input_processor) @@ -718,7 +583,6 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config # prepare the image_size to tokens mapping for the image preprocess, see input_processor @@ -735,7 +599,8 @@ def __init__( self.multi_modal_projector = build_mm_projector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AriaMoELMModel( - config.text_config, cache_config, quant_config + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model.model"), ) self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 @@ -773,11 +638,10 @@ def forward( torch.bfloat16 ) pixel_mask = pixel_mask.view(-1, *pixel_mask.shape[-2:]) - image_outputs, image_attn_mask = self.vision_tower( + selected_image_feature, image_attn_mask = self.vision_tower( pixel_values, pixel_mask=pixel_mask, ) - selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector( selected_image_feature, attn_mask=image_attn_mask @@ -814,37 +678,17 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # only doing this for language model part for now. - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ("experts.router_weight", "router.weight", "router"), - ("experts.w1", "experts.fc1.weight", "w1"), - ("experts.w2", "experts.fc2.weight", "w2"), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - shard_id = None - # Because we used the origin hf vit and vision projector, we cound keep the weight in the sharded shape. - # Only for the language model part needs to adjust the weight loading. - if "language_model" in name: - for param_name, weight_name, _shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - shard_id = _shard_id - break - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - if shard_id is not None: - weight_loader(param, loaded_weight, shard_id) - else: - weight_loader(param, loaded_weight) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "language_model", + "language_model.lm_head": "lm_head", + }, + orig_to_new_suffix={ + "experts.fc1.weight": "experts.w1", + "experts.fc2.weight": "experts.w2", + "router.weight": "experts.router_weight", + }, + ) + + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/aria/vllm/projector.py b/aria/vllm/projector.py new file mode 100644 index 0000000..3b14525 --- /dev/null +++ b/aria/vllm/projector.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from transformers.activations import ACT2FN + + +class FFN(nn.Module): + """ + Feed-Forward Network module. + + Args: + embed_dim (int): Input embedding dimension. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + """ + + def __init__(self, embed_dim, ff_dim, output_dim): + super().__init__() + self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False) + self.linear_out = nn.Linear(ff_dim, output_dim, bias=False) + self.act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_states = self.act(self.linear_in(hidden_states)) + hidden_states = self.linear_out(hidden_states) + return hidden_states + + +class CrossAttention(nn.Module): + """ + Cross-Attention module. + + Args: + kv_dim (int): Dimension of key and value. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + drop_out_rate (float): Dropout rate. Default is 0. + """ + + def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): + super().__init__() + self.num_heads = num_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) + + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + self.linear = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(drop_out_rate) + + self.layer_norm = nn.LayerNorm(embed_dim) + self.ln_kv = nn.LayerNorm(kv_dim) + + def forward(self, x, hidden_states, attn_mask=None, add_residual=False): + """ + Forward pass of the CrossAttention module. + + Args: + x (torch.Tensor): Input tensor for key and value. + hidden_states (torch.Tensor): Input tensor for query. + attn_mask (torch.Tensor, optional): Attention mask. Default is None. + add_residual (bool): Whether to add residual connection. Default is False. + + Returns: + torch.Tensor: Output tensor after cross-attention. + """ + normed_hidden_states = self.layer_norm(hidden_states) + query = self.q_proj(normed_hidden_states).permute(1, 0, 2) + + x = self.ln_kv(x) + key = self.k_proj(x).permute(1, 0, 2) + value = self.v_proj(x).permute(1, 0, 2) + + attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) + + attn_output = attn_output.permute(1, 0, 2) + + if add_residual: + attn_output = hidden_states + self.dropout(self.linear(attn_output)) + else: + attn_output = self.dropout(self.linear(attn_output)) + + return attn_output + + +class AriaProjector(nn.Module): + """ + A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs. + + Args: + patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers, + e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + + Outputs: + A tensor with the shape of (batch_size, query_number, output_dim) + """ + + def __init__( + self, + patch_to_query_dict, + embed_dim, + num_heads, + kv_dim, + ff_dim, + output_dim, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.patch_to_query_dict = patch_to_query_dict + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter( + torch.zeros(max(patch_to_query_dict.values()), self.embed_dim) + ) + + trunc_normal_(self.query, std=0.02) + + self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) + + self.ln_ffn = norm_layer(embed_dim) + self.ffn = FFN(embed_dim, ff_dim, output_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, attn_mask=None): + """ + Forward pass of the Projector module. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim). + attn_mask (torch.Tensor, optional): Attention mask. Default is None. + + Returns: + torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim). + """ + bs = x.shape[0] + queries = self.query.unsqueeze(0).repeat(bs, 1, 1) + + query_num = self.patch_to_query_dict.get(x.shape[1], None) + assert ( + query_num is not None + ), f"Query number for {x.shape[1]} patches is not provided" + + queries = queries[:, :query_num, :] + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) + + out = self.ffn(self.ln_ffn(attention_out)) + + return out diff --git a/aria/vllm/vision_encoder.py b/aria/vllm/vision_encoder.py new file mode 100644 index 0000000..9423027 --- /dev/null +++ b/aria/vllm/vision_encoder.py @@ -0,0 +1,94 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig +from vllm.config import QuantizationConfig +from vllm.model_executor.models.idefics2_vision_model import Idefics2VisionTransformer + + +class AriaVisionConfig(Idefics2VisionConfig): + model_type = "aria_vision_model" + + +class IdentityOp(torch.nn.Module): + """ + An identity operation that returns the input unchanged. + + This can be used as a placeholder or to maintain architectural consistency + when a specific operation is not needed. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + +class AriaVisionTransformer(Idefics2VisionTransformer): + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config, prefix) + self.post_layernorm = IdentityOp() + + +class AriaVisionModel(nn.Module): + config_class = AriaVisionConfig + + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.vision_model = AriaVisionTransformer( + config, + quant_config, + prefix=f"{prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + pixel_mask: Optional[torch.BoolTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]: + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + + vit_oup = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + image_atts = self._create_image_attention_mask(patch_attention_mask) + + return vit_oup, image_atts + + def _create_patch_attention_mask(self, pixel_mask): + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ).unfold( + dimension=2, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def _create_image_attention_mask(self, patch_attention_mask): + if patch_attention_mask is None: + return None + + flattened_mask = patch_attention_mask.flatten(1) + return torch.logical_not(flattened_mask)