diff --git a/configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml b/configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml new file mode 100644 index 00000000..2e80917d --- /dev/null +++ b/configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml @@ -0,0 +1,35 @@ +base: + seed: &seed 42 +model: + type: model_type + path: model path + torch_dtype: auto +eval: + eval_pos: [pretrain, fake_quant] + type: code + name: human_eval + res_path: ./human_eval/ + # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False". + # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True". + bs: 1 + format_tabs: True + inference_per_block: False +quant: + method: RTN + weight: + bit: 8 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 8 + symmetric: True + granularity: per_token + kvcache: + method: Naive + bit: 8 + symmetric: True + granularity: per_token +save: + save_fake: False + save_path: /path/to/save/ diff --git a/llmc/__main__.py b/llmc/__main__.py index 6dba1d3b..8d7f6ac5 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -15,8 +15,8 @@ from llmc.compression.quantization import * from llmc.compression.sparsification import * from llmc.data import BaseDataset -from llmc.eval import (AccuracyEval, PerplexityEval, TokenConsistencyEval, - VLMEval) +from llmc.eval import (AccuracyEval, HumanEval, PerplexityEval, + TokenConsistencyEval, VLMEval) from llmc.models import * from llmc.utils import (check_config, mkdirs, print_important_package_version, seed_all, update_autoawq_quant_config, @@ -49,6 +49,9 @@ def main(config): elif config.eval.type == 'img_txt': acc_eval = VLMEval(config_for_eval) eval_list.append(acc_eval) + elif config.eval.type == 'code' and config.eval.name == 'human_eval': + human_eval = HumanEval(model.get_tokenizer(), config_for_eval) + eval_list.append(human_eval) else: ppl_eval = PerplexityEval(model.get_tokenizer(), config_for_eval) eval_list.append(ppl_eval) @@ -62,6 +65,10 @@ def main(config): for vlm_eval in eval_list: results = vlm_eval.eval(model) logger.info(f'{config.eval.name} results : {results}') + elif config.eval.type == 'code' and config.eval.name == 'human_eval': + for human_eval in eval_list: + results = human_eval.eval(model, eval_pos='pretrain') + logger.info(f'{config.eval.name} results : {results}') else: for ppl_eval in eval_list: ppl = ppl_eval.eval(model) @@ -122,6 +129,10 @@ def main(config): for vlm_eval in eval_list: results = vlm_eval.eval(model) logger.info(f'{config.eval.name} results : {results}') + elif config.eval.type == 'code' and config.eval.name == 'human_eval': + for human_eval in eval_list: + results = human_eval.eval(model, eval_pos='transformed') + logger.info(f'{config.eval.name} results : {results}') else: for ppl_eval in eval_list: ppl = ppl_eval.eval(model) @@ -150,6 +161,10 @@ def main(config): for vlm_eval in eval_list: results = vlm_eval.eval(model) logger.info(f'{config.eval.name} results : {results}') + elif config.eval.type == 'code' and config.eval.name == 'human_eval': + for human_eval in eval_list: + results = human_eval.eval(model, eval_pos='fake_quant') + logger.info(f'{config.eval.name} results : {results}') else: for ppl_eval in eval_list: ppl = ppl_eval.eval(model) diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 437b121e..e5c5ac55 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -231,11 +231,17 @@ def set_quant_config(self): # set kv cache quant config if 'kvcache' in self.quant_config: self.quant_config['kvcache']['static'] = self.act_static - self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']]( - self.quant_type, self.quant_config['kvcache'], - self.model.model_config.num_hidden_layers, self.config.calib.n_samples, - self.config.calib.bs - ) + if self.act_static: + self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']]( + self.quant_type, self.quant_config['kvcache'], + self.model.model_config.num_hidden_layers, self.config.calib.n_samples, + self.config.calib.bs + ) + else: + self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']]( + self.quant_type, self.quant_config['kvcache'], + self.model.model_config.num_hidden_layers + ) self.quant_kvcache = True self.model.kvcache_buffer.append(self.kv_module) else: diff --git a/llmc/compression/quantization/kvquant.py b/llmc/compression/quantization/kvquant.py index 0a247c73..9a8ada7a 100644 --- a/llmc/compression/quantization/kvquant.py +++ b/llmc/compression/quantization/kvquant.py @@ -9,7 +9,7 @@ @KV_REGISTRY.register('Naive') class NaiveQuantKVCache(DynamicCache): - def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz): + def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples=128, bsz=1): super().__init__() assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group'] @@ -216,7 +216,7 @@ def get_qparams(self, tensor): ) return scales, zeros, qmin, qmax - def get_seq_length(self, layer_idx): + def get_seq_length(self, layer_idx=0): if len(self._quantized_key_cache) <= layer_idx: return 0 return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 diff --git a/llmc/eval/__init__.py b/llmc/eval/__init__.py index 435148c8..7337e70b 100644 --- a/llmc/eval/__init__.py +++ b/llmc/eval/__init__.py @@ -1,4 +1,5 @@ from .eval_acc import AccuracyEval +from .eval_code import HumanEval from .eval_ppl import PerplexityEval from .eval_token_consist import TokenConsistencyEval from .eval_vlm import VLMEval diff --git a/llmc/eval/eval_base.py b/llmc/eval/eval_base.py index 534701be..b6071f98 100644 --- a/llmc/eval/eval_base.py +++ b/llmc/eval/eval_base.py @@ -1,9 +1,11 @@ import gc +import os from concurrent.futures import ThreadPoolExecutor import torch import torch.nn as nn from datasets import load_dataset, load_from_disk +from human_eval.data import read_problems from loguru import logger @@ -12,6 +14,7 @@ def __init__(self, tokenizer, config): self.tokenizer = tokenizer # eval_cfg eval_cfg = config.eval + self.model_type = config.model.type logger.info(f'eval_cfg : {eval_cfg}') self.dataset = eval_cfg['name'] assert self.dataset in [ @@ -19,63 +22,71 @@ def __init__(self, tokenizer, config): 'c4', 'ptb', 'custom', - ], 'Ppl eval only support wikitext2, c4, ptb dataset now.' - self.seq_len = eval_cfg['seq_len'] + 'human_eval' + ], 'Ppl eval only support wikitext2, c4, ptb, human_eval dataset now.' + self.seq_len = eval_cfg.get('seq_len', None) self.bs = eval_cfg['bs'] self.path = eval_cfg.get('path', None) - self.download = eval_cfg['download'] + self.download = eval_cfg.get('download', False) self.load_from_txt = eval_cfg.get('load_from_txt', False) self.inference_per_block = eval_cfg.get('inference_per_block', False) self.testenc = self.build_data() + self.res_path = eval_cfg.get('res_path', None) + assert self.dataset in ['human_eval'] and self.res_path is not None + os.makedirs(self.res_path, exist_ok=True) + self.format_tabs = eval_cfg.get('format_tabs', False) @torch.no_grad() def build_data(self): # load data - if self.download: + if self.dataset == 'human_eval': + testenc = read_problems() + else: + if self.download: + if self.dataset == 'wikitext2': + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + elif self.dataset == 'c4': + testdata = load_dataset( + 'allenai/c4', + data_files={ + 'validation': 'en/c4-validation.00000-of-00008.json.gz' + }, + split='validation', + ) + elif self.dataset == 'ptb': + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + else: + if not self.load_from_txt: + assert self.path, 'Please set path in eval_cfg.' + testdata = load_from_disk(self.path) + else: + """Load dataset from your custom txt file. + + Each line in the txt file represents one input text data. + """ + assert self.path.endswith('.txt') + logger.info(f'eval dataset path: {self.path}') + with open(self.path, 'r') as fp: + lines = fp.readlines() + testdata = [] + for line in lines: + testdata.append(line.strip()) + # encode data if self.dataset == 'wikitext2': - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + testenc = self.tokenizer('\n\n'.join(testdata['text']), return_tensors='pt') elif self.dataset == 'c4': - testdata = load_dataset( - 'allenai/c4', - data_files={ - 'validation': 'en/c4-validation.00000-of-00008.json.gz' - }, - split='validation', + testenc = self.tokenizer( + ' '.join(testdata[:1100]['text']), return_tensors='pt' ) + testenc.input_ids = testenc.input_ids[:, : (256 * self.seq_len)] elif self.dataset == 'ptb': - testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') - else: - if not self.load_from_txt: - assert self.path, 'Please set path in eval_cfg.' - testdata = load_from_disk(self.path) - else: - """Load dataset from your custom txt file. - - Each line in the txt file represents one input text data. - """ - assert self.path.endswith('.txt') - logger.info(f'eval dataset path: {self.path}') - with open(self.path, 'r') as fp: - lines = fp.readlines() - testdata = [] - for line in lines: - testdata.append(line.strip()) - # encode data - if self.dataset == 'wikitext2': - testenc = self.tokenizer('\n\n'.join(testdata['text']), return_tensors='pt') - elif self.dataset == 'c4': - testenc = self.tokenizer( - ' '.join(testdata[:1100]['text']), return_tensors='pt' - ) - testenc.input_ids = testenc.input_ids[:, : (256 * self.seq_len)] - elif self.dataset == 'ptb': - testenc = self.tokenizer( - ' '.join(testdata['sentence']), return_tensors='pt' - ) - elif self.dataset == 'custom': - testenc = self.tokenizer( - '\n'.join(testdata), return_tensors='pt' - ) + testenc = self.tokenizer( + ' '.join(testdata['sentence']), return_tensors='pt' + ) + elif self.dataset == 'custom': + testenc = self.tokenizer( + '\n'.join(testdata), return_tensors='pt' + ) return testenc @torch.no_grad() @@ -102,7 +113,7 @@ def register_hooks(self, model): return handles @torch.no_grad() - def eval(self, model_llmc, model_org=None): + def eval(self, model_llmc, model_org=None, eval_pos=None): handles, handles_org = [], [] if self.inference_per_block: handles = self.register_hooks(model_llmc) @@ -118,7 +129,12 @@ def eval(self, model_llmc, model_org=None): model_org.model.eval() - eval_res = self.eval_func(model_org, model_llmc, self.testenc, self.seq_len, self.bs) + eval_res = self.eval_func(model_org, + model_llmc, + self.testenc, + self.seq_len, + self.bs, + eval_pos) if self.inference_per_block: for h in handles + handles_org: h.remove() @@ -130,3 +146,6 @@ def eval(self, model_llmc, model_org=None): gc.collect() torch.cuda.empty_cache() return eval_res + + def post_process(self, testenc): + pass diff --git a/llmc/eval/eval_code.py b/llmc/eval/eval_code.py new file mode 100644 index 00000000..ff576818 --- /dev/null +++ b/llmc/eval/eval_code.py @@ -0,0 +1,121 @@ +import glob +import os + +import torch +from human_eval.data import stream_jsonl, write_jsonl +from human_eval.evaluation import evaluate_functional_correctness +from loguru import logger +from tqdm import tqdm + +from .eval_base import BaseEval + + +class HumanEval(BaseEval): + + @torch.no_grad() + def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos): + samples = [] + pbar = tqdm(total=len(testenc) * bs, dynamic_ncols=True, position=0, desc='Evaluating') + + for task_id in testenc: + if self.format_tabs: + prompt = testenc[task_id]['prompt'].replace(' ', '\t') + else: + prompt = testenc[task_id]['prompt'] + batch_completions = self.generate_batch_completion( + model, prompt, bs + ) + + for sample in batch_completions: + result = dict( + task_id=task_id, + completion=sample, + ) + samples += [result] + + pbar.update(bs) + + pbar.close() + + self.output_dir = os.path.join(self.res_path, eval_pos) + + os.makedirs(self.output_dir, exist_ok=True) + out_path = os.path.join(self.output_dir, 'eval.jsonl') + write_jsonl(out_path, samples) + + res = self.post_process(testenc) + return res + + @torch.no_grad() + def generated_llama( + self, + model, + inputs, + max_new_tokens=512, + temperature=0.2, + top_p=0.95, + do_sample=True, + ): + generated_ids = model.model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=do_sample, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + use_cache=True, + ) + return generated_ids + + @torch.no_grad() + def generate_batch_completion(self, model, prompt, bs): + input_batch = [prompt for _ in range(bs)] + inputs = self.tokenizer(input_batch, return_tensors='pt').to(model.model.device) + input_ids_cutoff = inputs.input_ids.size(dim=1) + + if self.model_type in ['Llama']: + generated_ids = self.generated_llama(model, inputs) + model.reset_kv() + else: + raise NotImplementedError('This model is not support yet.') + + batch_completions = self.tokenizer.batch_decode( + [ids[input_ids_cutoff:] for ids in generated_ids], + skip_special_tokens=True, + ) + + return [ + self.filter_code(self.fix_indents(completion)) + for completion in batch_completions + ] + + @torch.no_grad() + def post_process(self, testenc): + files = sorted(glob.glob(os.path.join(self.output_dir, 'eval.jsonl'))) + logger.info(f'{len(files)} files in {self.output_dir}') + output = [] + + for code_file in tqdm(files, total=len(files)): + codes = [c for c in stream_jsonl(code_file)] + output += codes + + out_path = os.path.join(self.output_dir, 'processed.jsonl') + logger.info(f'save to {out_path}') + write_jsonl(out_path, output) + res = self.entry_point(out_path) + return res + + @torch.no_grad() + def filter_code(self, completion): + completion = completion.lstrip('\n') + return completion.split('\n\n')[0] + + @torch.no_grad() + def fix_indents(self, text): + return text.replace('\t', ' ') + + @torch.no_grad() + def entry_point(self, sample_file): + results = evaluate_functional_correctness(sample_file) + return results diff --git a/llmc/eval/eval_ppl.py b/llmc/eval/eval_ppl.py index afc2abc7..a2beaffc 100644 --- a/llmc/eval/eval_ppl.py +++ b/llmc/eval/eval_ppl.py @@ -12,7 +12,7 @@ class PerplexityEval(BaseEval): @torch.no_grad() - def eval_func(self, org_model, model, testenc, seq_len, bs): + def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos): testenc = testenc.input_ids nsamples = testenc.numel() // seq_len diff --git a/llmc/eval/eval_token_consist.py b/llmc/eval/eval_token_consist.py index 77055937..adc3f1bf 100644 --- a/llmc/eval/eval_token_consist.py +++ b/llmc/eval/eval_token_consist.py @@ -12,7 +12,7 @@ class TokenConsistencyEval(BaseEval): @torch.no_grad() - def eval_func(self, org_model, model, testenc, seq_len, bs): + def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos): testenc = testenc.input_ids nsamples = testenc.numel() // seq_len diff --git a/requirements/runtime.txt b/requirements/runtime.txt index a0975dab..7a4a1838 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -30,3 +30,5 @@ einops qwen-vl-utils tiktoken librosa +human_eval +glob \ No newline at end of file