Skip to content

Commit

Permalink
Add human-eval
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Nov 29, 2024
1 parent 24e461f commit 7cf886f
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 57 deletions.
35 changes: 35 additions & 0 deletions configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: auto
eval:
eval_pos: [pretrain, fake_quant]
type: code
name: human_eval
res_path: ./human_eval/
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
bs: 1
format_tabs: True
inference_per_block: False
quant:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
kvcache:
method: Naive
bit: 8
symmetric: True
granularity: per_token
save:
save_fake: False
save_path: /path/to/save/
19 changes: 17 additions & 2 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from llmc.compression.quantization import *
from llmc.compression.sparsification import *
from llmc.data import BaseDataset
from llmc.eval import (AccuracyEval, PerplexityEval, TokenConsistencyEval,
VLMEval)
from llmc.eval import (AccuracyEval, HumanEval, PerplexityEval,
TokenConsistencyEval, VLMEval)
from llmc.models import *
from llmc.utils import (check_config, mkdirs, print_important_package_version,
seed_all, update_autoawq_quant_config,
Expand Down Expand Up @@ -49,6 +49,9 @@ def main(config):
elif config.eval.type == 'img_txt':
acc_eval = VLMEval(config_for_eval)
eval_list.append(acc_eval)
elif config.eval.type == 'code' and config.eval.name == 'human_eval':
human_eval = HumanEval(model.get_tokenizer(), config_for_eval)
eval_list.append(human_eval)
else:
ppl_eval = PerplexityEval(model.get_tokenizer(), config_for_eval)
eval_list.append(ppl_eval)
Expand All @@ -62,6 +65,10 @@ def main(config):
for vlm_eval in eval_list:
results = vlm_eval.eval(model)
logger.info(f'{config.eval.name} results : {results}')
elif config.eval.type == 'code' and config.eval.name == 'human_eval':
for human_eval in eval_list:
results = human_eval.eval(model, eval_pos='pretrain')
logger.info(f'{config.eval.name} results : {results}')
else:
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
Expand Down Expand Up @@ -122,6 +129,10 @@ def main(config):
for vlm_eval in eval_list:
results = vlm_eval.eval(model)
logger.info(f'{config.eval.name} results : {results}')
elif config.eval.type == 'code' and config.eval.name == 'human_eval':
for human_eval in eval_list:
results = human_eval.eval(model, eval_pos='transformed')
logger.info(f'{config.eval.name} results : {results}')
else:
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
Expand Down Expand Up @@ -150,6 +161,10 @@ def main(config):
for vlm_eval in eval_list:
results = vlm_eval.eval(model)
logger.info(f'{config.eval.name} results : {results}')
elif config.eval.type == 'code' and config.eval.name == 'human_eval':
for human_eval in eval_list:
results = human_eval.eval(model, eval_pos='fake_quant')
logger.info(f'{config.eval.name} results : {results}')
else:
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
Expand Down
16 changes: 11 additions & 5 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,17 @@ def set_quant_config(self):
# set kv cache quant config
if 'kvcache' in self.quant_config:
self.quant_config['kvcache']['static'] = self.act_static
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers, self.config.calib.n_samples,
self.config.calib.bs
)
if self.act_static:
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers, self.config.calib.n_samples,
self.config.calib.bs
)
else:
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers
)
self.quant_kvcache = True
self.model.kvcache_buffer.append(self.kv_module)
else:
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/kvquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@KV_REGISTRY.register('Naive')
class NaiveQuantKVCache(DynamicCache):
def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz):
def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples=128, bsz=1):
super().__init__()

assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group']
Expand Down Expand Up @@ -216,7 +216,7 @@ def get_qparams(self, tensor):
)
return scales, zeros, qmin, qmax

def get_seq_length(self, layer_idx):
def get_seq_length(self, layer_idx=0):
if len(self._quantized_key_cache) <= layer_idx:
return 0
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
Expand Down
1 change: 1 addition & 0 deletions llmc/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .eval_acc import AccuracyEval
from .eval_code import HumanEval
from .eval_ppl import PerplexityEval
from .eval_token_consist import TokenConsistencyEval
from .eval_vlm import VLMEval
111 changes: 65 additions & 46 deletions llmc/eval/eval_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import gc
import os
from concurrent.futures import ThreadPoolExecutor

import torch
import torch.nn as nn
from datasets import load_dataset, load_from_disk
from human_eval.data import read_problems
from loguru import logger


Expand All @@ -12,70 +14,79 @@ def __init__(self, tokenizer, config):
self.tokenizer = tokenizer
# eval_cfg
eval_cfg = config.eval
self.model_type = config.model.type
logger.info(f'eval_cfg : {eval_cfg}')
self.dataset = eval_cfg['name']
assert self.dataset in [
'wikitext2',
'c4',
'ptb',
'custom',
], 'Ppl eval only support wikitext2, c4, ptb dataset now.'
self.seq_len = eval_cfg['seq_len']
'human_eval'
], 'Ppl eval only support wikitext2, c4, ptb, human_eval dataset now.'
self.seq_len = eval_cfg.get('seq_len', None)
self.bs = eval_cfg['bs']
self.path = eval_cfg.get('path', None)
self.download = eval_cfg['download']
self.download = eval_cfg.get('download', False)
self.load_from_txt = eval_cfg.get('load_from_txt', False)
self.inference_per_block = eval_cfg.get('inference_per_block', False)
self.testenc = self.build_data()
self.res_path = eval_cfg.get('res_path', None)
assert self.dataset in ['human_eval'] and self.res_path is not None
os.makedirs(self.res_path, exist_ok=True)
self.format_tabs = eval_cfg.get('format_tabs', False)

@torch.no_grad()
def build_data(self):
# load data
if self.download:
if self.dataset == 'human_eval':
testenc = read_problems()
else:
if self.download:
if self.dataset == 'wikitext2':
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
elif self.dataset == 'c4':
testdata = load_dataset(
'allenai/c4',
data_files={
'validation': 'en/c4-validation.00000-of-00008.json.gz'
},
split='validation',
)
elif self.dataset == 'ptb':
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
else:
if not self.load_from_txt:
assert self.path, 'Please set path in eval_cfg.'
testdata = load_from_disk(self.path)
else:
"""Load dataset from your custom txt file.
Each line in the txt file represents one input text data.
"""
assert self.path.endswith('.txt')
logger.info(f'eval dataset path: {self.path}')
with open(self.path, 'r') as fp:
lines = fp.readlines()
testdata = []
for line in lines:
testdata.append(line.strip())
# encode data
if self.dataset == 'wikitext2':
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
testenc = self.tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
elif self.dataset == 'c4':
testdata = load_dataset(
'allenai/c4',
data_files={
'validation': 'en/c4-validation.00000-of-00008.json.gz'
},
split='validation',
testenc = self.tokenizer(
' '.join(testdata[:1100]['text']), return_tensors='pt'
)
testenc.input_ids = testenc.input_ids[:, : (256 * self.seq_len)]
elif self.dataset == 'ptb':
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
else:
if not self.load_from_txt:
assert self.path, 'Please set path in eval_cfg.'
testdata = load_from_disk(self.path)
else:
"""Load dataset from your custom txt file.
Each line in the txt file represents one input text data.
"""
assert self.path.endswith('.txt')
logger.info(f'eval dataset path: {self.path}')
with open(self.path, 'r') as fp:
lines = fp.readlines()
testdata = []
for line in lines:
testdata.append(line.strip())
# encode data
if self.dataset == 'wikitext2':
testenc = self.tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
elif self.dataset == 'c4':
testenc = self.tokenizer(
' '.join(testdata[:1100]['text']), return_tensors='pt'
)
testenc.input_ids = testenc.input_ids[:, : (256 * self.seq_len)]
elif self.dataset == 'ptb':
testenc = self.tokenizer(
' '.join(testdata['sentence']), return_tensors='pt'
)
elif self.dataset == 'custom':
testenc = self.tokenizer(
'\n'.join(testdata), return_tensors='pt'
)
testenc = self.tokenizer(
' '.join(testdata['sentence']), return_tensors='pt'
)
elif self.dataset == 'custom':
testenc = self.tokenizer(
'\n'.join(testdata), return_tensors='pt'
)
return testenc

@torch.no_grad()
Expand All @@ -102,7 +113,7 @@ def register_hooks(self, model):
return handles

@torch.no_grad()
def eval(self, model_llmc, model_org=None):
def eval(self, model_llmc, model_org=None, eval_pos=None):
handles, handles_org = [], []
if self.inference_per_block:
handles = self.register_hooks(model_llmc)
Expand All @@ -118,7 +129,12 @@ def eval(self, model_llmc, model_org=None):

model_org.model.eval()

eval_res = self.eval_func(model_org, model_llmc, self.testenc, self.seq_len, self.bs)
eval_res = self.eval_func(model_org,
model_llmc,
self.testenc,
self.seq_len,
self.bs,
eval_pos)
if self.inference_per_block:
for h in handles + handles_org:
h.remove()
Expand All @@ -130,3 +146,6 @@ def eval(self, model_llmc, model_org=None):
gc.collect()
torch.cuda.empty_cache()
return eval_res

def post_process(self, testenc):
pass
Loading

0 comments on commit 7cf886f

Please sign in to comment.