Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scripts update #59

Merged
merged 2 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,26 +279,6 @@ def _load_model(checkpoint_path, device, precision, use_tp=False):
with torch.device("meta"):
model = Transformer.from_name(checkpoint_path.parent.name)

# if "int8" in str(checkpoint_path):
# print("Using int8 weight-only quantization!")
# from quantize import WeightOnlyInt8QuantHandler
#
# simple_quantizer = WeightOnlyInt8QuantHandler(model)
# model = simple_quantizer.convert_for_runtime()
#
# if "int4" in str(checkpoint_path):
# print("Using int4 weight-only quantization!")
# path_comps = checkpoint_path.name.split(".")
# assert path_comps[-3].startswith("g")
# assert (
# path_comps[-2] in device
# ), "weight packed format mismatch, please rerun quantize.py!"
# groupsize = int(path_comps[-3][1:])
# from quantize import WeightOnlyInt4QuantHandler
#
# simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
# model = simple_quantizer.convert_for_runtime(use_cuda)

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]
Expand Down
4 changes: 2 additions & 2 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def permute(w, n_head):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"))
parser.add_argument('--model_name', type=str, default=None)
parser.add_argument('--checkpoint-dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"))
parser.add_argument('--model-name', type=str, default=None)

args = parser.parse_args()
convert_hf_checkpoint(
Expand Down
4 changes: 2 additions & 2 deletions scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.')
parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.')
parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')
parser.add_argument('--repo-id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.')
parser.add_argument('--hf-token', type=str, default=None, help='HuggingFace API token.')

args = parser.parse_args()
hf_download(args.repo_id, args.hf_token)
2 changes: 1 addition & 1 deletion scripts/prepare.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1
python scripts/download.py --repo-id $1 && python scripts/convert_hf_checkpoint.py --checkpoint-dir checkpoints/$1
7 changes: 3 additions & 4 deletions scripts/test_flow.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
rm -r checkpoints/$MODEL_REPO
python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --max_new_tokens 100
python scripts/download.py --repo-id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint-dir checkpoints/$MODEL_REPO
python generate.py --compile --checkpoint-path checkpoints/$MODEL_REPO/model.pth --max_new_tokens 100
Loading