Skip to content

Commit

Permalink
add stride into KJT pytree
Browse files Browse the repository at this point in the history
Summary:
# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. 
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). 
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. 
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Differential Revision: D66400821
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Nov 23, 2024
1 parent 7f3b7dc commit 8b5124f
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3026,13 +3026,17 @@ def dist_init(

def _kjt_flatten(
t: KeyedJaggedTensor,
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], int]]:
# for variable batch scenario, the stride cannot be computed from lengths/len(keys),
# instead, it should be computed from stride_per_key_per_rank, which is not included
# in the flatten spec. The stride is needed for the EBC output shape, so we need to
# store it in the context.
return [getattr(t, a) for a in KeyedJaggedTensor._fields], (t._keys, t.stride())


def _kjt_flatten_with_keys(
t: KeyedJaggedTensor,
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], Tuple[List[str], int]]:
values, context = _kjt_flatten(t)
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
return [ # pyre-ignore[7]
Expand All @@ -3041,9 +3045,11 @@ def _kjt_flatten_with_keys(


def _kjt_unflatten(
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
values: List[Optional[torch.Tensor]],
context: List[str], # context is (_keys, _stride)
) -> KeyedJaggedTensor:
return KeyedJaggedTensor(context, *values)
keys, stride = context
return KeyedJaggedTensor(keys, *values, stride=stride)


def _kjt_flatten_spec(
Expand Down

0 comments on commit 8b5124f

Please sign in to comment.