Skip to content

Commit

Permalink
support intern-omni for different language parts
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Nov 29, 2024
1 parent 42334b1 commit 44c3e2f
Showing 1 changed file with 106 additions and 2 deletions.
108 changes: 106 additions & 2 deletions llmc/models/internomni.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional

import librosa
import torch
from loguru import logger
from transformers import GenerationConfig

try:
from internvl.conversation import get_conv_template
Expand All @@ -15,7 +18,6 @@

from llmc.utils.registry_factory import MODEL_REGISTRY

from .internlm2 import InternLM2
from .internvl2 import load_image


Expand All @@ -37,8 +39,110 @@ def load_audio(audio_file, audio_processor):
return audio_input


@torch.no_grad()
def generate_patch_for_internvl_qwen2(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor,
attention_mask: torch.LongTensor,
visual_features: Optional[torch.FloatTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
audio_len_after_cnn: Optional[bool] = None,
audio_token_num: Optional[bool] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:

assert self.img_context_token_id is not None
assert self.audio_context_token_id is not None

vit_embeds = None
if visual_features is not None:
vit_embeds = visual_features
elif pixel_values is not None:
vit_embeds = self.extract_feature(pixel_values)

input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)

input_ids = input_ids.reshape(B * N)

if vit_embeds is not None:
selected = (input_ids == self.img_context_token_id)
input_embeds[selected] = vit_embeds.reshape(-1, C)

if audio_values is not None and audio_len_after_cnn is not None and audio_token_num is not None:
audio_embeds = self.extract_audio_feature(audio_values, audio_len_after_cnn)
output_audios = []
for i in range(len(audio_token_num)):
token_num = int(audio_token_num[i].item())
audio = audio_embeds[i][:token_num]
output_audios.append(audio)
output_audios = torch.cat(output_audios, dim=0)
selected = (input_ids == self.audio_context_token_id)
input_embeds[selected] = output_audios.reshape(-1, C)

input_embeds = input_embeds.reshape(B, N, C)

outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
**generate_kwargs,
)

return outputs


@MODEL_REGISTRY
class InternOmni():
def __new__(cls, config, device_map=None, use_cache=False):
avlm_model_config = InternVLChatAudioConfig.from_pretrained(
config.model.path, trust_remote_code=True
)
language_part = avlm_model_config.llm_config.model_type
logger.warning(f'InternOmni language_part: {language_part}')
if language_part == 'internlm2':
from .internlm2 import InternLM2

class NewClass(InternOmniSharedBehavior, InternLM2):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
elif language_part == 'qwen2':
from .qwen2 import Qwen2

class NewClass(InternOmniSharedBehavior, Qwen2):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
setattr(
self.avlm_model,
'generate',
generate_patch_for_internvl_qwen2.__get__(self.avlm_model),
)
elif language_part == 'phi3':
from .phi3 import Phi3

class NewClass(InternOmniSharedBehavior, Phi3):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
elif language_part == 'llama':
from .llama import Llama

class NewClass(InternOmniSharedBehavior, Llama):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
else:
raise Exception(f'Not support for language_part: {language_part}')
return NewClass(config, device_map, use_cache)


@MODEL_REGISTRY
class InternOmni(InternLM2):
class InternOmniSharedBehavior():
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

Expand Down

0 comments on commit 44c3e2f

Please sign in to comment.