-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlashAttention Triton error on the MosaicBERT models other than base #441
Comments
Hi @Taytay, Thanks for your comment. This code was written in the early days of Triton Flash Attention and we are in the process of updating to Flash Attention 2 with support for ALiBi (see #440 which updates to PyTorch 2 and the The config value of This was understandably confusing for a lot of people, so I have set the default in Hugging Face to |
Your code should now work with the config from config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048')
config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)
mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")
classifier("I [MASK] to the store yesterday.") |
Thank you for the thorough explanation! I was trying training yesterday and ran into some more errors, so #440 is especially welcome! (I have a model I need to train on a large number of tokens, so the perf is going to be particularly helpful.) |
When I try to run MosaicBERT like this:
I get this error:
This appears to have been fixed a few days ago by @jacobfulano in the mosaic-bert-base repo:
https://huggingface.co/mosaicml/mosaic-bert-base/blob/ed2a544063a892b78823cba2858d1e098c0e6012/config.json
It looks like that removes FlashAttention? Does that mean that the speed increase from FA is also removed?
Here's how I can fix it in the meantime if someone else Googles and stumbles across this
The text was updated successfully, but these errors were encountered: