Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Fix edge cases for FSDP2 integration #1269

Merged
merged 10 commits into from
Nov 26, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Nov 12, 2024

This PR fixes 2 issues that came up in torchtune when fine-tuning Llama3.2-vision

1. Sometimes there is a strange torch.compile() error with DTensor when there is .grad field. pytorch/torchtune#1978 (comment)

This is actually an old issue #652 (see #652 (comment) for more details). However, it does not always happen: our CI test passes (for PyTorch 2.6, but PyTorch 2.5 has this issue), finetune Llama text (not multimodal) in torchtune has no issues, but finetuning Llama3.2-vision faces the error -> It's not clear why and how this happens. The error message seems to indicate that torch.compile() tries to do dynamic-shape, even though we are explicitly using dynamic=False

The solution is to call .detach() on param, which shares the same weight storage, but now it doesn't have .grad anymore. Thanks to this, low-bit optim + FSDP2 also work for PyTorch 2.5 CI now (previously it didn't).

I can't add a test for this, since I don't know how/when this happens.

2. Wrong DTensor creation when there is uneven sharding (i.e. 1st dim is not divisible by world size)

Usually we don't have uneven shards for LLMs, thus this error didn't surface. However, for ViT, it might be possible due to pos_embed: in some implementation pos_embed includes CLS token, hence the first dim is num_visual_tokens + 1.

The fix is simple: pass shape (and stride) to DTensor.from_local(). An appropriate test has also been added.

Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1269

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 38bf355 with merge base 6ff3904 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 12, 2024
@gau-nernst gau-nernst added bug Something isn't working topic: bug fix Use this tag for PRs that fix bugs and removed bug Something isn't working labels Nov 12, 2024
@gau-nernst gau-nernst changed the title [low-bit optim] Fix strange compiled Adam step + FSDP2 [low-bit optim] Fix edge cases for FSDP2 integration Nov 13, 2024
@msaroufim msaroufim requested a review from vkuzo November 14, 2024 04:17
@gau-nernst gau-nernst marked this pull request as ready for review November 14, 2024 06:41
@gau-nernst
Copy link
Collaborator Author

@vkuzo Do you have time to take a look at this PR? Thank you 🙏

@msaroufim msaroufim merged commit b3493eb into pytorch:main Nov 26, 2024
18 checks passed
@gau-nernst gau-nernst deleted the optim_compile branch November 26, 2024 04:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants