This is the official repo for the paper One-Step Diffusion Distillation via Deep Equilibrium Models, by Zhengyang Geng*, Ashwini Pokle*, and J. Zico Kolter.
First, download the datasets EDM-Uncond-CIFAR and EDM-Cond-CIFAR from this link.
Set up the --data_path
in run.sh
to the dir where you store the datasets, like --data_path DATA_DIR/EDM-Uncond-CIFAR-1M
.
In addition, download the precomputed dataset statistics from this link.
Set up the --stat_path
in run.sh
and eval.sh
using your download dir plus stat name.
To train a GET, run this command:
bash run.sh N_GPU DDP_PORT --model MODEL_NAME --name EXP_NAME
N_GPU
is the number of GPU used for training.
DDP_PORT
is the port number for syncing gradient during distributed training.
MODEL_NAME
is the model's name.
See all available models using python train.py -h
.
The training log, checkpoints, and sampled images will be saved to ./results
using your EXP_NAME
.
For example, this command train a GET-S/2 (of patch size 2) on 4 GPUs.
bash run.sh 4 12345 --model GET-S/2 --name test-GET
To train a ViT, run this command:
bash run.sh N_GPU DDP_PORT --model ViT-B/2 --name EXP_NAME
For training conditional models, add the --cond
command.
For the O(1)-memory training, add the --mem
command.
Download pretrained models from this link.
To load a checkpoint for evaluation, run this command
bash run.sh N_GPU DDP_PORT --model MODEL_NAME --resume CKPT_PATH --name EXP_NAME
The evaluation log and sampled images will be saved to ./eval-results
plus your EXP_NAME
.
For evaluating conditional models, add the --cond
command. Here is an example.
bash run.sh 4 12345 --model GET-B/2 --cond --resume CKPT_DIR/GET-B-cond-2M-data-bs256.pth
You can see the generative performance here. The discussion there might be interesting.
First, clone the EDM repo. Then, copy the files under /data
to the /edm
directory.
Set up the DATA_PATH
in dataset.sh
for storing the synthetic dataset.
Run the following command to generate both conditional and unconditional training sets.
bash dataset.sh
If you want to generate more data pairs, adjust the range of --seeds=0-MAX_SAMPLES
.
If you find our work helpful to your research, please consider citing this paper. :)
@inproceedings{
geng2023onestep,
title={One-Step Diffusion Distillation via Deep Equilibrium Models},
author={Zhengyang Geng and Ashwini Pokle and J Zico Kolter},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}
Feel free to contact us if you have additional questions! Please drop an email to [email protected] (or Twitter) or [email protected].
This project is built upon TorchDEQ, DiT, and timm. Thanks for the awesome projects!