This is a PyTorch implementation of iLLaMA proposed by our paper "Adapting LLaMA Decoder to Vision Transformer".
Figure 1: Left: iLLaMA architecture. Right: our design roadmap. Colored and gray bars represent the results of the tiny and base regimes, with the red line depicting the training loss of the tiny regime. iLLaMA strives to process visual tokens using standard LLaMa components, e.g., causal self-attention. The proposed PS [cls] and soft mask strategy help overcome training challenges.
Figure 2: (a) mask in causal self-attention. (b) mask in causal self-attention with our post-sequence class token (PS [cls]) method. (c) modified causal mask.
Figure 3: (a) Soft mask gradually transitions from a bi-directional mask into a causal mask during training through a constant or linear schedule. (b) Ablation results of training loss and test accuracy.
PyTorch and timm 0.5.4 (pip install timm==0.5.4
).
Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this script.
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Model | Pre-trained dataset | Resolution | Params | MACs | Top1 Acc |
---|---|---|---|---|---|
illama_tiny | - | 224 | 5.7M | 1.3G | 75.0 |
illama_small | - | 224 | 21.9M | 4.6G | 79.9 |
illama_base | - | 224 | 86.3M | 17.6G | 81.6 |
illama_base | - | 384 | 86.3M | 55.5G | 83.0 |
illama_base | ImageNet-21K | 224 | 86.3M | 17.6G | 83.6 |
illama_base | ImageNet-21K | 384 | 86.3M | 55.5G | 85.0 |
illama_large | ImageNet-21K | 224 | 310.2M | 62.8G | 84.8 |
illama_large | ImageNet-21K | 384 | 310.2M | 194.7G | 86.0 |
To evaluate models on 224 resolution, run:
MODEL=illama_tiny
RESUME='/your/path/to/model.pth'
python -m torch.distributed.launch --nproc_per_node=2 main.py \
--model $MODEL --eval true \
--data_path $root_imagenet \
--resume $RESUME
To evaluate models on 384 resolution, run:
MODEL=illama_base
RESUME='/your/path/to/model.pth'
python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \
--model $MODEL --input_size 384 --eval true \
--data_path $root_imagenet \
--resume $RESUME
We use batch size of 4096 by default with 8 GPUs.
bash scripts/train_illama_tiny_in1k.sh
Training scripts of other models are shown in scripts.
We use weight selection method to select weights from LLaMA2-7B.
python llama2/weight_selection.py
Then we use the selected weights to initialize our iLLaMA-T/S/B.
bash scripts/train_illama_tiny_from_llama2.sh
Training scripts of other models are shown in scripts.
@article{wang2024adapting,
title={Adapting LLaMA Decoder to Vision Transformer},
author={Wang, Jiahao and Shao, Wenqi and Chen, Mengzhao and Wu, Chengyue and Liu, Yong and Zhang, Kaipeng and Zhang, Songyang and Chen, Kai and Luo, Ping},
journal={arXiv preprint arXiv:2404.06773},
year={2024}
}
Our implementation is based on pytorch-image-models, llama, dropout, ConvNeXt, weight-selection, and MambaOut.