Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchrec change for dynamic embedding #2533

Open
wants to merge 1 commit into
base: release/v0.7.0
Choose a base branch
from

Conversation

kanghui0204
Copy link

@kanghui0204 kanghui0204 commented Nov 4, 2024

Hi TorchREC experts,

We would like to try incorporating NVIDIA HKV into the existing TorchREC workflow to extend TorchREC's capabilities for model-parallel dynamic embedding.

We aim to integrate HKV dynamic embedding as a new type of embedding table into the TorchREC workflow. To avoid disrupting the original TorchREC code, we have designed some code for registering new embedding tables, which will help us and other users to better register a customized embedding table into the TorchREC workflow. Our modifications mainly target the following two parts:

  1. Registering a new customized compute table during the creation of the embedding table and lookup, and accepting its customized parameters.

  2. Since the range of indices for dynamic embedding is unlimited, we need the input distribution to perform round-robin distribution.(Our current PR serves as a reference. For example, in the input dist section, we have only modified the RW code. However, it is necessary to support all sharding types, such as TWRW)

Our code is based on v0.7, and it can be easily migrated to the latest code. We are initiating this PR as a reference for further discussions with you. We hope to support a high-performance dynamic embedding feature.

@facebook-github-bot
Copy link
Contributor

Hi @kanghui0204!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@dstaay-fb
Copy link
Contributor

Thanks for proposal;

RE [1]: it would help to put together a toy example of what your doing potentially; so we can see how you intend to use this api; ideally to point you could create a few multi-gpu tests (with appropriate mocking if needed etc).

RE [2]: So its not well documented, but we actually can support round robin based RW sharding today; its utilized in ZCH workflows (bucketization strategy is % world_size). Basically if you pass in RwSparseFeaturesDist(.., feature_hash_dim = [0,....0]) this will trigger this logic. This calls into FBGEMM block_bucketization kernels. Coincidently just added logic in this area, take a look at tests in PR: #2538 - specifically the case we set input_hash_size=0 on ZCH modules for full behavior (albeit a different use case).

@kanghui0204
Copy link
Author

kanghui0204 commented Nov 7, 2024

Hi @dstaay-fb thank you very much for quickly reply!

RE1: I will prepare a example for you as a reference as soon as possible.
RE2: Sorry , I didn't find the test for input_hash_size=0 on ZCH modules in PR2538,

Do you mean that setting the hash size of each table to 0 will make the block_bucketize_sparse_features in FBGEMM switch from contiguous block partitioning to round-robin partitioning? It looks like we need to modify the information of sharding_infos input to BaseRwEmbeddingSharding(https://github.com/dstaay-fb/torchrec/blob/export-D62483238/torchrec/distributed/sharding/rw_sharding.py#L115), is that correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants