Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove dreambooth task #816

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/source/tasks/dreambooth.mdx.bck
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# DreamBooth

DreamBooth is an innovative method that allows for the customization of text-to-image
models like Stable Diffusion using just a few images of a subject.
DreamBooth enables the generation of new, contextually varied images of the
subject in a range of scenes, poses, and viewpoints, expanding the creative
possibilities of generative models.


## Data Preparation

The data format for DreamBooth training is simple. All you need is images of a concept (e.g. a person) and a concept token.

### Step 1: Gather Your Images

Collect 3-5 high-quality images of the subject you wish to personalize.
These images should vary slightly in pose or background to provide the model with a
diverse learning set. You can select more images if you want to train a more robust model.


### Step 2: Select Your Model

Choose a base model from the Hugging Face Hub that is compatible with your needs.
It's essential to select a model that supports the image size of your training data.
Models available on the hub often have specific requirements or capabilities,
so ensure the model you choose can accommodate the dimensions of your images.


### Step 3: Define Your Concept Token

The concept token is a crucial element in DreamBooth training.
This token acts as a unique identifier for your subject within the model.
Typically, you will use a simple, descriptive keyword like prompt in the parameters
section of your training setup. This token will be used to generate new images of
your subject by the model.


## Parameters

[[autodoc]] trainers.dreambooth.params.DreamBoothTrainingParams
16 changes: 1 addition & 15 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from autotrain.app.utils import token_verification
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
Expand Down Expand Up @@ -99,7 +98,6 @@ def create_api_base_model(base_class, class_name):
LLMORPOTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMORPOTrainingParamsAPI")
LLMGenericTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMGenericTrainingParamsAPI")
LLMRewardTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMRewardTrainingParamsAPI")
DreamBoothTrainingParamsAPI = create_api_base_model(DreamBoothTrainingParams, "DreamBoothTrainingParamsAPI")
ImageClassificationParamsAPI = create_api_base_model(ImageClassificationParams, "ImageClassificationParamsAPI")
Seq2SeqParamsAPI = create_api_base_model(Seq2SeqParams, "Seq2SeqParamsAPI")
TabularClassificationParamsAPI = create_api_base_model(TabularParams, "TabularClassificationParamsAPI")
Expand Down Expand Up @@ -141,10 +139,6 @@ class LLMRewardColumnMapping(BaseModel):
rejected_text_column: str


class DreamBoothColumnMapping(BaseModel):
default: Optional[str] = None


class ImageClassificationColumnMapping(BaseModel):
image_column: str
target_column: str
Expand Down Expand Up @@ -237,7 +231,7 @@ class APICreateProjectModel(BaseModel):
Attributes:
project_name (str): The name of the project.
task (Literal): The type of task for the project. Supported tasks include various LLM tasks,
image classification, dreambooth, seq2seq, token classification, text classification,
image classification, seq2seq, token classification, text classification,
text regression, tabular classification, tabular regression, image regression, VLM tasks,
and extractive question answering.
base_model (str): The base model to be used for the project.
Expand Down Expand Up @@ -270,7 +264,6 @@ class APICreateProjectModel(BaseModel):
"st:triplet",
"st:qa",
"image-classification",
"dreambooth",
"seq2seq",
"token-classification",
"text-classification",
Expand Down Expand Up @@ -308,7 +301,6 @@ class APICreateProjectModel(BaseModel):
LLMGenericTrainingParamsAPI,
LLMRewardTrainingParamsAPI,
SentenceTransformersParamsAPI,
DreamBoothTrainingParamsAPI,
ImageClassificationParamsAPI,
Seq2SeqParamsAPI,
TabularClassificationParamsAPI,
Expand All @@ -329,7 +321,6 @@ class APICreateProjectModel(BaseModel):
LLMORPOColumnMapping,
LLMGenericColumnMapping,
LLMRewardColumnMapping,
DreamBoothColumnMapping,
ImageClassificationColumnMapping,
Seq2SeqColumnMapping,
TabularClassificationColumnMapping,
Expand Down Expand Up @@ -395,9 +386,6 @@ def validate_column_mapping(cls, values):
if not values.get("column_mapping").get("rejected_text_column"):
raise ValueError("rejected_text_column is required for llm:reward")
values["column_mapping"] = LLMRewardColumnMapping(**values["column_mapping"])
elif values.get("task") == "dreambooth":
if values.get("column_mapping"):
raise ValueError("column_mapping is not required for dreambooth")
elif values.get("task") == "seq2seq":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for seq2seq")
Expand Down Expand Up @@ -561,8 +549,6 @@ def validate_params(cls, values):
values["params"] = LLMGenericTrainingParamsAPI(**values["params"])
elif values.get("task") == "llm:reward":
values["params"] = LLMRewardTrainingParamsAPI(**values["params"])
elif values.get("task") == "dreambooth":
values["params"] = DreamBoothTrainingParamsAPI(**values["params"])
elif values.get("task") == "seq2seq":
values["params"] = Seq2SeqParamsAPI(**values["params"])
elif values.get("task") == "image-classification":
Expand Down
83 changes: 23 additions & 60 deletions src/autotrain/app/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def colab_app():
"Text Regression",
"Sequence to Sequence",
"Token Classification",
"DreamBooth LoRA",
"Image Classification",
"Image Regression",
"Object Detection",
Expand All @@ -55,7 +54,6 @@ def colab_app():
"Text Regression": "text-regression",
"Sequence to Sequence": "seq2seq",
"Token Classification": "token-classification",
"DreamBooth LoRA": "dreambooth",
"Image Classification": "image-classification",
"Image Regression": "image-regression",
"Object Detection": "image-object-detection",
Expand Down Expand Up @@ -260,11 +258,6 @@ def update_col_mapping(*args):
col_mapping.value = '{"text": "text", "label": "target"}'
dataset_source_dropdown.disabled = False
valid_split.disabled = False
elif task == "dreambooth":
col_mapping.value = '{"image": "image"}'
dataset_source_dropdown.value = "Local"
dataset_source_dropdown.disabled = True
valid_split.disabled = True
elif task == "image-classification":
col_mapping.value = '{"image": "image", "label": "label"}'
dataset_source_dropdown.disabled = False
Expand Down Expand Up @@ -315,8 +308,6 @@ def update_base_model(*args):
base_model.value = MODEL_CHOICES["llm"][0]
elif TASK_MAP[task_dropdown.value] == "image-classification":
base_model.value = MODEL_CHOICES["image-classification"][0]
elif TASK_MAP[task_dropdown.value] == "dreambooth":
base_model.value = MODEL_CHOICES["dreambooth"][0]
elif TASK_MAP[task_dropdown.value] == "seq2seq":
base_model.value = MODEL_CHOICES["seq2seq"][0]
elif TASK_MAP[task_dropdown.value] == "tabular:classification":
Expand Down Expand Up @@ -351,61 +342,33 @@ def start_training(b):
if chat_template is not None:
params_val = {k: v for k, v in params_val.items() if k != "chat_template"}

if TASK_MAP[task_dropdown.value] == "dreambooth":
prompt = params_val.get("prompt")
if prompt is None:
raise ValueError("Prompt is required for DreamBooth task")
if not isinstance(prompt, str):
raise ValueError("Prompt should be a string")
params_val = {k: v for k, v in params_val.items() if k != "prompt"}
else:
prompt = None

push_to_hub = params_val.get("push_to_hub", True)
if "push_to_hub" in params_val:
params_val = {k: v for k, v in params_val.items() if k != "push_to_hub"}

if TASK_MAP[task_dropdown.value] != "dreambooth":
config = {
"task": TASK_MAP[task_dropdown.value].split(":")[0],
"base_model": base_model.value,
"project_name": project_name.value,
"log": "tensorboard",
"backend": "local",
"data": {
"path": dataset_path.value,
"train_split": train_split_value,
"valid_split": valid_split_value,
"column_mapping": json.loads(col_mapping.value),
},
"params": params_val,
"hub": {
"username": "${{HF_USERNAME}}",
"token": "${{HF_TOKEN}}",
"push_to_hub": push_to_hub,
},
}
if TASK_MAP[task_dropdown.value].startswith("llm"):
config["data"]["chat_template"] = chat_template
if config["data"]["chat_template"] == "none":
config["data"]["chat_template"] = None
else:
config = {
"task": TASK_MAP[task_dropdown.value],
"base_model": base_model.value,
"project_name": project_name.value,
"backend": "local",
"data": {
"path": dataset_path.value,
"prompt": prompt,
},
"params": params_val,
"hub": {
"username": "${HF_USERNAME}",
"token": "${HF_TOKEN}",
"push_to_hub": push_to_hub,
},
}
config = {
"task": TASK_MAP[task_dropdown.value].split(":")[0],
"base_model": base_model.value,
"project_name": project_name.value,
"log": "tensorboard",
"backend": "local",
"data": {
"path": dataset_path.value,
"train_split": train_split_value,
"valid_split": valid_split_value,
"column_mapping": json.loads(col_mapping.value),
},
"params": params_val,
"hub": {
"username": "${{HF_USERNAME}}",
"token": "${{HF_TOKEN}}",
"push_to_hub": push_to_hub,
},
}
if TASK_MAP[task_dropdown.value].startswith("llm"):
config["data"]["chat_template"] = chat_template
if config["data"]["chat_template"] == "none":
config["data"]["chat_template"] = None

with open("config.yml", "w") as f:
yaml.dump(config, f)
Expand Down
54 changes: 0 additions & 54 deletions src/autotrain/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,59 +166,6 @@ def _fetch_image_object_detection_models():
return hub_models


def _fetch_dreambooth_models():
hub_models1 = list(
list_models(
task="text-to-image",
sort="downloads",
direction=-1,
limit=100,
full=False,
filter=["diffusers:StableDiffusionXLPipeline"],
)
)
hub_models2 = list(
list_models(
task="text-to-image",
sort="downloads",
direction=-1,
limit=100,
full=False,
filter=["diffusers:StableDiffusionPipeline"],
)
)
hub_models = list(hub_models1) + list(hub_models2)
hub_models = get_sorted_models(hub_models)

trending_models1 = list(
list_models(
task="text-to-image",
sort="likes7d",
direction=-1,
limit=30,
full=False,
filter=["diffusers:StableDiffusionXLPipeline"],
)
)
trending_models2 = list(
list_models(
task="text-to-image",
sort="likes7d",
direction=-1,
limit=30,
full=False,
filter=["diffusers:StableDiffusionPipeline"],
)
)
trending_models = list(trending_models1) + list(trending_models2)
if len(trending_models) > 0:
trending_models = get_sorted_models(trending_models)
hub_models = [m for m in hub_models if m not in trending_models]
hub_models = trending_models + hub_models

return hub_models


def _fetch_seq2seq_models():
hub_models = list(
list_models(
Expand Down Expand Up @@ -392,7 +339,6 @@ def fetch_models():
_mc["llm"] = _fetch_llm_models()
_mc["image-classification"] = _fetch_image_classification_models()
_mc["image-regression"] = _fetch_image_classification_models()
_mc["dreambooth"] = _fetch_dreambooth_models()
_mc["seq2seq"] = _fetch_seq2seq_models()
_mc["token-classification"] = _fetch_token_classification_models()
_mc["text-regression"] = _fetch_text_classification_models()
Expand Down
Loading
Loading