Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen2-vl 2b 4-bit always getting OOM, yet llama3.2 11b works! #1326

Open
mehamednews opened this issue Nov 22, 2024 · 1 comment
Open

qwen2-vl 2b 4-bit always getting OOM, yet llama3.2 11b works! #1326

mehamednews opened this issue Nov 22, 2024 · 1 comment

Comments

@mehamednews
Copy link

qwen2-vl has always been memory hungry (compared to the other vision models) and even with unsloth it still OOMs when the largest llama3.2 11b works fine.
I'm using a dataset that has high resolution images ~1200px, running with the Latex dataset did work with qwen.
Not sure if this can be fixed.
Any help would be appreciated.

here's the code I'm using (replacing llama3.2 with qwen fails)

import json
from unsloth import FastVisionModel, is_bf16_supported  # FastLanguageModel for LLMs
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
    load_in_4bit=True,  # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing="unsloth",  # True or "unsloth" for long context
)

model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers=True,  # False if not finetuning vision layers
    finetune_language_layers=True,  # False if not finetuning language layers
    finetune_attention_modules=True,  # False if not finetuning attention layers
    finetune_mlp_modules=True,  # False if not finetuning MLP layers
    r=16,  # The larger, the higher the accuracy, but might overfit
    lora_alpha=16,  # Recommended alpha == r at least
    lora_dropout=0,
    bias="none",
    random_state=3407,
    use_rslora=False,  # We support rank stabilized LoRA
    loftq_config=None,  # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

# Load the JSONL dataset
dataset_path = "./label-dataset-train.jsonl"
dataset = []
with open(dataset_path, "r") as f:
    for line in f:
        sample = json.loads(line)
        # if len(sample["images"]) > 1:
        #     continue
        conversation = [
            {
                "role": "user",
                "content": [{"type": "text", "text": sample["query"]}, *[{"type": "image", "image": img} for img in sample["images"]]],
            },
            {"role": "assistant", "content": [{"type": "text", "text": sample["response"]}]},
        ]
        dataset.append({"messages": conversation})

converted_dataset = dataset
print(len(dataset))


FastVisionModel.for_training(model)  # Enable for training!

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=UnslothVisionDataCollator(model, tokenizer),  # Must use!
    train_dataset=converted_dataset,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        warmup_steps=10,
        max_steps=50,
        # num_train_epochs=1,  # Set this instead of max_steps for full training runs
        learning_rate=2e-4,
        fp16=not is_bf16_supported(),
        bf16=is_bf16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        report_to="none",  # For Weights and Biases
        # You MUST put the below items for vision finetuning:
        remove_unused_columns=False,
        dataset_text_field="",
        dataset_kwargs={"skip_prepare_dataset": True},
        dataset_num_proc=4,
        max_seq_length=2048,
    ),
)

trainer_stats = trainer.train()
@WizKnight
Copy link

Hey @mehamednews :), Qwen2-VL uses more memory than Llama-3.2 due to its architecture and the way it processes images.
Since you're working with high-resolution images, try experimenting with couple of things:

  1. Downsampling the images resolution to ~512px or ~256px. This can significantly reduce memory usage.

  2. Increasing the gradient_accumulation_steps in your code, this will help to add a larger batch size without loading all the data into memory at once.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants