From 8c76fddf0636f78fcf29e4313a313b7d49564d7f Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Thu, 21 Nov 2024 21:21:22 -0800 Subject: [PATCH] Mark weights unbacked Summary: This is to avoid recompilations caused by the shape changes of `_weights` in KJT. Differential Revision: D66342695 --- torchrec/pt2/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index e62a9a6a4..65e4bb35c 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -75,12 +75,16 @@ def kjt_for_pt2_tracing( values = kjt.values().long() torch._dynamo.decorators.mark_unbacked(values, 0) + weights = kjt.weights_or_none() + if weights is not None: + weights = weights.float() + torch._dynamo.decorators.mark_unbacked(weights, 0) return KeyedJaggedTensor( keys=kjt.keys(), values=values, lengths=lengths, - weights=kjt.weights_or_none(), + weights=weights, stride=stride if not is_vb else None, stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, inverse_indices=inverse_indices,