diff --git a/generate.py b/generate.py index 2201139b2..eb7a809ae 100644 --- a/generate.py +++ b/generate.py @@ -279,26 +279,6 @@ def _load_model(checkpoint_path, device, precision, use_tp=False): with torch.device("meta"): model = Transformer.from_name(checkpoint_path.parent.name) -# if "int8" in str(checkpoint_path): -# print("Using int8 weight-only quantization!") -# from quantize import WeightOnlyInt8QuantHandler -# -# simple_quantizer = WeightOnlyInt8QuantHandler(model) -# model = simple_quantizer.convert_for_runtime() -# -# if "int4" in str(checkpoint_path): -# print("Using int4 weight-only quantization!") -# path_comps = checkpoint_path.name.split(".") -# assert path_comps[-3].startswith("g") -# assert ( -# path_comps[-2] in device -# ), "weight packed format mismatch, please rerun quantize.py!" -# groupsize = int(path_comps[-3][1:]) -# from quantize import WeightOnlyInt4QuantHandler -# -# simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) -# model = simple_quantizer.convert_for_runtime(use_cuda) - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) if "model" in checkpoint and "stories" in str(checkpoint_path): checkpoint = checkpoint["model"] diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index b92114c41..78309493e 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -98,8 +98,8 @@ def permute(w, n_head): if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') - parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) - parser.add_argument('--model_name', type=str, default=None) + parser.add_argument('--checkpoint-dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) + parser.add_argument('--model-name', type=str, default=None) args = parser.parse_args() convert_hf_checkpoint( diff --git a/scripts/download.py b/scripts/download.py index a968cf33b..7cf257f2c 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -23,8 +23,8 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') - parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') - parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') + parser.add_argument('--repo-id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') + parser.add_argument('--hf-token', type=str, default=None, help='HuggingFace API token.') args = parser.parse_args() hf_download(args.repo_id, args.hf_token) diff --git a/scripts/prepare.sh b/scripts/prepare.sh index 7bc6b5e66..580c4e94b 100755 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -1 +1 @@ -python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 +python scripts/download.py --repo-id $1 && python scripts/convert_hf_checkpoint.py --checkpoint-dir checkpoints/$1 diff --git a/scripts/test_flow.sh b/scripts/test_flow.sh index a2636e937..2f1b4317f 100755 --- a/scripts/test_flow.sh +++ b/scripts/test_flow.sh @@ -1,6 +1,5 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf rm -r checkpoints/$MODEL_REPO -python scripts/download.py --repo_id $MODEL_REPO -python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO -python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth -python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --max_new_tokens 100 +python scripts/download.py --repo-id $MODEL_REPO +python scripts/convert_hf_checkpoint.py --checkpoint-dir checkpoints/$MODEL_REPO +python generate.py --compile --checkpoint-path checkpoints/$MODEL_REPO/model.pth --max_new_tokens 100