Skip to content

Commit

Permalink
remove dreambooth task
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Nov 29, 2024
1 parent 2d787b2 commit 1656a2d
Show file tree
Hide file tree
Showing 28 changed files with 77 additions and 4,301 deletions.
40 changes: 40 additions & 0 deletions docs/source/tasks/dreambooth.mdx.bck
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# DreamBooth

DreamBooth is an innovative method that allows for the customization of text-to-image
models like Stable Diffusion using just a few images of a subject.
DreamBooth enables the generation of new, contextually varied images of the
subject in a range of scenes, poses, and viewpoints, expanding the creative
possibilities of generative models.


## Data Preparation

The data format for DreamBooth training is simple. All you need is images of a concept (e.g. a person) and a concept token.

### Step 1: Gather Your Images

Collect 3-5 high-quality images of the subject you wish to personalize.
These images should vary slightly in pose or background to provide the model with a
diverse learning set. You can select more images if you want to train a more robust model.


### Step 2: Select Your Model

Choose a base model from the Hugging Face Hub that is compatible with your needs.
It's essential to select a model that supports the image size of your training data.
Models available on the hub often have specific requirements or capabilities,
so ensure the model you choose can accommodate the dimensions of your images.


### Step 3: Define Your Concept Token

The concept token is a crucial element in DreamBooth training.
This token acts as a unique identifier for your subject within the model.
Typically, you will use a simple, descriptive keyword like prompt in the parameters
section of your training setup. This token will be used to generate new images of
your subject by the model.


## Parameters

[[autodoc]] trainers.dreambooth.params.DreamBoothTrainingParams
16 changes: 1 addition & 15 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from autotrain.app.utils import token_verification
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
Expand Down Expand Up @@ -99,7 +98,6 @@ def create_api_base_model(base_class, class_name):
LLMORPOTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMORPOTrainingParamsAPI")
LLMGenericTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMGenericTrainingParamsAPI")
LLMRewardTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMRewardTrainingParamsAPI")
DreamBoothTrainingParamsAPI = create_api_base_model(DreamBoothTrainingParams, "DreamBoothTrainingParamsAPI")
ImageClassificationParamsAPI = create_api_base_model(ImageClassificationParams, "ImageClassificationParamsAPI")
Seq2SeqParamsAPI = create_api_base_model(Seq2SeqParams, "Seq2SeqParamsAPI")
TabularClassificationParamsAPI = create_api_base_model(TabularParams, "TabularClassificationParamsAPI")
Expand Down Expand Up @@ -141,10 +139,6 @@ class LLMRewardColumnMapping(BaseModel):
rejected_text_column: str


class DreamBoothColumnMapping(BaseModel):
default: Optional[str] = None


class ImageClassificationColumnMapping(BaseModel):
image_column: str
target_column: str
Expand Down Expand Up @@ -237,7 +231,7 @@ class APICreateProjectModel(BaseModel):
Attributes:
project_name (str): The name of the project.
task (Literal): The type of task for the project. Supported tasks include various LLM tasks,
image classification, dreambooth, seq2seq, token classification, text classification,
image classification, seq2seq, token classification, text classification,
text regression, tabular classification, tabular regression, image regression, VLM tasks,
and extractive question answering.
base_model (str): The base model to be used for the project.
Expand Down Expand Up @@ -270,7 +264,6 @@ class APICreateProjectModel(BaseModel):
"st:triplet",
"st:qa",
"image-classification",
"dreambooth",
"seq2seq",
"token-classification",
"text-classification",
Expand Down Expand Up @@ -308,7 +301,6 @@ class APICreateProjectModel(BaseModel):
LLMGenericTrainingParamsAPI,
LLMRewardTrainingParamsAPI,
SentenceTransformersParamsAPI,
DreamBoothTrainingParamsAPI,
ImageClassificationParamsAPI,
Seq2SeqParamsAPI,
TabularClassificationParamsAPI,
Expand All @@ -329,7 +321,6 @@ class APICreateProjectModel(BaseModel):
LLMORPOColumnMapping,
LLMGenericColumnMapping,
LLMRewardColumnMapping,
DreamBoothColumnMapping,
ImageClassificationColumnMapping,
Seq2SeqColumnMapping,
TabularClassificationColumnMapping,
Expand Down Expand Up @@ -395,9 +386,6 @@ def validate_column_mapping(cls, values):
if not values.get("column_mapping").get("rejected_text_column"):
raise ValueError("rejected_text_column is required for llm:reward")
values["column_mapping"] = LLMRewardColumnMapping(**values["column_mapping"])
elif values.get("task") == "dreambooth":
if values.get("column_mapping"):
raise ValueError("column_mapping is not required for dreambooth")
elif values.get("task") == "seq2seq":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for seq2seq")
Expand Down Expand Up @@ -561,8 +549,6 @@ def validate_params(cls, values):
values["params"] = LLMGenericTrainingParamsAPI(**values["params"])
elif values.get("task") == "llm:reward":
values["params"] = LLMRewardTrainingParamsAPI(**values["params"])
elif values.get("task") == "dreambooth":
values["params"] = DreamBoothTrainingParamsAPI(**values["params"])
elif values.get("task") == "seq2seq":
values["params"] = Seq2SeqParamsAPI(**values["params"])
elif values.get("task") == "image-classification":
Expand Down
83 changes: 23 additions & 60 deletions src/autotrain/app/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def colab_app():
"Text Regression",
"Sequence to Sequence",
"Token Classification",
"DreamBooth LoRA",
"Image Classification",
"Image Regression",
"Object Detection",
Expand All @@ -55,7 +54,6 @@ def colab_app():
"Text Regression": "text-regression",
"Sequence to Sequence": "seq2seq",
"Token Classification": "token-classification",
"DreamBooth LoRA": "dreambooth",
"Image Classification": "image-classification",
"Image Regression": "image-regression",
"Object Detection": "image-object-detection",
Expand Down Expand Up @@ -260,11 +258,6 @@ def update_col_mapping(*args):
col_mapping.value = '{"text": "text", "label": "target"}'
dataset_source_dropdown.disabled = False
valid_split.disabled = False
elif task == "dreambooth":
col_mapping.value = '{"image": "image"}'
dataset_source_dropdown.value = "Local"
dataset_source_dropdown.disabled = True
valid_split.disabled = True
elif task == "image-classification":
col_mapping.value = '{"image": "image", "label": "label"}'
dataset_source_dropdown.disabled = False
Expand Down Expand Up @@ -315,8 +308,6 @@ def update_base_model(*args):
base_model.value = MODEL_CHOICES["llm"][0]
elif TASK_MAP[task_dropdown.value] == "image-classification":
base_model.value = MODEL_CHOICES["image-classification"][0]
elif TASK_MAP[task_dropdown.value] == "dreambooth":
base_model.value = MODEL_CHOICES["dreambooth"][0]
elif TASK_MAP[task_dropdown.value] == "seq2seq":
base_model.value = MODEL_CHOICES["seq2seq"][0]
elif TASK_MAP[task_dropdown.value] == "tabular:classification":
Expand Down Expand Up @@ -351,61 +342,33 @@ def start_training(b):
if chat_template is not None:
params_val = {k: v for k, v in params_val.items() if k != "chat_template"}

if TASK_MAP[task_dropdown.value] == "dreambooth":
prompt = params_val.get("prompt")
if prompt is None:
raise ValueError("Prompt is required for DreamBooth task")
if not isinstance(prompt, str):
raise ValueError("Prompt should be a string")
params_val = {k: v for k, v in params_val.items() if k != "prompt"}
else:
prompt = None

push_to_hub = params_val.get("push_to_hub", True)
if "push_to_hub" in params_val:
params_val = {k: v for k, v in params_val.items() if k != "push_to_hub"}

if TASK_MAP[task_dropdown.value] != "dreambooth":
config = {
"task": TASK_MAP[task_dropdown.value].split(":")[0],
"base_model": base_model.value,
"project_name": project_name.value,
"log": "tensorboard",
"backend": "local",
"data": {
"path": dataset_path.value,
"train_split": train_split_value,
"valid_split": valid_split_value,
"column_mapping": json.loads(col_mapping.value),
},
"params": params_val,
"hub": {
"username": "${{HF_USERNAME}}",
"token": "${{HF_TOKEN}}",
"push_to_hub": push_to_hub,
},
}
if TASK_MAP[task_dropdown.value].startswith("llm"):
config["data"]["chat_template"] = chat_template
if config["data"]["chat_template"] == "none":
config["data"]["chat_template"] = None
else:
config = {
"task": TASK_MAP[task_dropdown.value],
"base_model": base_model.value,
"project_name": project_name.value,
"backend": "local",
"data": {
"path": dataset_path.value,
"prompt": prompt,
},
"params": params_val,
"hub": {
"username": "${HF_USERNAME}",
"token": "${HF_TOKEN}",
"push_to_hub": push_to_hub,
},
}
config = {
"task": TASK_MAP[task_dropdown.value].split(":")[0],
"base_model": base_model.value,
"project_name": project_name.value,
"log": "tensorboard",
"backend": "local",
"data": {
"path": dataset_path.value,
"train_split": train_split_value,
"valid_split": valid_split_value,
"column_mapping": json.loads(col_mapping.value),
},
"params": params_val,
"hub": {
"username": "${{HF_USERNAME}}",
"token": "${{HF_TOKEN}}",
"push_to_hub": push_to_hub,
},
}
if TASK_MAP[task_dropdown.value].startswith("llm"):
config["data"]["chat_template"] = chat_template
if config["data"]["chat_template"] == "none":
config["data"]["chat_template"] = None

with open("config.yml", "w") as f:
yaml.dump(config, f)
Expand Down
54 changes: 0 additions & 54 deletions src/autotrain/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,59 +166,6 @@ def _fetch_image_object_detection_models():
return hub_models


def _fetch_dreambooth_models():
hub_models1 = list(
list_models(
task="text-to-image",
sort="downloads",
direction=-1,
limit=100,
full=False,
filter=["diffusers:StableDiffusionXLPipeline"],
)
)
hub_models2 = list(
list_models(
task="text-to-image",
sort="downloads",
direction=-1,
limit=100,
full=False,
filter=["diffusers:StableDiffusionPipeline"],
)
)
hub_models = list(hub_models1) + list(hub_models2)
hub_models = get_sorted_models(hub_models)

trending_models1 = list(
list_models(
task="text-to-image",
sort="likes7d",
direction=-1,
limit=30,
full=False,
filter=["diffusers:StableDiffusionXLPipeline"],
)
)
trending_models2 = list(
list_models(
task="text-to-image",
sort="likes7d",
direction=-1,
limit=30,
full=False,
filter=["diffusers:StableDiffusionPipeline"],
)
)
trending_models = list(trending_models1) + list(trending_models2)
if len(trending_models) > 0:
trending_models = get_sorted_models(trending_models)
hub_models = [m for m in hub_models if m not in trending_models]
hub_models = trending_models + hub_models

return hub_models


def _fetch_seq2seq_models():
hub_models = list(
list_models(
Expand Down Expand Up @@ -392,7 +339,6 @@ def fetch_models():
_mc["llm"] = _fetch_llm_models()
_mc["image-classification"] = _fetch_image_classification_models()
_mc["image-regression"] = _fetch_image_classification_models()
_mc["dreambooth"] = _fetch_dreambooth_models()
_mc["seq2seq"] = _fetch_seq2seq_models()
_mc["token-classification"] = _fetch_token_classification_models()
_mc["text-regression"] = _fetch_text_classification_models()
Expand Down
Loading

0 comments on commit 1656a2d

Please sign in to comment.