Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable llava on torchchat #1183

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 185 additions & 2 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import os
import re
import sys
import glob
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional

import torch
import safetensors.torch
import shutil

from torchchat.model import TransformerArgs

Expand All @@ -21,9 +24,176 @@

from torchchat.model import ModelArgs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Llava Conversion Code
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code comment blocks to help us move things around later

def remap_llava_checkpoint(llava_ckpt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this written inhouse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not pretty following your question.
This function is consumed by convert_llava_checkpoint to get remapped checkpoint.
I made this as an individual function to simply the logic

def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]:
translated_state_dict = {}
hf_weight_prefix = "vision_model."
name_mapping = {
f"{hf_weight_prefix}embeddings.class_embedding": "encoder.cls_token_embedding.weight",
f"{hf_weight_prefix}embeddings.position_embedding.weight": "encoder.token_pos_embedding.positional_embedding",
f"{hf_weight_prefix}embeddings.patch_embedding.weight": "encoder.conv.weight",
f"{hf_weight_prefix}pre_layrnorm.weight": "encoder.ln_pre.weight",
f"{hf_weight_prefix}pre_layrnorm.bias": "encoder.ln_pre.bias",
f"{hf_weight_prefix}post_layernorm.weight": "encoder.ln_post.weight",
f"{hf_weight_prefix}post_layernorm.bias": "encoder.ln_post.bias",
}
patterns = [
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}",
),
]
for pattern, replacement in patterns:
for key in list(hf_state_dict.keys()):
if re.match(pattern, key):
new_key = re.sub(pattern, replacement, key)
name_mapping[key] = new_key
temp_state_dict = {}
for k, v in hf_state_dict.items():
new_k = name_mapping.get(k, k)
if "in_proj_weight" in new_k or "in_proj_bias" in new_k:
if new_k not in temp_state_dict:
temp_state_dict[new_k] = {"q": None, "k": None, "v": None}
if "q_proj" in k:
temp_state_dict[new_k]["q"] = v
elif "k_proj" in k:
temp_state_dict[new_k]["k"] = v
elif "v_proj" in k:
temp_state_dict[new_k]["v"] = v
else:
temp_state_dict[new_k] = v
for k, v in temp_state_dict.items():
if isinstance(v, dict):
translated_state_dict[k] = torch.cat([v["q"], v["k"], v["v"]], dim=0)
else:
translated_state_dict[k] = v
return translated_state_dict

def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]:
key_map = {
r"model.layers.([0-9]+).self_attn.q_proj.": r"decoder.layers.\1.attention.wq.",
r"model.layers.([0-9]+).self_attn.k_proj.": r"decoder.layers.\1.attention.wk.",
r"model.layers.([0-9]+).self_attn.v_proj.": r"decoder.layers.\1.attention.wv.",
r"model.layers.([0-9]+).self_attn.o_proj.": r"decoder.layers.\1.attention.wo.",
r"model.layers.([0-9]+).input_layernorm.": r"decoder.layers.\1.attention_norm.",
r"model.layers.([0-9]+).mlp.gate_proj.": r"decoder.layers.\1.feed_forward.w1.",
r"model.layers.([0-9]+).mlp.down_proj.": r"decoder.layers.\1.feed_forward.w2.",
r"model.layers.([0-9]+).mlp.up_proj.": r"decoder.layers.\1.feed_forward.w3.",
r"model.layers.([0-9]+).post_attention_layernorm.": r"decoder.layers.\1.ffn_norm.",
r"model.norm.": r"decoder.norm.",
# r"model.embed_tokens.": r"tok_embeddings.", # load separately
r"lm_head.": r"decoder.output.",
}
new_state_dict = {}
def get_new_key(old_key: str) -> str:
for old_pattern, replacement in key_map.items():
if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key:
return new_key
return old_key
for old_key in hf_state_dict.keys():
new_key = get_new_key(old_key)
new_state_dict[new_key] = hf_state_dict[old_key]
return new_state_dict

def _translate_state_dict_for_mm_projector_model(hf_state_dict) -> Dict[str, Any]:
new_state_dict = {}
for old_key in hf_state_dict.keys():
new_key = "mm_projector." + old_key
new_state_dict[new_key] = hf_state_dict[old_key]
return new_state_dict

def split_checkpoint(llava_ckpt):
language_model_ckpt = {}
multi_modal_ckpt = {}
vision_tower_ckpt = {}
for key, value in llava_ckpt.items():
if key.startswith("language_model"):
language_model_ckpt[key[len("language_model") + 1:]] = value
elif key.startswith("multi_modal_projector"):
multi_modal_ckpt[key[len("multi_modal_projector") + 1:]] = value
elif key.startswith("vision_tower"):
vision_tower_ckpt[key[len("vision_tower") + 1:]] = value
return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt
language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt)
remapped_state_dict = {
"tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"),
}
remapped_state_dict.update(_translate_state_dict_for_text_model(language_model_ckpt))
remapped_state_dict.update(_translate_state_dict_for_vision_model(vision_tower_ckpt))
remapped_state_dict.update(_translate_state_dict_for_mm_projector_model(multi_modal_ckpt))
return remapped_state_dict


@torch.inference_mode
def convert_llava_checkpoint(
*,
model_dir: Optional[Path] = None,
) -> None:

"""
Process safetensor files from a specific directory structure and save the remapped model.

Args:
model_dir (str): Base directory containing the model subdirectories.
"""

def _get_llava_files_with_pattern(pattern):
pattern = os.path.join(model_dir, f"models--llava-hf--llava-1.5-7b-hf/snapshots/*/{pattern}")
return glob.glob(pattern)

# get all safetensor files in the model directory
safetensor_files = _get_llava_files_with_pattern("*.safetensors")

if not safetensor_files:
raise ValueError("No safetensor files found.")

merged_weights = {}

# Merge safetensor files into a whole
for file in safetensor_files:
# Load weights from the current file
part_weights = safetensors.torch.load_file(file)

# Iterate over each weight in the current file
for key, value in part_weights.items():
if key in merged_weights:
# If the key already exists, concatenate tensors
merged_weights[key] = torch.cat((merged_weights[key], value), dim=0)
else:
# If the key does not exist, add it to the dictionary
merged_weights[key] = value

# Remap the checkpoint and save it as pth
remapped_weights = remap_llava_checkpoint(merged_weights)
model_path = model_dir / "model.pth"
torch.save(remapped_weights, model_path)

# copy tokenizer
tokenizer_files = _get_llava_files_with_pattern("tokenizer.model")
assert len(tokenizer_files) == 1, "Should get only one tokenizer file, but got {}".format(tokenizer_files)

tokenizer_path = model_dir / "tokenizer.model"
shutil.copy(tokenizer_files[0], tokenizer_path)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Text-Only Conversion Code
"""

@torch.inference_mode()
def convert_hf_checkpoint(
def convert_text_only_hf_checkpoint(
*,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -132,6 +302,19 @@ def permute(w, n_heads):
os.remove(file)


@torch.inference_mode()
def convert_hf_checkpoint(
*,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
remove_bin_files: bool = False,
):
if "llava" in model_name:
convert_llava_checkpoint(model_dir=model_dir)
else:
convert_text_only_hf_checkpoint(model_dir=model_dir, model_name=model_name, remove_bin_files=remove_bin_files)


if __name__ == "__main__":
import argparse

Expand Down
3 changes: 2 additions & 1 deletion torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def _download_hf_snapshot(
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
ignore_patterns=None if "llava" in model_config.name else "*safetensors*",
)

except HTTPError as e:
if e.response.status_code == 401: # Missing HuggingFace CLI login.
print(
Expand Down
82 changes: 62 additions & 20 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torchchat.model import Model, ModelType
from torchchat.utils.build_utils import device_sync, set_precision
from torchchat.utils.device_info import get_device_info
from torchchat.utils.preprocessors import llava_image_preprocess

# torchtune model definition dependencies
from torchtune.data import Message
Expand Down Expand Up @@ -357,8 +358,13 @@ def prefill(

if batch is not None:
# TODO: Verify sequential prefill works with multimodal models
logits = model(**batch)[:, -1]
return tune_sample(logits, 0, 500)
logits = model(**batch)
if model.config.model_type == ModelType.Llava:
context_len, logits = logits[0], logits[1][:, -1]
return context_len, tune_sample(logits, 0, 500)
else:
logits = logits[:, -1]
return tune_sample(logits, 0, 500)
elif sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
Expand Down Expand Up @@ -622,6 +628,13 @@ def generate(
sequential_prefill=sequential_prefill,
**sampling_kwargs,
)

# For llava with image input, we need to extract next pos id from prefill result
if batch and self.model.config.model_type == ModelType.Llava:
context_len, next_token = next_token
else:
context_len, next_token = T, next_token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context_len, next_token = T, next_token
context_len = T


if is_speculative:
self.prefill(
draft_model,
Expand All @@ -636,7 +649,7 @@ def generate(
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)

input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
input_pos = torch.tensor([start_pos + context_len], device=device, dtype=torch.int)
accept_counts = [0] * (
speculate_k + 1
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
Expand Down Expand Up @@ -726,27 +739,56 @@ def chat(

if generator_args.image_prompts is not None:
print("Image prompts", generator_args.image_prompts)

# Support for just the first image prompt for now
images = [Image.open(generator_args.image_prompts[0])]
messages = [
Message(
role="user",
content=[
{"type": "image", "content": images[0]},
{"type": "text", "content": generator_args.prompt},
],
eot=True,
),
Message(role="assistant", content=""),
]

transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
data = transform({"messages": messages}, inference=True)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]
assert len(images) == 1, "Only one image prompt is supported for now"

#TODO: updated encoded variable for multi-modality models to include image tokens.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this to me?

if self.model.config.model_type == ModelType.Flamingo:
messages = [
Message(
role="user",
content=[
{"type": "image", "content": images[0]},
{"type": "text", "content": generator_args.prompt},
],
eot=True,
),
Message(role="assistant", content=""),
]

transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
data = transform({"messages": messages}, inference=True)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]
elif self.model.config.model_type == ModelType.Llava:
#TODO: double check the tokenizer.
def find_subtensor(tensor, target):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typehints

target_len = len(target)
for i in range(len(tensor) - target_len + 1):
if torch.all(tensor[i:i+target_len] == target):
return i
return -1

input_ids = self.encode_tokens(generator_args.prompt, bos=True, device=self.builder_args.device)
image_token_indices = self.encode_tokens("<image>", device=self.builder_args.device)[1:]
index = find_subtensor(input_ids, image_token_indices)

if index == -1:
raise ValueError("Image token not found in prompt")

batch = {
"tokens": input_ids[:index].unsqueeze(0),
"encoder_input": llava_image_preprocess(images[0], device=self.builder_args.device, dtype=self.builder_args.precision),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding batch, but it looks like the batch variable isn't used?
Just the entries of batch

Especially encoder_input

Copy link
Contributor Author

@Gasoonjia Gasoonjia Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole batch will be forwarded into llava model, so everything here is used by llava's forward function

"post_tokens": input_ids[index + len(image_token_indices) :].unsqueeze(0),
}

# can not get actual encoded image feature before model inference; pseudo one
pseudo_vision_encoded = torch.zeros(1, 624).to(device=self.builder_args.device, dtype=self.builder_args.precision)
encoded = torch.cat([batch["tokens"].view(1, -1), pseudo_vision_encoded, batch["post_tokens"].view(1, -1)], dim=-1).view(-1)

else:
encoded = self.encode_tokens(
generator_args.prompt, bos=True, device=self.builder_args.device
Expand Down
Loading
Loading