Skip to content

Commit

Permalink
fix for siglip, llava, and lr decay
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Feb 24, 2024
1 parent 1e75320 commit feba465
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 75 deletions.
19 changes: 10 additions & 9 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def create_model_and_transforms(
)
vision_encoder.visual.output_tokens = True
vision_encoder = vision_encoder.visual
vis_hidden_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
"width"
]
vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path)
if "SigLIP" in clip_vision_encoder_path: # SigLIP models have a different config format
vis_hidden_dim = vision_encoder_config["embed_dim"]
else:
vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"]

# load tokenizer and ensure there is a pad token
text_tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -145,6 +147,9 @@ def _infer_decoder_layers_attr_name(model):
"gptneoxforcausallm": "gpt_neox.layers",
"mpt": "transformer.blocks",
"mosaicgpt": "transformer.blocks",
"gemma": "model.layers",
"phi": "model.layers",
"minicpm": "model.layers",
}


Expand Down Expand Up @@ -194,9 +199,5 @@ def check_embedding_fns(lang_model):


def has_fn(model, fn_name):
"""Try to call the fn_name function on the model"""
try:
getattr(model, fn_name)()
return True
except:
return False
"""Check if model has a function fn_name"""
return callable(getattr(model, fn_name, None))
8 changes: 7 additions & 1 deletion open_flamingo/src/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@ def __init__(
"media_token": "<image>",
}
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]

if vision_encoder.__class__.__name__ == "TimmModel":
grid_size = vision_encoder.trunk.patch_embed.grid_size
else:
grid_size = vision_encoder.grid_size

super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=LinearPatchProjection(dim_visual=vis_feature_dim,
dim_out=lang_embedding_dim,
num_patches=vision_encoder.grid_size[0] * vision_encoder.grid_size[1]),
num_patches=grid_size[0] * grid_size[1]),
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
gradient_checkpointing=gradient_checkpointing,
Expand Down
7 changes: 5 additions & 2 deletions open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,13 @@ def _encode_vision_x(self, vision_x: torch.Tensor):
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]

vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
if self.vision_encoder.__class__.__name__ == "TimmModel":
vision_x = self.vision_encoder.trunk.forward_features(vision_x)
else:
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
return vision_x

Expand Down
7 changes: 4 additions & 3 deletions open_flamingo/train/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ def preprocess_laion_image(sample, image_processor):
return rearrange(sample, "(b t f) c h w -> b t f c h w", t=1, f=1)


def preprocess_laion_text(sample, tokenizer, max_tokens=32):
def preprocess_laion_text(sample, tokenizer, max_tokens=256):
"""
Preprocess text for LAION. Applied to a batch of captions.
Captions are truncated to 32 tokens by default.
Captions are truncated to 256 tokens by default.
"""
tokenizer.padding_side = "right"
sample = [
(f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
# (f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
(f"<image>{s.strip()}{tokenizer.eos_token}") for s in sample
]
text = tokenizer(
sample,
Expand Down
72 changes: 13 additions & 59 deletions open_flamingo/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import torch
import wandb
import deepspeed
import functools
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand All @@ -16,16 +15,13 @@
world_info_from_env,
get_fsdp_config,
get_fsdp_checkpoint_config,
get_deepspeed_config,
)
from open_flamingo.train.train_utils import (
train_one_epoch,
random_seed,
load_deepspeed_checkpoint,
find_most_recent_checkpoint,
load_checkpoint,
save_checkpoint,
save_deepspeed_checkpoint,
)
from open_flamingo.train.losses import (
SUPPORTED_LOSSES,
Expand All @@ -44,8 +40,8 @@ def main():
parser.add_argument(
"--model_family", default="flamingo", type=str, choices=SUPPORTED_MODEL_FAMILIES
)
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
parser.add_argument("--vision_encoder_path", default="ViT-SO400M-14-SigLIP-384", type=str)
parser.add_argument("--vision_encoder_pretrained", default="webli", type=str)
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
parser.add_argument(
"--tokenizer_path",
Expand Down Expand Up @@ -73,7 +69,7 @@ def main():
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default. If using deepspeed this should be a directory, not a file.",
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default.",
default=None,
)
parser.add_argument(
Expand Down Expand Up @@ -187,20 +183,6 @@ def main():
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"]
)

# deepspeed args
parser.add_argument(
"--deepspeed",
default=False,
action="store_true",
help="Use deepspeed for distributed training.",
)
parser.add_argument(
"--deepspeed_stage",
default=2,
type=int,
help="DeepSpeed distributed training stage. 1: ZeRO-1 (optimizer sharding), 2: ZeRO-2 (optimizer + gradient sharding), 3: ZeRO-3 (optimizer + gradient + model sharding)",
)

# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -251,16 +233,10 @@ def main():
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")

if args.fsdp and args.deepspeed:
raise ValueError("Select either FSDP or deepspeed for distributed training.")

if args.fsdp:
print(
"Warning: FSDP is experimental and not fully tested. Preference should be given to Deepspeed."
)
assert (
"dev" in torch.__version__ and torch.__version__ > "2.0.1"
), "FSDP requires torch nightly > 2.0.1"
torch.__version__ > "2.0.1"
), "FSDP requires torch > 2.0.1"

# Set up distributed training
args.local_rank, args.rank, args.world_size = world_info_from_env()
Expand All @@ -269,13 +245,7 @@ def main():
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
if args.deepspeed:
torch.cuda.set_device(args.local_rank)
deepspeed.init_distributed()
ds_config = get_deepspeed_config(args)
device_id = args.local_rank
else:
device_id = init_distributed_device(args)
device_id = init_distributed_device(args)

random_seed(args.seed)

Expand Down Expand Up @@ -316,8 +286,8 @@ def main():
args.resume_from_checkpoint = find_most_recent_checkpoint(args)

if (
args.resume_from_checkpoint is not None and not args.deepspeed
): # deepspeed handles checkpoint loading
args.resume_from_checkpoint is not None
):
resume_from_epoch, checkpoint = load_checkpoint(args, model)
else:
resume_from_epoch = 0
Expand All @@ -327,7 +297,6 @@ def main():
model.init_gradient_checkpointing()

# Initialize FSDP / DDP, and ensure the model is on GPU
# Deepspeed is initialized later
if args.fsdp:
auto_wrap_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=model.get_fsdp_lambda_fn()
Expand All @@ -336,7 +305,7 @@ def main():
distributed_model = FSDP(
model, auto_wrap_policy=auto_wrap_policy, **wrapper_kwargs
)
elif not args.deepspeed:
else:
model = model.to(device_id)
distributed_model = DDP(model, device_ids=[device_id])

Expand All @@ -351,7 +320,7 @@ def main():
)

# load optimizer checkpoint
if args.resume_from_checkpoint is not None and not args.deepspeed:
if args.resume_from_checkpoint is not None:
osd = checkpoint["optimizer_state_dict"]
if args.fsdp:
FSDP.set_state_dict_type(
Expand All @@ -370,7 +339,7 @@ def main():
]
total_training_steps = (
getattr(args, f"train_num_samples_{datasets_to_train_on[0]}")
// getattr(args, f"batch_size_{datasets_to_train_on[0]}")
// (getattr(args, f"batch_size_{datasets_to_train_on[0]}") * args.gradient_accumulation_steps * args.world_size)
) * args.num_epochs

if args.rank == 0:
Expand All @@ -395,21 +364,9 @@ def main():
)

# load lr scheduler checkpoint
if args.resume_from_checkpoint is not None and not args.deepspeed:
if args.resume_from_checkpoint is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])

if args.deepspeed:
distributed_model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
config=ds_config,
lr_scheduler=lr_scheduler,
dist_init_required=True,
)
if args.resume_from_checkpoint is not None:
resume_from_epoch = load_deepspeed_checkpoint(args, distributed_model)

# Initialize the loss fn
loss_fn = get_loss_fn(args.loss)

Expand All @@ -435,10 +392,7 @@ def main():
wandb=wandb,
)

if args.deepspeed:
save_deepspeed_checkpoint(distributed_model, epoch, args)
else:
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
einops
einops-exts
transformers==4.28.1
transformers
torch>=2.0.1
pillow
open_clip_torch>=2.16.0
Expand Down

0 comments on commit feba465

Please sign in to comment.