Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for targets and ignore in Sparsity Compressors #182

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union

import compressed_tensors
import torch
Expand All @@ -39,6 +39,7 @@
apply_quantization_config,
load_pretrained_quantization,
)
from compressed_tensors.quantization.lifecycle import expand_targets
from compressed_tensors.quantization.utils import (
is_module_quantized,
iter_named_leaf_modules,
Expand Down Expand Up @@ -276,8 +277,13 @@ def compress(
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = expand_targets(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
compressed_state_dict, compression_targets=sparse_compression_targets
)

# HACK: Override the dtype_byte_size function in transformers to
Expand Down
38 changes: 34 additions & 4 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple
from typing import Dict, Generator, Optional, Set, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand Down Expand Up @@ -59,18 +59,27 @@ class BaseSparseCompressor(BaseCompressor):
:param config: config specifying compression parameters
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(
self,
model_state: Dict[str, Tensor],
compression_targets: Optional[Set[str]] = None,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression

:param model_state: state dict of uncompressed model
:param compression_targets: optional set of layer prefixes to compress, if None
compress all layers (for backwards compatibility)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are we holding backwards compatibility with? Ideally this should default to only compressing models that we detect the 50% sparsity threshold for

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For older configs, we will not have targets, this handles those cases;

For newer flow, compression_targets will only contain modules that are over 50% sparse once this lands: vllm-project/llm-compressor#822

:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
if not self.should_compress(name, compression_targets):
compressed_dict[name] = value
continue
compression_data = self.compress_weight(name, value)
for key in compression_data.keys():
if key in compressed_dict:
Expand All @@ -97,8 +106,10 @@ def decompress(
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
weight_mappings, other_params = get_nested_weight_mappings(
path_to_model_or_tensors,
self.COMPRESSION_PARAM_NAMES,
return_other_params=True,
)
for weight_name in weight_mappings.keys():
weight_data = {}
Expand All @@ -108,3 +119,22 @@ def decompress(
weight_data[param_name] = f.get_tensor(full_name)
decompressed = self.decompress_weight(weight_data)
yield weight_name, decompressed

for other_name, safe_path in other_params.items():
with safe_open(safe_path, framework="pt", device=device) as f:
value = f.get_tensor(other_name)
yield other_name, value

@staticmethod
def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed

:param name: name of the parameter
:param targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if targets is None:
return name.endswith(".weight")

return name.endswith(".weight") and name[: -(len(".weight"))] in targets
26 changes: 25 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from copy import deepcopy
from typing import Dict, Iterable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Union
from typing import Set, Union

import torch
from compressed_tensors.config import CompressionFormat
Expand Down Expand Up @@ -56,6 +56,7 @@
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
"expand_targets",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
Expand Down Expand Up @@ -280,6 +281,29 @@ def find_name_or_class_matches(
return matches


def expand_targets(
model: Module, targets: Iterable[str], ignore: Iterable[str]
) -> Set[str]:
"""
Finds all the targets in the model that match the given targets and ignore lists

Note: Targets must be regexes, layer types, or full layer names

:param model: model to search for targets in
:param targets: list of targets to search for
:param ignore: list of targets to ignore
:return: set of all targets that match the given targets and should
not be ignored
"""
current_targets = set()
for name, module in iter_named_leaf_modules(model):
if find_name_or_class_matches(
name, module, targets
) and not find_name_or_class_matches(name, module, ignore):
current_targets.add(name)
return current_targets


def _find_matches(
value: str, targets: Iterable[str], check_contains: bool = False
) -> List[str]:
Expand Down
31 changes: 25 additions & 6 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import re
import struct
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union

from safetensors import safe_open
from torch import Tensor
Expand All @@ -34,6 +34,9 @@
"is_quantization_param",
]

WEIGHT_MAPPING_TYPE = Dict[str, str]
NESTED_WEIGHT_MAPPING_TYPE = Dict[str, WEIGHT_MAPPING_TYPE]


def get_safetensors_folder(
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
Expand Down Expand Up @@ -176,8 +179,10 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:


def get_nested_weight_mappings(
model_path: str, params_to_nest: List[str]
) -> Dict[str, Dict[str, str]]:
model_path: str, params_to_nest: List[str], return_other_params: bool = False
) -> Union[
NESTED_WEIGHT_MAPPING_TYPE, Tuple[NESTED_WEIGHT_MAPPING_TYPE, WEIGHT_MAPPING_TYPE]
]:
"""
Takes a path to a state dict saved in safetensors format and returns a nested
mapping from uncompressed parameterized layer names to the file locations of each
Expand All @@ -193,22 +198,36 @@ def get_nested_weight_mappings(
This generalizes to cases where the model is split into multiple safetensors files

:param model_path: path to safetensors state dict, must contain either a single
safetensors file or multiple files with an index
:return: nested mapping of parameterized layer name to file location
safetensors file or multiple files with an index
:param return_other_params: if True, return a second dictionary containing the
remaining parameters that were not matched to the nested parameters
:return: nested mapping of parameterized layer name to file location if
return_other_params is False, else a tuple containing the nested mapping
and a mapping of the remaining parameters that were not matched to
the nested parameters
"""
weight_mappings = get_weight_mappings(model_path)
other_params = {}

nested_weight_mappings = {}
for key in weight_mappings.keys():
matched = False
for param_name in params_to_nest:
maybe_match = match_param_name(key, param_name)
if maybe_match is not None:
dense_param = maybe_match
if dense_param not in nested_weight_mappings:
nested_weight_mappings[dense_param] = {}
matched = True
nested_weight_mappings[dense_param][param_name] = weight_mappings[key]
if not matched:
other_params[key] = weight_mappings[key]

return nested_weight_mappings
return (
nested_weight_mappings
if not return_other_params
else (nested_weight_mappings, other_params)
)


def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,20 @@
from compressed_tensors.quantization.lifecycle import (
apply_quantization_config,
apply_quantization_status,
expand_targets,
)
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from transformers import AutoModelForCausalLM


@pytest.fixture
def model():
return AutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M",
torch_dtype="auto",
)


def test_target_prioritization():
# tests that the config_groups are applied in the correct order
# of priority, where exact layer name > regex > module name
Expand Down Expand Up @@ -272,3 +281,21 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):
assert len(caplog.text) > 0
else:
assert len(caplog.text) == 0


@pytest.mark.parametrize(
"targets, ignore, expected",
[
# ignore all
(["Linear"], ["Linear"], set()),
# ignore subset
(
["re:model.layers.[01].self_attn.q_proj"],
["re:model.layers.1.self_attn.q_proj"],
set(["model.layers.0.self_attn.q_proj"]),
),
],
)
def test_expand_targets(model, targets, ignore, expected):
actual_targets = expand_targets(model, targets, ignore)
assert actual_targets == expected
Loading