Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 24 compressor #167

Open
wants to merge 2 commits into
base: add-targets-and-ignore-support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
92 changes: 92 additions & 0 deletions src/compressed_tensors/compressors/sparse_compressors/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.utils import (
merge_names,
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
tensor_follows_mask_structure,
)
from torch import Tensor


@BaseCompressor.register(name=CompressionFormat.sparse_24.value)
class Sparse24Compressor(BaseSparseCompressor):
"""
Compresses a with 2:4 sparsity structure for inference
with sparse 2:4 kernels for float/float16/bfloat16.
https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py
"""

COMPRESSION_PARAM_NAMES = ["sparse_24_packed_weight", "meta"]

@staticmethod
def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
"""
Checks if a tensor fits the required 2:4 sparsity structure
:param name: name of the tensor to check
:param weight: tensor to check for sparsity structure
:return: True if all rows match the 2:4 sparsity structure, raises
ValueError otherwise
"""

if not tensor_follows_mask_structure(
weight, mask=SparsityStructure.TWO_FOUR.value
):
raise ValueError(
"Sparse24Compressor is only compatible with weights that have "
f"a 2:4 sparsity structure. Found segments in {name} "
"that do not match the expected structure."
)

return True

def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]:
"""
Compresses a given with 2:4 sparsity structure.
:param name: name of the tensor in state dict of uncompressed model
:param value: 2:4 sparse tensor to compress
:return: dictionary containing the compressed weight and associated
metadata
"""
weight_suffix = ".weight"
if not name.endswith(weight_suffix):
return {}

prefix = name[: -len(weight_suffix)]
self.validate_sparsity_structure(name=prefix, weight=value)
sparse_24_packed_weight, meta = sparse_semi_structured_from_dense_cutlass(
dense=value
)
return {
merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(),
merge_names(name, "meta"): meta.cpu(),
}

def decompress_weight(self, weight_data):
assert (
"sparse_24_packed_weight" in weight_data
), "sparse_24_packed_weight not found in weight_data"
assert "meta" in weight_data, "meta not found in weight_data"

return sparse_semi_structured_to_dense_cutlass(
sparse=weight_data["sparse_24_packed_weight"],
meta_reordered=weight_data["meta"],
)
1 change: 1 addition & 0 deletions src/compressed_tensors/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa
from .base import *
from .dense import *
from .sparse_24 import *
from .sparse_bitmask import *
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
sparse_24 = "sparse-24"
int_quantized = "int-quantized"
float_quantized = "float-quantized"
naive_quantized = "naive-quantized"
Expand Down
37 changes: 37 additions & 0 deletions src/compressed_tensors/config/sparse_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from compressed_tensors.config import (
CompressionFormat,
SparsityCompressionConfig,
SparsityStructure,
)


__all__ = ["Sparse24Config"]


@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value)
class Sparse24Config(SparsityCompressionConfig):
"""
Configuration for storing a sparse model using 2:4 compression
:param global_sparsity: average sparsity of the entire model
:param sparsity_structure: structure of the sparsity, "2:4"
"""

format: str = CompressionFormat.sparse_24.value
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
19 changes: 15 additions & 4 deletions src/compressed_tensors/utils/semi_structured_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
# Modified from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L47
def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dim() != 2:
raise RuntimeError(
Expand All @@ -85,7 +86,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
device = dense.device

meta_dtype = torch.int8
if dense.dtype == torch.int8:
if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
meta_dtype = torch.int16
Expand Down Expand Up @@ -165,11 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense):
idxs1 = bit2 | (bit3.to(torch.int64) << 1)

if dense.dtype != torch.float:
if dense.dtype == torch.float8_e4m3fn:
dense_4 = dense_4.view(torch.int8)
sparse0 = dense_4.gather(
-1, idxs0.unsqueeze(-1)
) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
if dense.dtype == torch.float8_e4m3fn:
sparse = sparse.view(torch.float8_e4m3fn)
else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
m, k // 2
Expand Down Expand Up @@ -213,6 +218,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L180
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
if sparse.dim() != 2:
raise RuntimeError(
Expand Down Expand Up @@ -298,16 +304,21 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
).view(-1, 1).repeat(1, 2).view(-1)

dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
sparse_dtype = sparse.dtype if sparse.dtype != torch.float8_e4m3fn else torch.int8
dense = torch.zeros((m * 2 * k,), dtype=sparse_dtype, device=device)
if sparse.dtype != torch.float:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
if sparse.dtype == torch.float8_e4m3fn:
dense.scatter_(0, dense_offsets, sparse.view(torch.int8).view(-1))
else:
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
else:
dense.view(torch.half).scatter_(
0, dense_offsets, sparse.view(torch.half).view(-1)
)

return dense.view(m, 2 * k)
result = dense.view(m, 2 * k)
return result.view(sparse.dtype)


def mask_creator(tensor):
Expand Down
71 changes: 71 additions & 0 deletions tests/test_utils/test_semi_structured_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors.utils.semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
)


def supported_dtypes():
dtypes = [torch.int8, torch.float16, torch.bfloat16]
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major > 9 or (major == 9 and minor >= 0):
dtypes += [torch.float8_e4m3fn]
return dtypes


def get_random_mat(M, K, dtype):
rand_tensor_dtype = dtype
if dtype in [torch.int8, torch.float8_e4m3fn]:
rand_tensor_dtype = torch.float16
mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda()
mat = mat.masked_fill_(mat == 0, 1)
return mat.to(dtype)


def generate_pruned_semi_structured_mat(M, K, dtype):
mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool()
rand_tensor_dtype = dtype
if dtype in [torch.int8, torch.float8_e4m3fn]:
rand_tensor_dtype = torch.float16
mat = torch.rand(M, K, dtype=rand_tensor_dtype)
mat = mat.masked_fill_(mat == 0, 1)
if dtype == torch.float8_e4m3fn:
# some float8_e4m3fn operations are not supported on CPU
mat = mat.cuda()
mask = mask.cuda()
mat = mat * mask
return mat.to(dtype)


@pytest.mark.parametrize("dtype", supported_dtypes())
def test_inverse_property_from_dense_then_to_dense(dtype):
M, K = 1024, 1024
dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype)
compressed_matrix, meta = sparse_semi_structured_from_dense_cutlass(dense_matrix)
result = sparse_semi_structured_to_dense_cutlass(compressed_matrix, meta)

assert (
dense_matrix.dtype == result.dtype
), f"Dtype Mis-match: {dense_matrix.dtype} and {result.dtype}"
assert (
dense_matrix.shape == result.shape
), f"Shape Mis-match: {dense_matrix.shape} and {result.shape}"
assert torch.equal(
dense_matrix, result
), f"Failed for dtype: {dense_matrix.dtype} and input: {dense_matrix}"