diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4b5359f0d..243a58ee0 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -3380,14 +3380,21 @@ def _kt_unflatten( return KeyedTensor(context[0], context[1], values[0]) +print_flatten_spec_warn = True + + def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: _keys, _length_per_key = spec.context # please read https://fburl.com/workplace/8bei5iju for more context, # you can also consider use short_circuit_pytree_ebc_regroup with KTRegroupAsDict - logger.warning( - "KT's key order might change from spec from the torch.export, this could have perf impact. " - f"{kt.keys()} vs {_keys}" - ) + global print_flatten_spec_warn + if print_flatten_spec_warn: + logger.warning( + "KT's key order might change from spec from the torch.export, this could have perf impact. " + f"{kt.keys()} vs {_keys}" + ) + print_flatten_spec_warn = False + res = permute_multi_embedding([kt], [_keys]) return [res[0]]