Skip to content

Commit

Permalink
Mark weights unbacked
Browse files Browse the repository at this point in the history
Summary: This is to avoid recompilations caused by the shape changes of `_weights` in KJT.

Differential Revision: D66342695
  • Loading branch information
Microve authored and facebook-github-bot committed Nov 22, 2024
1 parent 7ae70cd commit 8c76fdd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchrec/pt2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,16 @@ def kjt_for_pt2_tracing(

values = kjt.values().long()
torch._dynamo.decorators.mark_unbacked(values, 0)
weights = kjt.weights_or_none()
if weights is not None:
weights = weights.float()
torch._dynamo.decorators.mark_unbacked(weights, 0)

return KeyedJaggedTensor(
keys=kjt.keys(),
values=values,
lengths=lengths,
weights=kjt.weights_or_none(),
weights=weights,
stride=stride if not is_vb else None,
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
inverse_indices=inverse_indices,
Expand Down

0 comments on commit 8c76fdd

Please sign in to comment.