This is the official implementation of our paper Truncated Consistency Models.
- Tested environment: PyTorch 2.3.1, Python 3.8.13.
- Dependencies:
pip install click Pillow psutil requests scipy tqdm matplotlib wandb
- We tested with NVIDIA A100 GPUs.
We use the CIFAR-10 and ImageNet datasets for training and evaluation. You should prepare CIFAR-10 following EDM, and ImageNet following EDM2.
The pre-trained TCM/EDM/EDM2 checkpoints are provided in this link.
Model | Dataset | 1-step FID | 2-step FID |
---|---|---|---|
TCM (DDPM++) | CIFAR-10 | 2.46 | 2.05 |
TCM (EDM2-S) | ImageNet64x64 | 2.88 | 2.31 |
TCM (EDM2-XL) | ImageNet64x64 | 2.20 | 1.62 |
Stage-1 (CIFAR-10):
iteration=3000000
bsize=512
double_iter=25000
print_iter=100
save_iter=10000
vis_iter=1000
# Convert to kimg
duration=$(($iteration * $bsize / 1000000))
# Convert to ticks
double=$(($double_iter/ 100))
snap=$(($save_iter / 100))
vis=$(($vis_iter / 100))
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --outdir=ct-runs/09.26 --desc 09.26.cifar.stage1 --arch=ddpmpp \
--duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \
--dropout=0.2 --ratio_limit 0.999 --weighting default -c 1e-8 \
--transfer=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl \
--metrics='fid50k_full' --lr=1e-4 --cond=0 --optim=RAdam --augment=0.0 --mean=-1.1 --std=2.0 \
--data=../datasets/cifar10-32x32.zip
Stage-2 (CIFAR-10):
iteration=3000000
bsize=1024
double_iter=25000
print_iter=100
save_iter=10000
vis_iter=1000
# Convert to kimg
duration=$(($iteration * $bsize / 1000000))
# Convert to ticks
double=$(($double_iter/ 100))
snap=$(($save_iter / 100))
vis=$(($vis_iter / 100))
torchrun --nnodes=1 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --outdir=ct-runs/09.26 --desc 09.26.cifar.stage2 --arch=ddpmpp \
--duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \
--dropout=0.2 --ratio_limit 0.999 --weighting default -c 1e-8 \
--tcm_transition_t=1 --t_lower=1 --tcm_teacher_pkl=<PATH_TO_STAGE1_MODEL.pkl> \
--transfer=<PATH_TO_STAGE1_MODEL.pkl> --resume_tick=<TICK_OF_STAGE1_MODEL> \
--metrics='fid50k_full' --lr=5e-5 --cond=0 --optim=RAdam --augment=0.0 \
--mean=0 --std=0.2 --boundary_prob=0.25 --w_boundary=0.1 --tdist=t --df=0.01 \
--data=../datasets/cifar10-32x32.zip
Stage-1 (ImageNet 64x64):
iteration=800000
bsize=2048
double_iter=10000
print_iter=100
save_iter=10000
vis_iter=1000
# Convert to kimg
duration=$(($iteration * $bsize / 1000000))
# Convert to ticks
double=$(($double_iter/ 100))
snap=$(($save_iter / 100))
vis=$(($vis_iter / 100))
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --cond=1 --outdir=ct-runs/ --desc imagenet.stage1 --fp16=1 --ls=16 --arch=edm2-img64-s --augment=0 \
--duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \
--dropout=0.4 --ratio_limit 0.9961 --weighting cout_sq -c 0.06\
--decay_iter=2000 --optim=Adam --rampup_iter=0 --lr=1e-3 --beta2=0.99 --metrics=fid50k_full --ema_gamma=16.97 --ema_type=power -q 4 \
--data=../datasets/imgnet64x64-dhariwal.zip --transfer=edm2-imgnet64-s.pth \
--mean=-0.8 --std=1.6
Stage-2 (ImageNet 64x64):
iteration=800000
bsize=1024
double_iter=10000
print_iter=100
save_iter=10000
vis_iter=1000
# Convert to kimg
duration=$(($iteration * $bsize / 1000000))
# Convert to ticks
double=$(($double_iter/ 100))
snap=$(($save_iter / 100))
vis=$(($vis_iter / 100))
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --cond=1 --outdir=ct-runs/ --desc imagenet.stage2 --fp16=1 --ls=16 --arch=edm2-img64-s --augment=0 \
--duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \
--dropout=0.4 --ratio_limit 0.9961 --weighting cout_sq -c 0.06 \
--tcm_transition_t=1 --t_lower=1 --tcm_teacher_pkl=<PATH_TO_STAGE1_MODEL.pkl> \
--transfer=<PATH_TO_STAGE1_MODEL.pkl> --resume_tick=<TICK_OF_STAGE1_MODEL> \
--decay_iter=8000 --optim=Adam --rampup_iter=0 --lr=5e-4 --beta2=0.99 --metrics=fid50k_full --ema_gamma=16.97 --ema_type=power -q 4 \
--data=../datasets/imgnet64x64-dhariwal.zip --mean=0 --std=0.2 --boundary_prob=0.25 --w_boundary=0.1 --tdist=t --df=0.01
# Evaluate a CIFAR-10 model with 4 GPUs
torchrun --standalone --nproc_per_node=4 ct_eval.py \
--outdir=ct-evals --data=../datasets/cifar10-32x32.zip \
--cond=0 --arch=ddpmpp --metrics=fid50k_full \
--resume cifar-stage2-2.46.pkl
# Evaluate an ImageNet64 model (EDM2-S) with 4 GPUs
torchrun --standalone --nproc_per_node=4 ct_eval.py \
--outdir=ct-evals --data=../datasets/imgnet64x64-dhariwal.zip \
--cond=1 --arch=edm2-img64-s --metrics=fid50k_full --fp16=1 \
--resume imgnet-stage2-edm2s-2.88.pkl
This code is built on the implementation of EDM and EDM2. We thank the authors for their great implementation.
Copyright @ 2024, NVIDIA Corporation. All rights reserved.
This work is made available under the Nvidia Source Code License-NC.
The model checkpoints are shared under Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)
. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing.
@misc{lee2024truncatedconsistencymodels,
title={Truncated Consistency Models},
author={Sangyun Lee and Yilun Xu and Tomas Geffner and Giulia Fanti and Karsten Kreis and Arash Vahdat and Weili Nie},
year={2024},
eprint={2410.14895},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.14895},
}