-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade Flax NNX Gemma Sampling Inference doc #4325
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
docs_nnx/guides/gemma.ipynb
Outdated
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it." | ||
"In this tutorial, you will learn step-by-step how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use the [Flax NNX `gemma` code](https://github.com/google/flax.git) that was written with Flax and JAX.\n", | ||
"\n", | ||
"> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to what we did in other Gemma docs - added some background.
"\n", | ||
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"guide" in the title, "tutorial" in the first paragraph -> let's use "tutorial".
"\n", | ||
"Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cgarciae Since there are checkpoints and tokenizer files, changed to "model" instead of "checkpoint".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking if free TPU v2-8 is sufficient.
docs_nnx/guides/gemma.ipynb
Outdated
@@ -19,16 +19,24 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cgarciae Adding "inference" next to "sampling" for search.
docs_nnx/guides/gemma.ipynb
Outdated
"\n", | ||
"1. Visit https://www.kaggle.com/ and create an account.\n", | ||
"2. Go to your account settings, then the 'API' section.\n", | ||
"3. Click 'Create new token' to download your key.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cgarciae Adding Step 3 "OPTIONAL" and removing "OPTIONAL". since Colab asks for access here after running the code below, so users won't have to manually entering the details if they are stored in Colab:
import kagglehub
kagglehub.login()
"1. To create an account, visit Kaggle and click on 'Register'."
"2. If/once you have an account, you need to sign in, go to your 'Settings', and under 'API' click on 'Create New Token' to generate and download your Kaggle API key."
"3. OPTIONAL: In Google Colab, under 'Secrets' add your Kaggle username and API key, storing the username as KAGGLE_USERNAME
and the key as KAGGLE_KEY
. If you are using a Kaggle Notebook for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys."
docs_nnx/guides/gemma.ipynb
Outdated
@@ -82,13 +90,21 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"If everything went well, you should see:\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding an extra optional step here, similar to what we have in the Gemma docs.
"Note: In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above...."
docs_nnx/guides/gemma.ipynb
Outdated
@@ -124,7 +140,7 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cgarciae Edited:
"To interact with the Gemma model, you will use the Flax NNX Gemma code from google/flax
examples on GitHub. Since it is not exposed as packages, you need to use the following workaround in the next cells to import from the Flax NNX Gemma example."
docs_nnx/guides/gemma.ipynb
Outdated
@@ -195,7 +218,9 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the source code that has more docstring(s) since transformer_lib
is an alias for Flax NNX examples -> gemma.transformer
:
"Then, use the Flax NNX transformer_lib.TransformerConfig.from_params
function to automatically load the correct configuration from a checkpoint."
docs_nnx/guides/gemma.ipynb
Outdated
@@ -212,7 +237,9 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"Finally, build a sampler on top of your model and your tokenizer." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the source code with the docstring.
"Build a Flax NNX Sampler
on top of your model and tokenizer with the right parameter shapes."
docs_nnx/guides/gemma.ipynb
Outdated
@@ -235,7 +261,11 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some background on JAX JIT after studying the source code (it's not NNX JIT).
"Note: This Flax NNX gemma.Sampler
uses JAX’s just-in-time (JIT) compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent."
e7e7195
to
243439f
Compare
docs_nnx/guides/gemma.ipynb
Outdated
@@ -136,6 +152,14 @@ | |||
"! git clone https://github.com/google/flax.git flax_examples" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixing
! git clone https://github.com/google/flax.git flax_examples
...
- sys.path.append("./flax_examples/flax/nnx/examples/gemma")
+ sys.path.append("./flax_examples/examples/gemma")
...
243439f
to
52dc69b
Compare
Throws an error
@cgarciae PTAL |
docs_nnx/guides/gemma.ipynb
Outdated
"3. Click 'Create new token' to download your key.\n", | ||
"1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.\n", | ||
"2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.\n", | ||
"3. OPTIONAL: In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Should remove Optional for Colab users?
Hey @8bitmp3! I cleaned up this guide a little bit. Can you take a look at the new version? |
thanks @cgarciae 👍 |
52dc69b
to
d8b1a92
Compare
Reopening after #4334 fixes |
Preview: https://flax--4325.org.readthedocs.build/en/4325/guides/gemma.html
Also fixes broken code after: