Skip to content

Commit

Permalink
Merge pull request #96 from pytorch-labs/initial_hqq
Browse files Browse the repository at this point in the history
initial hqq
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 4be0e5b commit 7eda42e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
60 changes: 60 additions & 0 deletions parking_lot/hqq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: Compile main

on:
push:
branches:
- main
pull_request:
workflow_dispatch:

jobs:
run-hqq-tinystories:
strategy:
matrix:
runner: [ubuntu-latest]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.11
- name: Print machine info
run: |
uname -a
if [ $(uname -s) == Darwin ]; then
sysctl machdep.cpu.brand_string
sysctl machdep.cpu.core_count
fi
- name: Install requirements
run: |
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -r requirements.txt
pip install hqq
- name: Download checkpoints
run: |
mkdir -p checkpoints/stories15M
pushd checkpoints/stories15M
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
popd
- name: Run inference
run: |
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp
echo "******************************************"
echo "******** HQQ: group-wise quantized *******"
echo "******************************************"
python generate.py --quant '{"linear:hqq" : {"group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"linear:hqq" : {"group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"embedding" : {"group_size": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "tests complete"
echo "******************************************"
58 changes: 58 additions & 0 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def quantize_model(model: nn.Module, quantize_options):
model,
**q_kwargs
).quantized_model()
elif quantizer == "linear:hqq":
linears_quantized = True
model = WeightOnlyInt4HqqQuantHandler(
model,
**q_kwargs
).quantized_model()
elif quantizer == "precision":
model.to(**q_kwargs)
else:
Expand Down Expand Up @@ -600,6 +606,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:

# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))


#########################################################################
##### weight only int4 per channel groupwise quantized code ######

Expand Down Expand Up @@ -683,6 +690,7 @@ def create_quantized_state_dict(self):

return cur_state_dict


def convert_for_runtime(self, use_cuda=False):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
return self.mod
Expand Down Expand Up @@ -1255,3 +1263,53 @@ def quantized_model(self) -> nn.Module:
# self.precision,
# )
# return model

##################################################################
### WIP: HQQ ###

class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, group_size):
self.mod = mod
self.groupsize = group_size

def create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao


for m in self.mod.modules():
for name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
child.weight = torch.nn.Parameter(
Quantizer.dequantize(
*Quantizer.quantize(
child.weight,
nbits=4,
group_size=self.groupsize,
axis=1,
)
)
)

# we use Int4 packaged in an int8 for now, packing to follow
# return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()
return WeightOnlyInt8QuantHandler(
self.mod, bitwidth=4, group_size=self.groupsize
).create_quantized_state_dict()

def convert_for_runtime(self):
# we use Int4 packaged in an int8 for now, packing to follow
# ALSO: all code must work for CPU, CUDA, MPS
# return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)
return WeightOnlyInt4GPTQQuantHandler(
self.mod, bitwidth=4, group_size=self.groupsize
).convert_for_runtime()

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod


##################################################################

0 comments on commit 7eda42e

Please sign in to comment.