-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable llava on torchchat #1183
base: main
Are you sure you want to change the base?
Changes from all commits
f52007e
32d969e
9e4350d
72d7b96
dfe37b8
1834696
8ecc2fa
a70d7b5
937e7ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,10 +7,13 @@ | |||||||||
import os | ||||||||||
import re | ||||||||||
import sys | ||||||||||
import glob | ||||||||||
from pathlib import Path | ||||||||||
from typing import Optional | ||||||||||
from typing import Any, Dict, Optional | ||||||||||
|
||||||||||
import torch | ||||||||||
import safetensors.torch | ||||||||||
import shutil | ||||||||||
|
||||||||||
from torchchat.model import TransformerArgs | ||||||||||
|
||||||||||
|
@@ -21,9 +24,176 @@ | |||||||||
|
||||||||||
from torchchat.model import ModelArgs | ||||||||||
|
||||||||||
def remap_llava_checkpoint(llava_ckpt): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this written inhouse? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not pretty following your question. |
||||||||||
def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]: | ||||||||||
translated_state_dict = {} | ||||||||||
hf_weight_prefix = "vision_model." | ||||||||||
name_mapping = { | ||||||||||
f"{hf_weight_prefix}embeddings.class_embedding": "encoder.cls_token_embedding.weight", | ||||||||||
f"{hf_weight_prefix}embeddings.position_embedding.weight": "encoder.token_pos_embedding.positional_embedding", | ||||||||||
f"{hf_weight_prefix}embeddings.patch_embedding.weight": "encoder.conv.weight", | ||||||||||
f"{hf_weight_prefix}pre_layrnorm.weight": "encoder.ln_pre.weight", | ||||||||||
f"{hf_weight_prefix}pre_layrnorm.bias": "encoder.ln_pre.bias", | ||||||||||
f"{hf_weight_prefix}post_layernorm.weight": "encoder.ln_post.weight", | ||||||||||
f"{hf_weight_prefix}post_layernorm.bias": "encoder.ln_post.bias", | ||||||||||
} | ||||||||||
patterns = [ | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}", | ||||||||||
), | ||||||||||
] | ||||||||||
for pattern, replacement in patterns: | ||||||||||
for key in list(hf_state_dict.keys()): | ||||||||||
if re.match(pattern, key): | ||||||||||
new_key = re.sub(pattern, replacement, key) | ||||||||||
name_mapping[key] = new_key | ||||||||||
temp_state_dict = {} | ||||||||||
for k, v in hf_state_dict.items(): | ||||||||||
new_k = name_mapping.get(k, k) | ||||||||||
if "in_proj_weight" in new_k or "in_proj_bias" in new_k: | ||||||||||
if new_k not in temp_state_dict: | ||||||||||
temp_state_dict[new_k] = {"q": None, "k": None, "v": None} | ||||||||||
if "q_proj" in k: | ||||||||||
temp_state_dict[new_k]["q"] = v | ||||||||||
elif "k_proj" in k: | ||||||||||
temp_state_dict[new_k]["k"] = v | ||||||||||
elif "v_proj" in k: | ||||||||||
temp_state_dict[new_k]["v"] = v | ||||||||||
else: | ||||||||||
temp_state_dict[new_k] = v | ||||||||||
for k, v in temp_state_dict.items(): | ||||||||||
if isinstance(v, dict): | ||||||||||
translated_state_dict[k] = torch.cat([v["q"], v["k"], v["v"]], dim=0) | ||||||||||
else: | ||||||||||
translated_state_dict[k] = v | ||||||||||
return translated_state_dict | ||||||||||
|
||||||||||
def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]: | ||||||||||
key_map = { | ||||||||||
r"model.layers.([0-9]+).self_attn.q_proj.": r"decoder.layers.\1.attention.wq.", | ||||||||||
r"model.layers.([0-9]+).self_attn.k_proj.": r"decoder.layers.\1.attention.wk.", | ||||||||||
r"model.layers.([0-9]+).self_attn.v_proj.": r"decoder.layers.\1.attention.wv.", | ||||||||||
r"model.layers.([0-9]+).self_attn.o_proj.": r"decoder.layers.\1.attention.wo.", | ||||||||||
r"model.layers.([0-9]+).input_layernorm.": r"decoder.layers.\1.attention_norm.", | ||||||||||
r"model.layers.([0-9]+).mlp.gate_proj.": r"decoder.layers.\1.feed_forward.w1.", | ||||||||||
r"model.layers.([0-9]+).mlp.down_proj.": r"decoder.layers.\1.feed_forward.w2.", | ||||||||||
r"model.layers.([0-9]+).mlp.up_proj.": r"decoder.layers.\1.feed_forward.w3.", | ||||||||||
r"model.layers.([0-9]+).post_attention_layernorm.": r"decoder.layers.\1.ffn_norm.", | ||||||||||
r"model.norm.": r"decoder.norm.", | ||||||||||
# r"model.embed_tokens.": r"tok_embeddings.", # load separately | ||||||||||
r"lm_head.": r"decoder.output.", | ||||||||||
} | ||||||||||
new_state_dict = {} | ||||||||||
def get_new_key(old_key: str) -> str: | ||||||||||
for old_pattern, replacement in key_map.items(): | ||||||||||
if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: | ||||||||||
return new_key | ||||||||||
return old_key | ||||||||||
for old_key in hf_state_dict.keys(): | ||||||||||
new_key = get_new_key(old_key) | ||||||||||
new_state_dict[new_key] = hf_state_dict[old_key] | ||||||||||
return new_state_dict | ||||||||||
|
||||||||||
def _translate_state_dict_for_mm_projector_model(hf_state_dict) -> Dict[str, Any]: | ||||||||||
new_state_dict = {} | ||||||||||
for old_key in hf_state_dict.keys(): | ||||||||||
new_key = "mm_projector." + old_key | ||||||||||
new_state_dict[new_key] = hf_state_dict[old_key] | ||||||||||
return new_state_dict | ||||||||||
|
||||||||||
def split_checkpoint(llava_ckpt): | ||||||||||
language_model_ckpt = {} | ||||||||||
multi_modal_ckpt = {} | ||||||||||
vision_tower_ckpt = {} | ||||||||||
for key, value in llava_ckpt.items(): | ||||||||||
if key.startswith("language_model"): | ||||||||||
language_model_ckpt[key[len("language_model") + 1:]] = value | ||||||||||
elif key.startswith("multi_modal_projector"): | ||||||||||
multi_modal_ckpt[key[len("multi_modal_projector") + 1:]] = value | ||||||||||
elif key.startswith("vision_tower"): | ||||||||||
vision_tower_ckpt[key[len("vision_tower") + 1:]] = value | ||||||||||
return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt | ||||||||||
language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt) | ||||||||||
remapped_state_dict = { | ||||||||||
"tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"), | ||||||||||
} | ||||||||||
remapped_state_dict.update(_translate_state_dict_for_text_model(language_model_ckpt)) | ||||||||||
remapped_state_dict.update(_translate_state_dict_for_vision_model(vision_tower_ckpt)) | ||||||||||
remapped_state_dict.update(_translate_state_dict_for_mm_projector_model(multi_modal_ckpt)) | ||||||||||
return remapped_state_dict | ||||||||||
|
||||||||||
|
||||||||||
@torch.inference_mode | ||||||||||
def convert_llava_checkpoint( | ||||||||||
*, | ||||||||||
model_dir: Optional[Path] = None, | ||||||||||
) -> None: | ||||||||||
|
||||||||||
""" | ||||||||||
Process safetensor files from a specific directory structure and save the remapped model. | ||||||||||
|
||||||||||
Args: | ||||||||||
model_dir (str): Base directory containing the model subdirectories. | ||||||||||
""" | ||||||||||
|
||||||||||
def _get_llava_files_with_pattern(pattern): | ||||||||||
pattern = os.path.join(model_dir, f"models--llava-hf--llava-1.5-7b-hf/snapshots/*/{pattern}") | ||||||||||
return glob.glob(pattern) | ||||||||||
|
||||||||||
# get all safetensor files in the model directory | ||||||||||
safetensor_files = _get_llava_files_with_pattern("*.safetensors") | ||||||||||
|
||||||||||
if not safetensor_files: | ||||||||||
raise ValueError("No safetensor files found.") | ||||||||||
|
||||||||||
merged_weights = {} | ||||||||||
|
||||||||||
# Merge safetensor files into a whole | ||||||||||
for file in safetensor_files: | ||||||||||
# Load weights from the current file | ||||||||||
part_weights = safetensors.torch.load_file(file) | ||||||||||
|
||||||||||
# Iterate over each weight in the current file | ||||||||||
for key, value in part_weights.items(): | ||||||||||
if key in merged_weights: | ||||||||||
# If the key already exists, concatenate tensors | ||||||||||
merged_weights[key] = torch.cat((merged_weights[key], value), dim=0) | ||||||||||
else: | ||||||||||
# If the key does not exist, add it to the dictionary | ||||||||||
merged_weights[key] = value | ||||||||||
|
||||||||||
# Remap the checkpoint and save it as pth | ||||||||||
remapped_weights = remap_llava_checkpoint(merged_weights) | ||||||||||
model_path = model_dir / "model.pth" | ||||||||||
torch.save(remapped_weights, model_path) | ||||||||||
|
||||||||||
# copy tokenizer | ||||||||||
tokenizer_files = _get_llava_files_with_pattern("tokenizer.model") | ||||||||||
assert len(tokenizer_files) == 1, "Should get only one tokenizer file, but got {}".format(tokenizer_files) | ||||||||||
|
||||||||||
tokenizer_path = model_dir / "tokenizer.model" | ||||||||||
shutil.copy(tokenizer_files[0], tokenizer_path) | ||||||||||
|
||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
@torch.inference_mode() | ||||||||||
def convert_hf_checkpoint( | ||||||||||
def convert_text_only_hf_checkpoint( | ||||||||||
*, | ||||||||||
model_dir: Optional[Path] = None, | ||||||||||
model_name: Optional[str] = None, | ||||||||||
|
@@ -132,6 +302,19 @@ def permute(w, n_heads): | |||||||||
os.remove(file) | ||||||||||
|
||||||||||
|
||||||||||
@torch.inference_mode() | ||||||||||
def convert_hf_checkpoint( | ||||||||||
*, | ||||||||||
model_dir: Optional[Path] = None, | ||||||||||
model_name: Optional[str] = None, | ||||||||||
remove_bin_files: bool = False, | ||||||||||
): | ||||||||||
if "llava" in model_name: | ||||||||||
convert_llava_checkpoint(model_dir=model_dir) | ||||||||||
else: | ||||||||||
convert_text_only_hf_checkpoint(model_dir=model_dir, model_name=model_name, remove_bin_files=remove_bin_files) | ||||||||||
|
||||||||||
|
||||||||||
if __name__ == "__main__": | ||||||||||
import argparse | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -36,6 +36,7 @@ | |||||
from torchchat.model import Model, ModelType | ||||||
from torchchat.utils.build_utils import device_sync, set_precision | ||||||
from torchchat.utils.device_info import get_device_info | ||||||
from torchchat.utils.preprocessors import llava_image_preprocess | ||||||
|
||||||
# torchtune model definition dependencies | ||||||
from torchtune.data import Message | ||||||
|
@@ -357,8 +358,13 @@ def prefill( | |||||
|
||||||
if batch is not None: | ||||||
# TODO: Verify sequential prefill works with multimodal models | ||||||
logits = model(**batch)[:, -1] | ||||||
return tune_sample(logits, 0, 500) | ||||||
logits = model(**batch) | ||||||
if model.config.model_type == ModelType.Llava: | ||||||
context_len, logits = logits[0], logits[1][:, -1] | ||||||
return context_len, tune_sample(logits, 0, 500) | ||||||
else: | ||||||
logits = logits[:, -1] | ||||||
return tune_sample(logits, 0, 500) | ||||||
elif sequential_prefill: | ||||||
for i in range(width): | ||||||
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) | ||||||
|
@@ -622,6 +628,13 @@ def generate( | |||||
sequential_prefill=sequential_prefill, | ||||||
**sampling_kwargs, | ||||||
) | ||||||
|
||||||
# For llava with image input, we need to extract next pos id from prefill result | ||||||
if batch and self.model.config.model_type == ModelType.Llava: | ||||||
context_len, next_token = next_token | ||||||
else: | ||||||
context_len, next_token = T, next_token | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
if is_speculative: | ||||||
self.prefill( | ||||||
draft_model, | ||||||
|
@@ -636,7 +649,7 @@ def generate( | |||||
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens(). | ||||||
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2) | ||||||
|
||||||
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int) | ||||||
input_pos = torch.tensor([start_pos + context_len], device=device, dtype=torch.int) | ||||||
accept_counts = [0] * ( | ||||||
speculate_k + 1 | ||||||
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long | ||||||
|
@@ -726,27 +739,56 @@ def chat( | |||||
|
||||||
if generator_args.image_prompts is not None: | ||||||
print("Image prompts", generator_args.image_prompts) | ||||||
|
||||||
# Support for just the first image prompt for now | ||||||
images = [Image.open(generator_args.image_prompts[0])] | ||||||
messages = [ | ||||||
Message( | ||||||
role="user", | ||||||
content=[ | ||||||
{"type": "image", "content": images[0]}, | ||||||
{"type": "text", "content": generator_args.prompt}, | ||||||
], | ||||||
eot=True, | ||||||
), | ||||||
Message(role="assistant", content=""), | ||||||
] | ||||||
|
||||||
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) | ||||||
data = transform({"messages": messages}, inference=True) | ||||||
batch = padded_collate([data], self.builder_args.device) | ||||||
batch.pop("mask") | ||||||
encoded = batch["tokens"] | ||||||
assert len(images) == 1, "Only one image prompt is supported for now" | ||||||
|
||||||
#TODO: updated encoded variable for multi-modality models to include image tokens. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain this to me? |
||||||
if self.model.config.model_type == ModelType.Flamingo: | ||||||
messages = [ | ||||||
Message( | ||||||
role="user", | ||||||
content=[ | ||||||
{"type": "image", "content": images[0]}, | ||||||
{"type": "text", "content": generator_args.prompt}, | ||||||
], | ||||||
eot=True, | ||||||
), | ||||||
Message(role="assistant", content=""), | ||||||
] | ||||||
|
||||||
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) | ||||||
data = transform({"messages": messages}, inference=True) | ||||||
batch = padded_collate([data], self.builder_args.device) | ||||||
batch.pop("mask") | ||||||
encoded = batch["tokens"] | ||||||
elif self.model.config.model_type == ModelType.Llava: | ||||||
#TODO: double check the tokenizer. | ||||||
def find_subtensor(tensor, target): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typehints |
||||||
target_len = len(target) | ||||||
for i in range(len(tensor) - target_len + 1): | ||||||
if torch.all(tensor[i:i+target_len] == target): | ||||||
return i | ||||||
return -1 | ||||||
|
||||||
input_ids = self.encode_tokens(generator_args.prompt, bos=True, device=self.builder_args.device) | ||||||
image_token_indices = self.encode_tokens("<image>", device=self.builder_args.device)[1:] | ||||||
index = find_subtensor(input_ids, image_token_indices) | ||||||
|
||||||
if index == -1: | ||||||
raise ValueError("Image token not found in prompt") | ||||||
|
||||||
batch = { | ||||||
"tokens": input_ids[:index].unsqueeze(0), | ||||||
"encoder_input": llava_image_preprocess(images[0], device=self.builder_args.device, dtype=self.builder_args.precision), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be misunderstanding batch, but it looks like the batch variable isn't used? Especially encoder_input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the whole batch will be forwarded into llava model, so everything here is used by llava's forward function |
||||||
"post_tokens": input_ids[index + len(image_token_indices) :].unsqueeze(0), | ||||||
} | ||||||
|
||||||
# can not get actual encoded image feature before model inference; pseudo one | ||||||
pseudo_vision_encoded = torch.zeros(1, 624).to(device=self.builder_args.device, dtype=self.builder_args.precision) | ||||||
encoded = torch.cat([batch["tokens"].view(1, -1), pseudo_vision_encoded, batch["post_tokens"].view(1, -1)], dim=-1).view(-1) | ||||||
|
||||||
else: | ||||||
encoded = self.encode_tokens( | ||||||
generator_args.prompt, bos=True, device=self.builder_args.device | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code comment blocks to help us move things around later