Skip to content

The official implementation of DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis

Notifications You must be signed in to change notification settings

tyshiwo1/DiM-DiffusionMamba

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis

The official implementation of our paper DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis.

drawing

drawing

drawing

Method Overview

drawing

Acknowledge

This code is mainly built on U-ViT and Mamba.

Installing Mamba may cost a lot of effort. If you encounter problems, this issues in Mamba may be very helpful.

Installation

# create env:
conda env create -f environment.yaml

# if you want to update the env `mamba` with the contents in `~/mamba_attn/environment.yaml`:
conda env update --name mamba --file ~/mamba_attn/environment.yaml --prune

# Switch to the correct environment
conda activate mamba-attn
conda install chardet

# Compiling Mamba. This step may take a lot of time, please be patient.
# You need to successfully install causal-conv1d first.
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .
# If failing to compile, you can copy the files in './build/' from another server which has compiled successfully; Maybe --user is necessary.

# Optional: if you have only 8 A100 to train Huge model with a batch size of 768, I recommand to install deepspeed to reduce the required GPU memory:
pip install deepspeed

Frequently Asked Questions:

  • If you encounter errors like ModuleNotFoundError: No module named 'selective_scan_cuda':

    Answer: you need to correctly install and compile Mamba:

    pip install causal-conv1d==1.2.0.post2 # The version maybe different depending on your cuda version
    CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .
  • failed Compilation:

    • The detected CUDA version mismatches the version that was used to compile PyTorch. Please make sure to use the same CUDA versions:

      Answer: you need to reinstall Pytorch with the correct version:

      # For example, on cuda 11.8:
      conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia
      # Then, compiling the mamba in our project again:
      CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .

Preparation Before Training and Evaluation

Please follow U-ViT, the same subtitle.

Checkpoints

Model FID training iterations batch size
ImageNet 256x256 (Huge/2) 2.40 425K 768
ImageNet 256x256 (Huge/2) 2.21 625K 768
ImageNet 512x512 (fine-tuned Huge/2) 3.94 Fine-tune 240

About the checkpoint files:

  • We use nnet_ema.pth for evaluation instead of nnet.pth.

  • nnet.pth is the trained model, while nnet_ema.pth is the EMA of model weights.

Evaluation

Use eval_ldm_discrete.py for evaluation and generating images with CFG

# ImageNet 256x256 Huge, 425K
# If your model checkpoint path is not 'workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth', you can change the path after '--nnet_path='
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet256_H_DiM.py --nnet_path='workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth'

# ImageNet 512x512 Huge
# The generated 512x512 images for evaluation cost ~22G.
# So I recommend setting a path to `config.sample.path` in the config `imagenet512_H_DiM_ft` if the space is tight for temporary files.
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'

# ImageNet 512x512 Huge, upsample 2x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config.
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'

# ImageNet 512x512 Huge, upsample 3x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config.
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_3x_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'

Training

# Cifar 32x32 Small
accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 ./train.py --config=configs/cifar10_S_DiM.py

# ImageNet 256x256 Large
accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet256_L_DiM.py

# ImageNet 256x256 Huge (Deepspeed Zero-2 for memory-efficient training)
accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet256_H_DiM.py

# ImageNet 512x512 Huge (Deepspeed Zero-2 for memory-efficient training)
# Fine-tuning, and you need to carefully check whether
# the pre-trained weights are in `workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth`.
# This location is set in the config file: `config.nnet.pretrained_path`.
# If there is no such ckpt, no pre-training weight will be loaded.
accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py

Citation

@misc{teng2024dim,
      title={DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis}, 
      author={Yao Teng and Yue Wu and Han Shi and Xuefei Ning and Guohao Dai and Yu Wang and Zhenguo Li and Xihui Liu},
      year={2024},
      eprint={2405.14224},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

About

The official implementation of DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published