-
Notifications
You must be signed in to change notification settings - Fork 531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable HF SpeedMonitor #997
base: main
Are you sure you want to change the base?
Conversation
Oops, some of these changes are for our internal use. Will remove them from here. |
Hey @rlrs, thanks for the contribution! I didn't know about this PyTorch flop counter! We'll want to do a bit of testing to make sure that this reports the correct number and doesn't cause any issues with (1) speed (2) memory usage or (3) bad interactions with distributed training strategies like FSDP. What testing of this have you been able to do yourself? |
Apologies for the lack of explanation or tests, I rushed this a bit. So far I've used this with Mistral 7B, comparing against the standard Transformer Math One uncertainty I have is how the FLOP counter interacts with non-PyTorch constructs like Flash Attention. I suspect that it might be necessary to register such code manually in order to get the correct result. If so, it might be silently underreporting FLOPs right now. |
No worries @rlrs! If you're able to do some testing (and add some unit tests) that would be great! Otherwise we'll look into it when we get a chance and appreciate the suggestion! |
Enable SpeedMonitor on HF models by using PyTorch FlopCounterMode to calculate model FLOPs.