Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ShortGPT #20

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions configs/sparsification/ShortGPT/shortgpt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
base:
seed: &seed 42
model:
type: Llama
path: model path
torch_dtype: auto
calib:
name: pileval
download: False
path: calib data path
n_samples: 128
bs: -1
seq_len: 512
preproc: general
seed: *seed
eval:
eval_pos: [transformed]
name: [wikitext2, c4]
download: False
path: eval data path
seq_len: 2048
sparse:
method: ShortGPT
weight:
n_prune_layers: 9
save:
save_trans: True
save_fp: False
save_lightllm: False
save_path: ./save
2 changes: 1 addition & 1 deletion docs/en/source/advanced/sparsification.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Model Sparsification

The llmc is currently gradually supporting sparse methods, having already implemented Magnitude and Wanda, and will support more algorithms in the future.
The llmc is currently gradually supporting sparse methods, having already implemented Magnitude, Wanda, and ShortGPT, and will support more algorithms in the future.

Here is a sample of Wanda's settings:

Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/source/advanced/sparsification.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 模型稀疏化

llmc目前正在逐渐支持稀疏化方法,目前已经实现了Magnitude和Wanda,将在未来支持更多的算法
llmc目前正在逐渐支持稀疏化方法,目前已经实现了Magnitude,Wanda和ShortGPT将在未来支持更多的算法

以下是Wanda的设置样例:

Expand Down
6 changes: 1 addition & 5 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def main(config):
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
sparsification = None
if not config.get('calib', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model, quant_config=config.quant, input=None, config=config
Expand All @@ -61,20 +60,17 @@ def main(config):
gc.collect()
torch.cuda.empty_cache()
if not config.get('sparse', False):
sparsification = False
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model, config.quant, model.get_first_block_input(), config
)
else:
sparsification = True
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
model, config.sparse, model.get_first_block_input(), config
)
blockwise_opt.run_block_loop()

if 'eval' in config and 'transformed' in config.eval.eval_pos:
if not sparsification:
blockwise_opt.deploy('origin_float')
blockwise_opt.deploy('origin_float')
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
Expand Down
1 change: 1 addition & 0 deletions llmc/compression/sparsification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_blockwise_sparsification import BaseBlockwiseSparsification
from .magnitude import Magnitude
from .shortgpt import ShortGPT
from .sparse import Sparser
from .wanda import Wanda
57 changes: 37 additions & 20 deletions llmc/compression/sparsification/base_blockwise_sparsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from loguru import logger

from llmc.utils import copy_files

from ..blockwise_optimization import BlockwiseOpt
from .sparse import Sparser

Expand All @@ -18,15 +20,15 @@ def block_init(self, block):
pass

def set_sparsity_config(self):
if (
'sparsity_out' in self.sparsity_config
and self.sparsity_config['sparsity_out']
):
if 'sparsity_out' in self.sparsity_config and self.sparsity_config[
'sparsity_out'
]:
self.sparsity_out = True
else:
self.sparsity_out = False
logger.info(f'use sparsity_out {self.sparsity_out}')
self.sparser = Sparser(**self.sparsity_config['weight'])

self.sparser = Sparser(self.sparsity_config['weight'])

def block_forward(self, block, input_data=None):
output = []
Expand All @@ -35,10 +37,9 @@ def block_forward(self, block, input_data=None):

for i in range(len(input_data)):
input_data[i] = input_data[i].to(device=next(block.parameters()).device)
if (
'attention_mask' in self.input['kwargs'][i]
and self.input['kwargs'][i]['attention_mask'] is not None
):
if 'attention_mask' in self.input[
'kwargs'
][i] and self.input['kwargs'][i]['attention_mask'] is not None:
self.input['kwargs'][i]['attention_mask'] = self.input['kwargs'][i][
'attention_mask'
].cuda()
Expand All @@ -47,10 +48,10 @@ def block_forward(self, block, input_data=None):
output.append(out)
return output

def block_opt(self, block, idx):
def block_opt(self, block):
block = block.cuda()
named_linears = self.model.get_block_linears(block)
# logger.info(f"named_linears: {named_linears}")
logger.info(f'named_linears: {named_linears}')
input_feat = defaultdict(list)
handles = []
self.block_init(block)
Expand All @@ -72,7 +73,7 @@ def block_opt(self, block, idx):
h.remove()
torch.cuda.empty_cache()

self.block_transform(block, input_feat, idx, self.input['kwargs'])
self.block_transform(block, input_feat, self.input['kwargs'])

if self.sparsity_out:
self.input['data'] = self.block_forward(block)
Expand All @@ -82,8 +83,8 @@ def block_opt(self, block, idx):
gc.collect()
torch.cuda.empty_cache()

def block_transform(self, block, input_feat, idx, block_kwargs):
logger.info(f'Start transform the {idx+1}-th block')
def block_transform(self, block, input_feat, block_kwargs):
logger.info(f'Start transform the {self.block_idx+1}-th block')
subsets = self.model.get_subsets_in_block(block)
for index, subset in enumerate(subsets):
if not self.filter_subset(subset):
Expand All @@ -101,19 +102,35 @@ def block_transform(self, block, input_feat, idx, block_kwargs):
prev_op,
input_name,
inspect_module,
subset_kwargs,
idx,
subset_kwargs
)
logger.info(f'End transform the {idx+1}-th block')
logger.info(f'End transform the {self.block_idx+1}-th block')

def filter_subset(self, subset):
return True

# todo
@torch.no_grad()
def deploy(self):
def deploy(self, deploy_format):
logger.info('-- deploy_sparsity_model start --')
logger.info(f'sparsity_config : {self.sparsity_config}')

# self.model.replace_module_all(module, params_dict)
logger.info('-- deploy_sparsity_model done --')

@torch.no_grad()
def copy_tokenizer(self, path):
for substring in self.config.save.get('tokenizer_file_substring', ['token']):
copy_files(self.config.model.path, path, substring)
logger.info('copy tokenizer done --')

@torch.no_grad()
def save_model(self, path):
if self.config.model.type == 'Llava':
self.model.llava_model.language_model = self.model.get_model()
self.model.llava_model.save_pretrained(path)
logger.info('save model done --')
self.copy_tokenizer(path)
copy_files(self.config.model.path, path, 'preprocessor_config')
else:
self.model.get_model().save_pretrained(path)
logger.info('save model done --')
self.copy_tokenizer(path)
3 changes: 1 addition & 2 deletions llmc/compression/sparsification/magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def subset_transform(
prev_op,
input_name,
inspect_module,
subset_kwargs,
idx,
subset_kwargs
):
layers = list(layers_dict.values())
for layer in layers:
Expand Down
114 changes: 114 additions & 0 deletions llmc/compression/sparsification/shortgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import gc
import json
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
from loguru import logger
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import MistralRMSNorm

from llmc.utils import copy_files
from llmc.utils.registry_factory import ALGO_REGISTRY

from .base_blockwise_sparsification import BaseBlockwiseSparsification


@ALGO_REGISTRY
class ShortGPT(BaseBlockwiseSparsification):
def __init__(self, model, sparsity_config, input, config):
super().__init__(model, sparsity_config, input, config)

def block_opt(self, block):
block = block.cuda()

output_feat = self.block_forward(block)
torch.cuda.empty_cache()
self.block_transform(self.input['data'], output_feat)
self.input['data'] = output_feat

def block_transform(self, input_feat, output_feat):
logger.info(f'Start transform the {self.block_idx+1}-th block')
self.subset_transform(
input_feat,
output_feat
)

@torch.no_grad()
def compute_bi(
self,
input_feat: torch.Tensor,
output_feat: torch.Tensor
):
_, _, d = input_feat.shape
input_feat = input_feat.reshape(-1, d)
output_feat = output_feat.reshape(-1, d)

norm_input = input_feat.norm(dim=-1, keepdim=True)
norm_output = output_feat.norm(dim=-1, keepdim=True)

sim = (input_feat @ output_feat.T) / (norm_input * norm_output)
sim = sim.diagonal().nan_to_num(nan=0.5)

return 1 - sim

@torch.no_grad()
def subset_transform(
self,
input_feat,
output_feat
):
# calculate BI score
if self.sparser.importances is None:
self.sparser.importances = np.zeros(len(self.blocks))
self.sparser.importances[self.block_idx] = self.compute_bi(
input_feat[0], output_feat[0]
).sum().cpu().item()

@torch.no_grad()
def remove_layers(
self,
layers_to_remove: Optional[List[int]] = []
):
if not layers_to_remove and self.sparser.n_prune_layers:
layers_to_remove = np.argsort(
np.array(self.sparser.importances)
)[:self.sparser.n_prune_layers].tolist()

for idx in sorted(layers_to_remove, reverse=True):
try:
del self.blocks[idx]
except IndexError:
logger.info(f'layer {idx} does not exist')
return layers_to_remove

@torch.no_grad()
def deploy(self, deploy_format):
logger.info(f'After compute, BI scores are {self.sparser.importances}')
logger.info('-- deploy_sparsity_model start --')
logger.info(f'sparsity_config : {self.sparsity_config}')
logger.info('-- begin remove layers --')
layers_to_remove = self.remove_layers()
logger.info(f'remove layers: {layers_to_remove}')
logger.info('-- deploy_sparsity_model done --')

@torch.no_grad()
def save_model(self, path):
if self.config.model.type == 'Llava':
self.model.llava_model.language_model = self.model.get_model()
self.model.llava_model.save_pretrained(path)
logger.info('save model done --')
self.copy_tokenizer(path)
copy_files(self.config.model.path, path, 'preprocessor_config')
else:
self.model.get_model().save_pretrained(path)
config_file = path + '/config.json'

logger.info('save model done --')
self.copy_tokenizer(path)
with open(config_file, 'r') as file:
config_new = json.load(file)
config_new['num_hidden_layers'] = len(self.blocks)
with open(config_file, 'w') as file:
json.dump(config_new, file, indent=4)
10 changes: 7 additions & 3 deletions llmc/compression/sparsification/sparse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
class Sparser:
def __init__(self, sparsity, **kwargs):
self.sparsity = sparsity
def __init__(self, sparsity_constraint, **kwargs):
if 'sparsity' in sparsity_constraint:
self.sparsity = sparsity_constraint['sparsity']
self.W_mask = None
elif 'n_prune_layers' in sparsity_constraint:
self.n_prune_layers = sparsity_constraint['n_prune_layers']
self.importances = None
self.kwargs = kwargs
self.W_mask = None
3 changes: 1 addition & 2 deletions llmc/compression/sparsification/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def subset_transform(
prev_op,
input_name,
inspect_module,
subset_kwargs,
idx,
subset_kwargs
):
layers = list(layers_dict.values())
for layer in layers:
Expand Down
15 changes: 15 additions & 0 deletions scripts/run_shortgpt_llama.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

gpu_id=0
export CUDA_VISIBLE_DEVICES=$gpu_id

llmc=llmc_path
export PYTHONPATH=$llmc:$PYTHONPATH

task_name=llm_quant_exp

nohup \
python -m llmc --config ../configs/sparsification/ShortGPT/shortgpt.yml \
> ${task_name}.log 2>&1 &

echo $! > ${task_name}.pid
15 changes: 15 additions & 0 deletions scripts/run_wanda_llama.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

gpu_id=0
export CUDA_VISIBLE_DEVICES=$gpu_id

llmc=llmc_path
export PYTHONPATH=$llmc:$PYTHONPATH

task_name=llm_quant_exp

nohup \
python -m llmc --config ../configs/sparsification/Wand/wanda.yml \
> ${task_name}.log 2>&1 &

echo $! > ${task_name}.pid
Loading