diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index e62a9a6a4..65e4bb35c 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -75,12 +75,16 @@ def kjt_for_pt2_tracing( values = kjt.values().long() torch._dynamo.decorators.mark_unbacked(values, 0) + weights = kjt.weights_or_none() + if weights is not None: + weights = weights.float() + torch._dynamo.decorators.mark_unbacked(weights, 0) return KeyedJaggedTensor( keys=kjt.keys(), values=values, lengths=lengths, - weights=kjt.weights_or_none(), + weights=weights, stride=stride if not is_vb else None, stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, inverse_indices=inverse_indices,