From 117f39fffd0ff87c65b9c008f69d5eb5c2061656 Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Fri, 22 Nov 2024 09:29:28 -0800 Subject: [PATCH] Mark weights unbacked (#2583) Summary: This is to avoid recompilations caused by the shape changes of `_weights` in KJT. Differential Revision: D66342695 --- torchrec/pt2/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index e62a9a6a4..55accff68 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -75,12 +75,15 @@ def kjt_for_pt2_tracing( values = kjt.values().long() torch._dynamo.decorators.mark_unbacked(values, 0) + weights = kjt.weights_or_none() + if weights is not None: + torch._dynamo.decorators.mark_unbacked(weights, 0) return KeyedJaggedTensor( keys=kjt.keys(), values=values, lengths=lengths, - weights=kjt.weights_or_none(), + weights=weights, stride=stride if not is_vb else None, stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, inverse_indices=inverse_indices,