Skip to content

Commit

Permalink
Update external documentation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700641441
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 27, 2024
1 parent 57ded1f commit 47ef4e8
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 142 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,28 @@
[![PyPI version](https://badge.fury.io/py/kauldron.svg)](https://badge.fury.io/py/kauldron)
[![Documentation Status](https://readthedocs.org/projects/kauldron/badge/?version=latest)](https://kauldron.readthedocs.io/en/latest/?badge=latest)

Kauldron is a library for training machine learning models, optimized for
**research velocity** and **modularity**.

**Modularity**:

* All parts of Kauldron are self-contained, so can be used independently
outside Kauldron.
* Use any dataset (TFDS, Grain, SeqIO, your custom pipeline),
any (flax) model, any optimizer,... Kauldron provides the
glue that link everything together.
* Everything can be customized and overwritten (e.g. sweep over models
architecture, overwrite any inner layer parameter,...)

**Research velocity**:

* Everything should work out-of the box. The
[example configs](http://https://github.com/google-research/kauldron/tree/HEAD/kauldron/examples/mnist_autoencoder.py)
can be used and customized as a starting point.
* Colab-first workflow for easy prototyping and fast iteration
.
* Polished user experience (integrated XM plots, profiler,
post-mortem debugging on borg, runtime shape checking, and many others...).
[Open an issue](https://github.com/google-research/kauldron/issues)..

*This is not an officially supported Google product.*
33 changes: 3 additions & 30 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,9 @@
```
"""

import sys
from unittest import mock

import apitree


# TODO(epot): Delete once `grain` can be imported
sys.modules['grain'] = mock.MagicMock()
sys.modules['grain._src'] = mock.MagicMock()
sys.modules['grain._src.core'] = mock.MagicMock()
sys.modules['grain._src.core.constants'] = mock.MagicMock()
sys.modules['grain._src.tensorflow'] = mock.MagicMock()
sys.modules['grain._src.tensorflow.transforms'] = mock.MagicMock()
sys.modules['grain.tensorflow'] = mock.MagicMock()

import grain.tensorflow as _mocked_grain # pylint: disable=g-import-not-at-top


class _MockedTransform:
pass


# Required for inheritance `class MyTransform(grain.MapTransform)`
_mocked_grain.MapTransform = _MockedTransform
_mocked_grain.RandomMapTransform = _MockedTransform


# Early failure if kauldron cannot be imported
# Read-the-doc install kauldron not in `-e` edit mode, so should only import
# kauldron after `apitree` import kauldron from the right path.
# from kauldron import kd # pylint: disable=g-import-not-at-top


apitree.make_project(
modules=apitree.ModuleInfo(
api='kauldron.kd',
Expand All @@ -64,7 +34,10 @@ class _MockedTransform:
),
includes_paths={
'kauldron/konfig/docs/demo.ipynb': 'konfig.ipynb',
'kauldron/kontext/README.md': 'kontext.md',
'kauldron/data/py/README.md': 'data_py.md',
'kauldron/klinen/README.md': 'klinen.md',
'kauldron/random/README.md': 'random.md',
},
globals=globals(),
)
71 changes: 1 addition & 70 deletions docs/data.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Train, eval, randomness
# Data pipelines

## Pipelines options

Expand All @@ -20,78 +20,9 @@ By default, Kauldron provides two main pipelines implementations:
)
```

* `tf.data` based: `kd.data.TFDataPipeline` base class which itself implements
multiple sub-classes (see next section). For example:

```python
cfg.train_ds = kd.data.Tfds(
# TFDS specific args
name='mnist',
split='train',
shuffle=True,

# `kd.data.TFDataPipeline` args (common to all TFDataPipeline)
batch_size=256,
transforms=[
kd.data.Elements(keep=["image"]),
kd.data.ValueRange(key="image", vrange=(0, 1)),
],
)
```

While it's easy to implement your custom pipeline, please contact us if the
existing pipelines do not fit your use-case.

## TFDataPipeline

The following `tf.data` sources are available:

* `kd.data.Tfds`: TFDS dataset (note that this requires the dataset to be in
ArrayRecord format)
* `kd.data.TfdsLegacy`: TFDS dataset for datasets not supporting random access
( e.g. in `tfrecord` format)
* `kd.data.SeqIOTask`: SeqIO task
* `kd.data.SeqIOMixture`: SeqIO mixture
* Your custom `tf.data` pipeline. See: https://kauldron.rtfd.io/en/latest-kmix#implement-your-own

Additionally, any of those sources dataset can be combined using:

* `kd.data.SampleFromDatasets`: Sample from a combination of datasets.

Other sources will be added in the future. If your dataset is not yet supported,
please [contact us](https://kauldron.rtfd.io/en/latest-help#bugs-feedback).

See https://kauldron.rtfd.io/en/latest-kmix for details on how to implement a custom `tf.data` source.

Example of dataset mixture with nested transforms:

```python
cfg.train_ds = kd.data.SampleFromDatasets(
datasets=[
kd.data.Tfds(
name='cifar100',
split='train',
transforms=[
kd.data.Elements(keep=["image", "label"]),
],
),
kd.data.Tfds(
name='imagenet2012',
split='train',
transforms=[
kd.data.Elements(keep=["image", "label"]),
kd.data.Resize(key='image', height=32, width=32),
],
),
],
seed=0,
batch_size=256,
transforms=[
kd.data.RandomCrop(shape=(15, 15, None)),
],
)
```

## Transformations

Both `kd.data.py.PyGrainPipeline` and `kd.data.TFDataPipeline` can be customized
Expand Down
52 changes: 33 additions & 19 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,43 @@
```{include} ../README.md
```

```{toctree}
:hidden:
:caption: Guides
train
eval
checkpoint
konfig
klinen
```
```{eval-rst}
.. toctree::
:hidden:
:caption: Guides
```{toctree}
:hidden:
:caption: Links
intro
eval
sharding
checkpoint
data
.. toctree::
:hidden:
:caption: Modules
konfig
kontext
data_py
metrics
klinen
random
.. toctree::
:hidden:
:caption: Links
GitHub <https://github.com/google-research/kauldron>
Issues <https://github.com/google-research/kauldron/issues>
GitHub <https://github.com/google-research/kauldron>
```
```{toctree}
:hidden:
:caption: API
.. toctree::
:hidden:
:caption: API
api/kd/index
api/kd/index
```

<!--
Expand Down
3 changes: 0 additions & 3 deletions kauldron/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
# PyGrain based data pipeline.
from kauldron.data import py

# tf.data based data pipeline.
from kauldron.data import tf

# User should inherit from those base classes to have transformations
# supported by both TfGrain (`kd.data.tf`) and PyGrain (`kd.data.py`)
from kauldron.data.transforms.abc import MapTransform
Expand Down
44 changes: 24 additions & 20 deletions kauldron/metrics/auto_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class AutoState(base_state.State[_MetricT]):
Subclasses of AutoState have to use the @flax.struct.dataclass decorator and
can define two kinds of fields:
1) Data fields are defined by the `sum_field`, `concat_field` or
`truncate_field` functions. E.g. `d : Float['n'] = sum_field()`.
Data fields are pytrees of Jax arrays.
Expand Down Expand Up @@ -173,15 +174,16 @@ def sum_field(
Preserves shape and assumes that the other (merged) field has the same shape.
Usage:
```python
@flax.struct.dataclass
class ShapePreservingAverage(AutoState):
summed_values: Float['*any'] = sum_field()
total_values: Float['*any'] = sum_field()
def compute(self):
return self.summed_values / self.total_values
```
```python
@flax.struct.dataclass
class ShapePreservingAverage(AutoState):
summed_values: Float['*any'] = sum_field()
total_values: Float['*any'] = sum_field()
def compute(self):
return self.summed_values / self.total_values
```
Args:
default: The default value of the field.
Expand Down Expand Up @@ -212,12 +214,13 @@ def concat_field(
The final compute() method concatenates the arrays along the given axis.
Usage:
```python
@flax.struct.dataclass
class CollectTokens(AutoState):
# merged along token axis ('n') by concatenation
tokens: Float['b n d'] = concat_field(axis=1)
```
```python
@flax.struct.dataclass
class CollectTokens(AutoState):
# merged along token axis ('n') by concatenation
tokens: Float['b n d'] = concat_field(axis=1)
```
Args:
axis: The axis along which to concatenate the two arrays. Defaults to 0.
Expand Down Expand Up @@ -252,12 +255,13 @@ def truncate_field(
of a tensor, e.g. the first few images for plotting.
Usage:
```python
@flax.struct.dataclass
class CollectFirstKImages(AutoState):
num_images: int
images: Float['n h w 3'] = truncate_field(num_field="num_images")
```
```python
@flax.struct.dataclass
class CollectFirstKImages(AutoState):
num_images: int
images: Float['n h w 3'] = truncate_field(num_field="num_images")
```
Args:
num_field: The name of the field (in the state) that determines the number
Expand Down

0 comments on commit 47ef4e8

Please sign in to comment.