Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization #1328

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

vmpuri
Copy link
Contributor

@vmpuri vmpuri commented Oct 24, 2024

Replace the WeightOnlyInt8Linear quantization code with TorchAO's int8_weight_only quantization.

Note - this commit also contains lintrunner changes.

Testing:

python3 torchchat.py eval llama3.2-1b --quantize '{"linear:int8": {"groupsize": 0}, "executor":{"accelerator":"cuda"}}' --compile
Using device=cuda
Loading model...
Time to load model: 1.21 seconds
Quantizing the model with: {'linear:int8': {'groupsize': 0}, 'executor': {'accelerator': 'cuda'}}
quantizer is linear int8
Time to quantize model: 0.31 seconds
-----------------------------------------------------------
2024-10-24:15:55:20,261 INFO     [huggingface.py:162] Using device 'cuda'
2024-10-24:15:55:27,792 WARNING  [task.py:763] [Task: wikitext] metric word_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:55:27,792 WARNING  [task.py:775] [Task: wikitext] metric word_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:55:27,792 WARNING  [task.py:763] [Task: wikitext] metric byte_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:55:27,792 WARNING  [task.py:775] [Task: wikitext] metric byte_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:55:27,792 WARNING  [task.py:763] [Task: wikitext] metric bits_per_byte is defined, but aggregation is not. using default aggregation=bits_per_byte
2024-10-24:15:55:27,792 WARNING  [task.py:775] [Task: wikitext] metric bits_per_byte is defined, but higher_is_better is not. using default higher_is_better=False
Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:55:28,687 WARNING  [repocard.py:108] Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:55:28,760 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:00<00:00, 501.80it/s]
2024-10-24:15:55:28,889 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [01:10<00:00,  1.13s/it]
Time to run eval: 78.96s.
Time in model.forward: 62.57s, over 162 model evaluations
forward run time stats - Median: 0.00s Min: 0.00s Max: 41.80s
For model /home/puri/.torchchat/model-cache/meta-llama/Meta-Llama-3.2-1B-Instruct/model.pth
wikitext:
 word_perplexity,none: 19.2032
 byte_perplexity,none: 1.7378
 bits_per_byte,none: 0.7973
 alias: wikitext

From current master:

python3 torchchat.py eval llama3.2-1b --quantize '{"linear:int8": {"groupsize": 0}, "executor":{"accelerator":"cuda"}}' --compile
Using device=cuda
Loading model...
Time to load model: 1.20 seconds
Quantizing the model with: {'linear:int8': {'groupsize': 0}, 'executor': {'accelerator': 'cuda'}}
Time to quantize model: 0.19 seconds
-----------------------------------------------------------
2024-10-24:15:43:59,945 INFO     [huggingface.py:162] Using device 'cuda'
2024-10-24:15:44:07,664 WARNING  [task.py:763] [Task: wikitext] metric word_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:44:07,664 WARNING  [task.py:775] [Task: wikitext] metric word_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:44:07,664 WARNING  [task.py:763] [Task: wikitext] metric byte_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:44:07,664 WARNING  [task.py:775] [Task: wikitext] metric byte_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:44:07,664 WARNING  [task.py:763] [Task: wikitext] metric bits_per_byte is defined, but aggregation is not. using default aggregation=bits_per_byte
2024-10-24:15:44:07,664 WARNING  [task.py:775] [Task: wikitext] metric bits_per_byte is defined, but higher_is_better is not. using default higher_is_better=False
Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:44:09,261 WARNING  [repocard.py:108] Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:44:09,342 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:00<00:00, 463.50it/s]
2024-10-24:15:44:09,482 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [01:00<00:00,  1.03it/s]
Time to run eval: 70.16s.
Time in model.forward: 53.46s, over 162 model evaluations
forward run time stats - Median: 0.00s Min: 0.00s Max: 33.02s
For model /home/puri/.torchchat/model-cache/meta-llama/Meta-Llama-3.2-1B-Instruct/model.pth
wikitext:
 word_perplexity,none: 19.2432
 byte_perplexity,none: 1.7385
 bits_per_byte,none: 0.7978
 alias: wikitext

Lint

pip install -r install/requirements-lintrunner.txt 
lintrunner -a

Copy link

pytorch-bot bot commented Oct 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1328

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1a42fb6 with merge base e30aaa0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 24, 2024
@vmpuri vmpuri marked this pull request as ready for review October 24, 2024 22:57
@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 24, 2024

thanks! can you add a generate.py speed benchmark result for before and after as well

# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("quantizer is linear int8")

"precision": PrecisionHandler,
"executor": ExecutorHandler,
"linear:int4": Int4WeightOnlyQuantizer,
"linear:int8": int8_weight_only,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably use None for now, and remove this later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check for int8_weight_only and finished check before it looks at the table I think

@vmpuri can you check?

@Jack-Khuu
Copy link
Contributor

Can you ack that the numerics look good for MPS and CPU as well?

# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
quantize_(model, int8_weight_only())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not integrate it into a QuantHandler class dispatched thru the handler dict at a single call site rather than build a chain of if statements?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mikekgfb, we will refactor this part in the future after all quant APIs are moved to torchao I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchAO already has a class-based API that is used for other quantizers? Why do these differently, and then later refactor them? Or why not do them all a consistent way now, and if you refactor later, do that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, quantizer API is deprecated in favor of quantize_, that's why we are gradually refactoring the quantizer APIs to use quantize_, the reason we do it one by one is because there might be missing support/alignment on numerics etc. that we need to do during the migration

return linear_int8_aoti(input, self.weight, self.scales)

def et_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_et(input, self.weight, self.scales)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Int 8 seems like it special cased for ET, reminder to check that as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants