Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Granite code support #1336

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

Conversation

gabe-l-hart
Copy link
Contributor

@gabe-l-hart gabe-l-hart commented Oct 31, 2024

Dependencies

This PR is part of a sequence in support of adding Granite Code. It depends on merging the following PRs:

Issues

Closes #1262

Description

This PR adds support for Granite Code in 3B and 8B sizes. Given current limitations with the export of tokenizers, they will only work in the python environment with this PR.

Discussion

Usage

To test using these models, I did it both by running with the aliases and by running pointing directly at the checkpoint/tokenizer:

# Run with alias
python torchchat.py generate granite-code \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes"

# Run with direct reference to artifacts
python torchchat.py generate \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes" \
  --checkpoint-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/model.pth \
  --tokenizer-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/tokenizer.json \
  --params-path torchchat/model_params/Granite-3B-Code.json

Open Questions

There are several outstanding issues, beyond the upstream tokenizers PR, that need to be solved before this PR is ready for full review:

  • It seems that in chat mode, the models produce very unreliable results, sometimes generating a single token while other times generating a reasonable result but stopping mid-sentence before reaching the max token limit. My current hypothesis is that the chat template is not being used anywhere and we're therefore using the llama chat template automatically.
  • The 8B model currently produces garbage after a few tokens. The main difference between the 3B and 8B models, besides common parameter differences like number of layers and hidden size, is that the 8B uses grouped query attention. I've seen similar behavior in other frameworks where the model starts on a good track, then devolves into garbage and in those cases GQA was also at play, so I suspect it's something along these lines here as well.

Copy link

pytorch-bot bot commented Oct 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1336

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 10918a1 with merge base 6895a18 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 31, 2024
@gabe-l-hart
Copy link
Contributor Author

Also, I used the following script to perform conversion of a pre-existing HF snapshot. It's similar to the if __name__ == "__main__" block in convert_hf_checkpoint.py:

convert_existing_checkpoint.py
#!/usr/bin/env python
"""
Simple script to convert an existing HF snapshot into torchchat format
"""

# Standard
import argparse
from pathlib import Path

# Local
from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune

def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("checkpoint_dir", help="Directory containing HF checkpoint")
    parser.add_argument("--name", "-n", default=None, help="Name to use for the model")
    parser.add_argument("--torchtune", "-t", action="store_true", default=False, help="Convert to torchtune format")
    args = parser.parse_args()
    if args.torchtune:
        convert_hf_checkpoint_to_tune(model_dir=Path(args.checkpoint_dir), model_name=args.name)
    else:
        convert_hf_checkpoint(model_dir=Path(args.checkpoint_dir), model_name=args.name)

if __name__ == "__main__":
    main()

@gabe-l-hart gabe-l-hart marked this pull request as ready for review November 5, 2024 16:28
@mikekgfb mikekgfb mentioned this pull request Nov 5, 2024
3 tasks
@gabe-l-hart gabe-l-hart force-pushed the GraniteCodeSupport branch 2 times, most recently from daeeb79 to 19aa6c7 Compare November 7, 2024 00:29
@gabe-l-hart
Copy link
Contributor Author

I confirmed that it was falling back to the llama2 chat formatter because it wasn't using tiktoken. I've added basic jinja2 chat template support when using the HF tokenizer.

@mikekgfb
Copy link
Contributor

mikekgfb commented Nov 8, 2024

A pointer to this PR and the example commands from the PR description would make a good starting point for docs/new_model.md to (at least partially?) address #1038 / #1041 in conjunction with some explanatory text

# wget artifacts here
# Run with direct reference to artifacts
python torchchat.py generate \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes" \
  --checkpoint-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/model.pth \
  --tokenizer-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/tokenizer.json \
  --params-path torchchat/model_params/Granite-3B-Code.json

Explain how to add to model list....

# Run with alias
python torchchat.py generate granite-code \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes"

if added to .ci/scripts/run-docs new_model it might also make a testcase for the features used in granite.

@gabe-l-hart
Copy link
Contributor Author

gabe-l-hart commented Nov 8, 2024

@Jack-Khuu I'm a bit stumped trying to get the 8B model working. I'm trying to mentally diff the Attention implementation in torchchat vs transformers to see if I can find anything that would indicate something behaving differently with Grouped Query Attention.

I'm not really following the different way that the torchchat version is manipulating the tensors for tensor parallel inference (need to do some background reading there), but this feels like it's got to be close to the root of the issue. The only other place that I could imagine things going wrong is in the unpacking of the unified wqkv here. Any insight you can offer would be much appreciated!

Results with 3B
?> python torchchat.py generate granite-code-3b --prompt "Write a python hello world function"
NumExpr defaulting to 16 threads.
PyTorch version 2.6.0.dev20241002 available.
lm_eval is not installed, GPTQ may not be usable
W1108 13:18:36.747000 52813 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
Using device=mps 
Loading model...
Time to load model: 3.86 seconds
-----------------------------------------------------------
Write a python hello world function

`​``python
def say_hello():
    print("hello world")
`​``

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~                
Generated 19 tokens                 
Time for inference 1: 1.6639 sec total                 
Time to first token: 0.4289 sec with parallel prefill.                

      Total throughput: 12.0199 tokens/sec, 0.0832 s/token                 
First token throughput: 2.3316 tokens/sec, 0.4289 s/token                 
 Next token throughput: 15.3844 tokens/sec, 0.0650 s/token                     

Bandwidth achieved: 86.74 GB/s
*** This first iteration will include cold start effects for dynamic import, hardware caches. ***

========================================


      Average tokens/sec (total): 12.02                 
Average tokens/sec (first token): 2.33                 
Average tokens/sec (next tokens): 15.38 

NOTE (because I feel compelled): The above snippet uses zero-width-spaces to escape the triple backticks inside the code blocks, so copy-paste at your own peril!

Results with 8B
?> python torchchat.py generate granite-code-8b -p "Write a python hello world function"
usage: torchchat [-h] {chat,generate,browser,export,download,list,remove,where,server,eval} ...
torchchat: error: unrecognized arguments: -p Write a python hello world function
(torchchat2) ghart@Mac [torchchat GraniteCodeSupport ?]$ python torchchat.py generate granite-code-8b --prompt "Write a python hello world function"
NumExpr defaulting to 16 threads.
PyTorch version 2.6.0.dev20241002 available.
lm_eval is not installed, GPTQ may not be usable
W1108 13:13:21.744000 51816 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
Using device=mps 
Loading model...
Time to load model: 11.67 seconds
-----------------------------------------------------------
Write a python hello world function function function function function function function function function function function

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~                
Generated 11 tokens                 
Time for inference 1: 7.5729 sec total                 
Time to first token: 4.8976 sec with parallel prefill.                

      Total throughput: 1.5846 tokens/sec, 0.6311 s/token                 
First token throughput: 0.2042 tokens/sec, 4.8976 s/token                 
 Next token throughput: 4.1117 tokens/sec, 0.2432 s/token                     

Bandwidth achieved: 26.17 GB/s
*** This first iteration will include cold start effects for dynamic import, hardware caches. ***

========================================


      Average tokens/sec (total): 1.58                 
Average tokens/sec (first token): 0.20                 
Average tokens/sec (next tokens): 4.11 

@Jack-Khuu
Copy link
Contributor

Thanks for the details @gabe-l-hart

I'll try to give it a gander this weekend. It's weird that 3B works, but 8B doesn't. I assume they use the same template so that at least clears that part

@byjlw
Copy link
Contributor

byjlw commented Nov 19, 2024

Looks like this has been open for a several weeks now.
Yeah, the template thing is super hacky right now and I knew it was going to hang up our ability to add new models.
In general we need to make a smoother path for new adding new models with different architectures, templates and storage locations.

It's been on @varunfb and @Jack-Khuu 's plate for a while but they've been swamped with other work.
Fortunately it's planning season and the design for this is on the list.
@gabe-l-hart would love to get your feedback on how best to support folks like yourself.

@gabe-l-hart
Copy link
Contributor Author

Thanks @byjlw! I definitely understand juggling priorities. The path to adding new models in the model_params and model_config is relatively straightforward (could use a doc, but TBH I never read docs anyway, so easy-to-read code is always best). The real challenge has come up around places where the models differ from the llama series models. In particular, Granite Code uses the llama architecture, but uses several optional bits that the Meta Llama models don't (e.g. HF tokenizers, tied embeddings). Getting these pieces to work has been a decently steep learning curve (fun though!). I think the thing that would be most helpful would be some kind of compatibility matrix doc that shows architectures that have support, sub-features within architectures, and which "layers" they're supported in (e.g. python, c++, executorch). This would help a lot in figuring out where to dig in to add new model support.

For the specific issues for Granite Code, the place I'm a bit stuck is trying to figure out why the 8B model is flopping while the 3B model is working just fine. My gut is that it has something to do with the alternate attention mechanism in TC, but I'm not deeply versed in attention enough to spot it quickly. The only architectural difference between 3B and 8B is the use of grouped query attention, so it's either something there or there's some incompatibility between the attention implementations in transformers and TC that's only being exercised by the specific weights of the 8B. Any help and/or expert bug spotting would be much appreciated!

@gabe-l-hart
Copy link
Contributor Author

gabe-l-hart commented Nov 20, 2024

I just rebased on main and it now looks like even the 3b model is producing only a single token as output in chat mode. Will try to get to the bottom of it.

@mikekgfb
Copy link
Contributor

I just rebased on main and it now looks like even the 3b model is producing only a single token as output in chat mode. Will try to get to the bottom of it.

Have you tried bisecting the 3B fail? Even if the change was legit and necessary, the type of change that would break the 3B model might give insight in how to "fix" both the 3B and 8B models?
.

@mikekgfb
Copy link
Contributor

mikekgfb commented Nov 21, 2024

The real challenge has come up around places where the models differ from the llama series models. In particular, Granite Code uses the llama architecture, but uses several optional bits that the Meta Llama models don't (e.g. HF tokenizers, tied embeddings). Getting these pieces to work has been a decently steep learning curve (fun though!).

I'm a bit surprised by this because chatgpt had this to say (understanding that I'm quoting chatgppt about an IBM model to an IBMer, so skating on seriously thin ice!!!):

what tokenization scheme does the ibm granite model use

Searched 4 sites
The IBM Granite models, including its base and instruction-tuned variants, utilize the Llama2 tokenizer for tokenization. This choice aligns with the models’ architectural similarity to Meta's Llama2 series, such as the Granite-7b model, which follows the Llama2-7B architecture and employs similar tokenization strategies. These tokenizers are designed to handle diverse data sources, including programming languages and natural language, ensuring compatibility and efficiency in tasks like code synthesis and language understanding​

So in theory, SentencePiece should do the trick? Is it the pre and post processing with regexps? (I think I saw some discussion about regexps in one of your PRs or issues?)

In any event, it's cool that we have HF tokenizers because they are a proper superset of SentencePiece+TikToken. (I think @lessw2020 and @kwen2501 had also added some HF tokenizer support for distributed if I remember correctly?)

@gabe-l-hart
Copy link
Contributor Author

Have you tried bisecting the 3B fail? Even if the change was legit and necessary, the type of change that would break the 3B model might give insight in how to "fix" both the 3B and 8B models?

That's on my todo list for my next chunk of uninterrupted dev time! I'm hoping that will be today.

I'm a bit surprised by this because chatgpt had this to say (understanding that I'm quoting chatgppt about an IBM model to an IBMer, so skating on seriously thin ice!!!):

Heh, as you know I'm sure, IBM is a big place, so I'm definitely doing a lot of learning myself in this space. My info from the models team is that we've been using the starcoder tokenizer up until now (including Granite Code and the Granite 3.0 series). When first trying to understand how best to support that in torchchat, I was missing a lot of knowledge about sentencepiece, so was working off of the tokenizer_config.json in HF. I suspect it would be possible to reverse-convert from tokenizers back to sentencepiece for this config, but I haven't done that work yet since I was already halfway down the rabbit hole of tokenizers support. We can certainly look into that as an alternative approach if the preference is to avoid the complexity of the c++ tokenizer buildout.

@gabe-l-hart
Copy link
Contributor Author

@Jack-Khuu @mikekgfb @byjlw I figured out where the issues were coming from. It was two things:

  1. The logic was always inserting a bos token at the beginning of the sequence which the 3b model was sometimes ok with, but the 8b never was
    • To solve this, I added tokenizer_prepend_bos as a parameter in TransformerArgs and ModelArgs. It seemed a little klunky to plumb it through multiple blobs, but this got things working for both models with raw generation
  2. The chat template logic was not robust beyond llama2 and llama3 templating. Solving this resulted in a fair bit of refactoring:
    • Refactor the Llama2ChatFormatter and Llama3ChatFormatter to encapsulate all logic in a single abstract method encode_dialog_prompt
    • Remove all formatter-specific logic from the primary generation loop in def chat
    • Add the HFTokenizerChatFormatter
    • Plumb the ability to use the chat template with jinja2 through HFTokenizer
      • NOTE: jinja2 was already a transitive dependency, so I just formalized it

To get to the bottom of all of this, I also tweaked the logging a bit. There was already a hard-coded logging config call in cli, so I just added the ability to parse the LOG_LEVEL env var to set it. I also added a fair number of additional log lines and uncommented some that were there but commented out.

NOTE: Many of the existing log lines were using f-strings which will cause the string to be interpolated regardless of whether the logger/level are enabled. I switched all of these to use lazy interpolation with percent-encoding so that it's safe to have them uncommented without a performance hit.

Finally, I was getting lost trying to make sure I didn't break anything in the chat templating, so I bit the bullet and added some basic unit tests. They only cover the chat formatting, but they're a place to start. I did not go any further with unit testing, including not adding pytest as a dependency or adding any CI steps to invoke the tests. If you're interested, I'd be happy to push on unit testing, but I didn't want to lump that conversation into this PR.

@@ -9,6 +9,10 @@ gguf
# Tiktoken tokenizer for Llama 3 and other advanced models
tiktoken

# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these here, but did not add pytest (yet). I think there's a pending conversation about introducing optional dependency sets, so it would make sense to add a test or dev set at that point, but I didn't want to accidentally carry pytest along as a runtime dependency.

import os
import sys

# Make sure tests can import torchchat
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a lot cleaner if we move to having a pyproject.toml or setup.py to bundle torchchat as a package that could be installed with pip install -e.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the list

return tokens


B_INST, E_INST = "[INST]", "[/INST]"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved these into the class as members to enforce the encapsulation that they should only be used in the context of this formatter

@byjlw
Copy link
Contributor

byjlw commented Nov 23, 2024

Thanks @gabe-l-hart
Yeah a lot of this feedback resonates really well with me, and resolving it has already made it on our H1 roadmap such as making it easy to have model specific templates, adding test infra and guidelines around tests, abstracting and making the code more modular so that there is a specific module for core with well defined APIs that the CLI and API can use. We will also figure out the release strategy and publish two or three specific pip packages.

Will be able to share the details soon and will have them as RFCs on GH so everyone can comment and contribute.

Copy link
Contributor

@byjlw byjlw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Longer term we want to implement the tokenizer the drop the dependency on HF tokenizer but this will get things going for now.

import os
import sys

# Make sure tests can import torchchat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the list

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
* Use the right tokenizer_file name
* Use the right transformer_params_key based on the file name in
model_params
* Use the updated name to indicate HF tokenizers

Signed-off-by: Gabe Goodhart <[email protected]>
Something isn't quite working with this model yet, but the config should be
accurate at this point.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
It was implicitly being pulled in via lm_eval -> transformers, but it's
better to have it explicit since we use it directly

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
…HF tokenizers

This is a much simplified version of the corresponding logic in
transformers. I opted for this so that the full transformers dependency is
not added here.

CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1522

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
This will allow the jinja2 templates for HF tokenizers to be applied
without needing to hard-code the formatter logic. This will likely need to
be duplicated in the embedded code version of chat.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
It was getting pulled in implicitly via flask and lm_eval -> transformers,
but better to have it explicit.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
In generate, there were a number of commented-out log lines. These are safe
to leave in as long as lazy string interpolation is used.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
And disable it for Granite Code models

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
… in classes

The formatted strings may not be perfectly 1:1 with the previous impl, but
they should be in line with the official model guidelines:

* https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
* https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
There's no formal execution framework for pytest yet, but these were
helpful in ensuring that the formatting was working correctly!

To run them, install pytest and run `pytest tests/`

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
There is an incompatibility with logging and torch._dynamo, so this
disables it unless the developer asks for it explicitly.

NOTE: The TC team has stated that they have holistic logging on the roadmap
so this is a short-term solution pending a more robust approach.

REF: https://github.com/pytorch/torchchat/actions/runs/11963066986/job/33493237302#step:14:3599

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Granite Code 3B/8B
5 participants