Skip to content

A minimal Diffusion Model build using only linear components

License

Notifications You must be signed in to change notification settings

willkurt/linear-diffusion

Repository files navigation

Linear Diffusion

A Python Library for experimenting with a minimal implementation of a diffusion model built using entirely linear components.

The accompanying Count Bayesie post gives a good over view of how this model works (and diffusion models in general).

Design

Linear Diffusion is an attempt to explore the performance of essentially the minimum viable Diffusion Model, such as Dall-E 2 and Stable Diffusion.

These models represent the current state of the art both in terms of performance and their immense complexity. Linear Diffusion aims to be to Diffusion Models what Logistic Regression is to Large Convolutional Neural Networks. One of my personal favorite benchmarks is that Logistic Regression, often dismissed by data scientists as "just" a linear model, is able to achieve > 90% accuracy on the MNIST data set. While this is far from state of the art, it is much better than many people naively guess. Likewise while Linear Diffusion is far from the capabilities of models multiple orders of magnitude inside, it still performs surprisingly well!

Diffusion models can be broken down into 3 major parts:

  • An image encoder, typically a Variational Autoencoder
  • A text embedder which creates vector representation of the target text, often an LSTM, Transformer or similar style LLM
  • A denoiser that predicts noise in the image given a caption represented by the text embedder, typically a Denoising UNET

For those familiar with current work in deep learning, what makes diffusion models so impressive and complex is that each of these components itself is a very sophisticated neural network performing powerful non-linear transformations.

Clearly it is impossible for Linear Models to achieve anything like this! Or is it?

The goal of this project is to replace each of these parts with a simple, linear model and see if we can even replatic the most rudimentary task of a transformer. The simple challenge is:

from a simple language of only single digits, can we generate passable images of digits?

I don't expect to use linear models to create an "Astronaut Riding a Horse", but these results are pretty cool:

"linear diffusion results"

The architecture of Linear Diffusion consists of these linear components:

  • PCA for image encoding
  • One-Hot encoding for "text embedding"
  • Linear regression to denoise

Given only these simple components Linear Diffusion is able to do a surprising job of creating a purely linear generative model!

Install

To install clone the repo (or just download the package file in the /dist folder) and:

pip install linear-diffusion-0.1.tar.gz

Usage

Currently, it is assumed that Linear Diffusion will be used exclusively to play around with MNIST digits, though it should be possible to play around with other (square) image data sets.

Here is a basic example case:

import mnist
import numpy as np
from lineardiffusion import LinearDiffusion

# setup the training data
train_imgs = mnist.train_images()
test_imgs = mnist.test_images()
all_imgs = np.concatenate([train_imgs, test_imgs])

all_labels = [str(val) for val in np.concatenate([mnist.train_labels(), mnist.test_labels()])]



ld = LinearDiffusion()
ld.fit(all_labels, all_imgs)

The prediction will take a list of "text" (the only "language" the model knows is of single digits), and output a vector images maxtrices.

The code used to generate the image above looks like this:

import matplotlib.pyplot as plt
from itertools import chain

rows=10
cols=5
fig, ax = plt.subplots(rows,cols,facecolor='white', figsize=(3,9))

test_labels = list(chain.from_iterable([[str(i)]*5 for i in range(10)]))

# here's our prediction!
test_images = ld.predict(test_labels,seed=137)

for i in range(rows*cols):
    ax[i//cols][i%cols].imshow(test_images[i],
          cmap='gray_r')
    ax[i//cols][i%cols].axis('off')
    ax[i//cols][i%cols].set_title(f"\"{test_labels[i]}\"")
fig.suptitle("Images Generated from Prompt")

This code was just made as a quick demo so there are undoubtably many, many bugs and any deviation from this basic behavior will likely break things.

About

A minimal Diffusion Model build using only linear components

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages