Skip to content

WAVC2025: Rethinking cluster-conditioned diffusion models for label-free image synthesis

License

Notifications You must be signed in to change notification settings

HHU-MMBS/cedm-official-wavc2025

Repository files navigation

Rethinking cluster-conditioned diffusion models (C-EDM)
Official PyTorch implementation

This codebase is based on NVIDIA's codebase for the EDM Paper: Elucidating the Design Space of Diffusion-Based Generative Models authored by Tero Karras, Miika Aittala, Timo Aila, Samuli Laine. A copy of the licence is provided in LICENCE.txt.


Requirements and setup

  • We recommend Linux for performance and compatibility reasons.
  • 1+ high-end NVIDIA GPU for training and sampling. We have done all testing and development using V100 and A100 GPUs.
  • 64-bit Python 3.8 and PyTorch 1.12.0 (or later). See https://pytorch.org for PyTorch installation instructions.
  • See environment.yml for Python library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
    • conda env create -f environment.yml -n cedm
    • conda activate cedm

Dataset preprocessing as in Karras et al. (EDM)

Datasets are stored in uncompressed ZIP archives containing uncompressed PNG files and a metadata file dataset.json for labels. Custom datasets can be created from a folder containing images; see python dataset_tool.py --help for more information.

CIFAR-10: Download the CIFAR-10 python version and convert to ZIP archive:

python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
    --dest=datasets/cifar10-32x32.zip
python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz

CIFAR-100: Download CIFAR-100 python version and convert to ZIP archive:

python dataset_tool.py --source=downloads/cifar100/cifar-100-python.tar.gz \
    --dest=datasets/cifar100-32x32.zip
python fid.py ref --data=datasets/cifar100-32x32.zip --dest=fid-refs/cifar100-32x32.npz

FFHQ: Download the Flickr-Faces-HQ dataset as 1024x1024 images and convert to ZIP archive at 64x64 resolution:

python dataset_tool.py --source=downloads/ffhq/images1024x1024 \
    --dest=datasets/ffhq-64x64.zip --resolution=64x64
python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz

Reproduce our results

To facilitate reproduction, all cluster assignments that were used to conduct our experiments are provided locally in the repository. To launch training, specify the variables at the top of the train_cedm.sh script.

# Specify the following parameters
free_gpu_mem=25000  # ~25GB are needed
dataset=cifar100  # choose from cifar10, cifar100, ffhq and afhqv2
duration=100  # M_img in paper
clusters=200  # C in paper

Then launch the training script:

bash train_cedm.sh

The script will run commands for training a model, generating 3 sets of 50k samples and evaluating their FID. The results are saved in '.csv' files for each set of images. We use the default hyperparameters from Karras et al. for training and sampling.

Dataset GPUs TrainingTTime Sampling Time (50k)
cifar10‑32x32 4xA100 ~1 days ~7 min
cifar100‑32x32 4xA100 ~1 days ~7 min
FFHQ‑64x64 4xA100 ~2 days ~13 min

If you want to train our own models, run the train.py script with appropriate parameters.


Generate samples

To generate samples with an already trained model, you can run the generate.py script with appropriate parameters. Below is an example command that uses a pre-trained EDM model for CIFAR-10 to generate 50k images. See generate.py for more information on the available options for image generation.

python generate.py \
    --seeds 0-49999 \
    --outdir $out_path \
    --network https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl \
    --batch 256 \
    --steps 18 \
    --sigma_min 2e-3 \
    --sigma_max 80 \
    --rho 7 ;

Evaluating FID

To compute the FID for a set of 50k generated images, use the calc function in fid.py. Below is an example command for CIFAR-10.

python fid.py calc --batch 128 --ref https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz \
    --images $images_path

where $images_path needs to be the folder containing the generated images.


Training, image generation and FID computation can be distributed across multiple GPUs by replacing python with torchrun --master_port=$port --nproc_per_node=4, where n_proc_per_node specifies the number of GPUs to be used.

Evaluating FDD (Fréchet distance with DINOv2)

The dgm_eval folder contains slightly modified code from the dgm-eval paper Exposing flaws of generative model evaluation metrics and their unfair treatment of diffusion models that provides code for the computation of many evaluation metrics for diffusion models, including FDD. In order to run the evaluation, you need to specify the required paths in eval_fdd.sh and run the script with:

bash eval_fdd.sh

This will perform the evaluation and save the results in a .csv file.


About

WAVC2025: Rethinking cluster-conditioned diffusion models for label-free image synthesis

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published