Dreambooth is an algorithm for fine-tuning text-to-image diffusion models for subject-driven generation. This was a work by Google Research in 2023. Dreambooth personalizes a diffusion model by "implanting" a (unique identifier, subject) pair to the model's output space using a very small set of subject images.
In this work, inferencing of Stable Diffusion has been replicated using the original work of Stable Diffusion for text-to-image generation and text-guided image-to-image translation. Then the Dreambooth algorithm was implemented from scratch.
Create the anaconda environment sd
from the following command and activate it.
conda create -f environment.yaml
conda activate sd
The official weights of Stable Diffusion V1 can be downloaded from Huggingface using the sd-weights.sh
script.
bash download/sd-weights.sh
This implementation is based on the "High-Resolution Image Synthesis with Latent Diffusion Models" paper that was published at CVPR 2022. This Stable Diffusion implementation only focuses on sampling text-to-image and image-to-image processes.
Use the script txt2img.py
python txt2img.py -p "Hide the pain Harold drinking tea" -v 4
Use the script img2img.py
python img2img.py -i imgs/bear.jpg -p "A realistic bear in it's natural habitat under the moon light. Cinematic lighting" -v 4
Please note that most of the source codes are burrowed from the official Stable Diffusion Repository.
To start the training, users will have to collect a set of images of the subject and determine a suitable class name. Once they have been satisfied, users can invoke the Dreambooth script at dreambooth.py
. For example,
python dreambooth.py -c "water bottle" -i data/other/bottle -s 200 -l 1.5 --learning-rate 1e-5 --iterations 3000
Finetuning a diffusion model can be unstable. The precise amount of training iterations can be different from subject to subject. Hence, it is recommended to use the --save-every
(-s
) flag in the above script to save multiple checkpoints throughout the training. Also, users may have to experiment with the best values for the hyperparameters depending on their images.
After this training, the subject will be known by the diffusion model with the name "sks" (This 'rare-token' is hard coded in this implementation). Now with these trained weights, users may invoke the inferencing script txt2img.py
while providing the path of the newly trained weights using the flag --ckpt-path
(-c
). For ease of usage, the --ckpt-path
accepts directories containing a set of checkpoints. So users may directly invoke inferencing jobs on multiple checkpoints straight away to evaluate their performance.
Depending on the GPU RAM, users may select to also train the text encoder (using the -t
flag) and select training in a higher precision (using the --precision
flag) to obtain better results.
A Stable Diffusion model was trained using the Dreambooth algorithm. The learning rate was set to
Input images
Generated images
This implementation was done as a part of two presentations that I did for the In19-S8-EN4583 - Advances in Machine Vision course at the University of Moratuwa, Sri Lanka on 08.03.2024 and 21.03.2023.
- Stable Diffusion presentation slides: Google Slides.
- Stable Diffusion presentation recording: YouTube.
- Dreambooth presentation slides: Google Slides
- Dreambooth presentation recording: YouTube
This implementation would have not been possible except for the computational power that was accessible to me through the Aberdeen server at the Department of Electronic and Telecommunication Engineering, University of Moratuwa, Sri Lanka. All the training and inferencing were done on a single Nvidia 24 GB Quadro RTX 6000 graphics card.