Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major refactor to support new architectures #261

Draft
wants to merge 53 commits into
base: main
Choose a base branch
from
Draft

Major refactor to support new architectures #261

wants to merge 53 commits into from

Conversation

i-gao
Copy link
Collaborator

@i-gao i-gao commented Sep 16, 2023

New models

  • All models inherit from a VLM class. See documentation in src/vlm.py
    """
    Generic vision-language model (VLM) class.
    A VLM consists of four components:
        1. A vision encoder that extracts features from pixels, e.g. CLIP
            input: (B, T_img, F, C, H, W)
            output: (B, T_img, F, v, d)
        2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
            input: (B, T_img, F, v, d)
            output: (B, T_img, n, d)
        3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
        4. A language model
    """
  • Models are further split into those that inherit from VLMWithCrossAttention (dense xattn to fuse vision + language, Flamingo-style) vs. VLMWithLanguageStream (insert vision tokens into the language stream, Kosmos-style)
  • BLIP, Kosmos, and Flamingo implemented

FSDP Updates

  • FSDP rewritten to reflect recent updates to Pytorch nightly, which allow for mixed frozen/unfrozen params within an FSDP wrapper, and for frozen params to reshard during the backward pass.
  • GPU power util for Flamingo model with facebook/opt-6.7B backbone is > 70%, compared to ~30% previously (this is with gradient ckpting).
  • This also removes the need to untie LM embeddings, and seems to address the issues w/ backward resharding shapes.

Training code refactor

  • train_one_epoch now accepts a list of datasets and executes the same loss function on all of them. This permits us to decide the datasets to train on at runtime (e.g. just LAION) and makes adding in datasets more flexible. To train on a dataset, set the --{dataset_name}_shards arg (e.g. --laion_shards). If this is None, then we will not train on that dataset (i.e., skip LAION)
  • train_one_epoch also now accepts a loss function decided at runtime. Losses are found in train/losses.py. Currently, only next token prediction is implemented, but this allows us to work on adding contrastive-generative losses.
  • Most of the FSDP / Deepspeed code has been moved to train/distributed.py in an attempt to streamline train/train.py

Steps before merging

  • Replicate OF-3B as a sanity check
  • Test training with deepspeed
  • Check embedding logic (cc @anas-awadalla)
  • Check whether anything in eval code needs to change; I think the image caching for classification needs to be completely rewritten now.
  • Merge in deepspeed eval code from the deepspeed branch
  • Update released weights or modify this code to be compatible with released weights (e.g., I have switched to using lang_model instead of lang_encoder; this will not play well with the released weights; we need to decide what to do about the embeddings).
  • Merge in updated / modular eval code from wilds branch

Steps after merging

  • Add support for image generation components, e.g. Emu-style
  • Add support for additional losses, e.g. CoCA-style
  • Add back in additional datasets that were being explored previously on this branch (only mmc4 and laion are in the code for now)

@anas-awadalla
Copy link
Collaborator

anas-awadalla commented Sep 16, 2023

Some other todos I want to add to this:

  • Deepspeed updates from that branch
  • Cast logits to fp32 for pure bf16/fp16 (to avoid loss spikes)
  • Add z-loss

@i-gao i-gao linked an issue Sep 20, 2023 that may be closed by this pull request
anas-awadalla and others added 2 commits September 20, 2023 19:40
* fix padding side when generating

* clean up language stream forward pass (less for looping)

* expose BLIP model

* fixes for forward pass without images

* restore for looping
@liyongqi67
Copy link

I have a keen interest in exploring the latest features. To that end, I've integrated the deepspeed-related code into the current main branch of Openflamingo, which includes functions like get_deepspeed_config(). During my testing, I observed that the code runs smoothly with the setting deepspeed_stage = 2 and exhibits significantly efficiency improvement compared to fsdp. However, when I attempted to configure it with deepspeed_stage = 3, an error was encountered during the execution of the loss backward propagation process:

    model.backward(divided_loss_laion)
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
        self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)ret_val = func(*args, **kwargs)

  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1923, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: The size of tensor a (0) must match the size of tensor b (8192) at non-singleton dimension 1
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/deepspeed/runtime/zero/stage3.py", line 2080, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: The size of tensor a (0) must match the size of tensor b (8192) at non-singleton dimension 1

Do you have any idea about this? Or have you encountered this problem when developing the new version.

@anas-awadalla
Copy link
Collaborator

You said you integrated “deepspeed-related code into the current main branch of Openflamingo”. Have you tried using this branch as is? The integration is basically complete but we are doing more testing to be certain. I will also test out stage 3 again to make sure we haven’t missed anything.

@liyongqi67
Copy link

liyongqi67 commented Oct 2, 2023

You said you integrated “deepspeed-related code into the current main branch of Openflamingo”. Have you tried using this branch as is? The integration is basically complete but we are doing more testing to be certain. I will also test out stage 3 again to make sure we haven’t missed anything.

I did not directly run this branch, as I have developed my project based on the main branch. Therefore, I just copy the deepspeed-related code in this branch to my code. The error is very strange: 1) Stage 2 works, but stage 3 reports the error; 2) The error occurred while executing loss backward, but the backward process rarely reports errors; 3) Which tensor has a size 0 as reported. If you have no idea about this, I have to run my code with deepspeed stage 2. Thanks!

@liyongqi67
Copy link

You said you integrated “deepspeed-related code into the current main branch of Openflamingo”. Have you tried using this branch as is? The integration is basically complete but we are doing more testing to be certain. I will also test out stage 3 again to make sure we haven’t missed anything.

I tried this branch, and it works well on the training part. I also tested the evaluation part of the branch "Merge wilds mllm". Unfortunately, there are some bugs. I reported two bugs in "#266".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

fsdp Error report Can I train OpenFlamingo without LIAON dataset? support training on only LAION 2B
3 participants