This repository contains the code related to our paper Discriminative Class Tokens for Text-to-Image Diffusion models.
Idan Schwartz*1, Vésteinn Snæbjarnarson*2, Hila Chefer1, Serge Belongie2, Lior Wolf1, Sagie Benaim2 1Tel Aviv University, 2University of Copenhagen, 3ETH Zürich * Denotes equal contribution
Recent advances in text-to-image diffusion models have enabled the generation of diverse and high-quality images. However, generated images often fall short of depicting subtle details and are susceptible to errors due to ambiguity in the input text. One way of alleviating these issues is to train diffusion models on class-labeled datasets. This comes with a downside, doing so limits their expressive power: (i) supervised datasets are generally small compared to large-scale scraped text-image datasets on which text-to-image models are trained, and so the quality and diversity of generated images are severely affected, or (ii) the input is a hard-coded label, as opposed to free-form text, which limits the control over the generated images. In this work, we propose a non-invasive fine-tuning technique that capitalizes on the expressive potential of free-form text while achieving high accuracy through discriminative signals from a pretrained classifier, which guides the generation. This is done by iteratively modifying the embedding of a single input token of a text-to-image diffusion model, using the classifier, by steering generated images toward a given target class. Our method is fast compared to prior fine-tuning methods and does not require a collection of in-class images or retraining of a noise-tolerant classifier. We evaluate our method extensively, showing that the generated images are: (i) more accurate and of higher quality than standard diffusion models, (ii) can be used to augment training data in a low-resource setting, and (iii) reveal information about the data used to train the guiding classifier.
We propose a technique that introduces a token (
- Release code and support for ImageNet
- Release support for iNaturalist and CUB200
- Add google colab
- [] Add hf-spaces
conda env create -f requirements.yml
conda activate discriminative-token
Run this command to log in with your HF Hub token if you haven't before:
huggingface-cli login
An overview of our method for optimizing a new discriminative token representation (
To train and evaluate use:
python run.py --class_index 283 --train True --evaluate True
The hyperparameters can be changed in the config.py
script. Note that the paper results are based on stable-diffusion version 1.4.
The script will create folders and store tokens representation in pipeline_token
and the images in img.
If you make use of our work, please cite our paper:
@article{schwartz2023discriminative,
title={Discriminative Class Tokens for Text-to-Image Diffusion Models},
author={Schwartz, Idan and Sn{\ae}bjarnarson, V{\'e}steinn and Chefer, Hila and Belongie, Serge and Wolf, Lior and Benaim, Sagie},
journal={arXiv preprint arXiv:2303.17155},
year={2023}
}