diff --git a/docs/source/tasks/dreambooth.mdx.bck b/docs/source/tasks/dreambooth.mdx.bck new file mode 100644 index 0000000000..f8402a9ed8 --- /dev/null +++ b/docs/source/tasks/dreambooth.mdx.bck @@ -0,0 +1,40 @@ +# DreamBooth + +DreamBooth is an innovative method that allows for the customization of text-to-image +models like Stable Diffusion using just a few images of a subject. +DreamBooth enables the generation of new, contextually varied images of the +subject in a range of scenes, poses, and viewpoints, expanding the creative +possibilities of generative models. + + +## Data Preparation + +The data format for DreamBooth training is simple. All you need is images of a concept (e.g. a person) and a concept token. + +### Step 1: Gather Your Images + +Collect 3-5 high-quality images of the subject you wish to personalize. +These images should vary slightly in pose or background to provide the model with a +diverse learning set. You can select more images if you want to train a more robust model. + + +### Step 2: Select Your Model + +Choose a base model from the Hugging Face Hub that is compatible with your needs. +It's essential to select a model that supports the image size of your training data. +Models available on the hub often have specific requirements or capabilities, +so ensure the model you choose can accommodate the dimensions of your images. + + +### Step 3: Define Your Concept Token + +The concept token is a crucial element in DreamBooth training. +This token acts as a unique identifier for your subject within the model. +Typically, you will use a simple, descriptive keyword like prompt in the parameters +section of your training setup. This token will be used to generate new images of +your subject by the model. + + +## Parameters + +[[autodoc]] trainers.dreambooth.params.DreamBoothTrainingParams \ No newline at end of file diff --git a/src/autotrain/app/api_routes.py b/src/autotrain/app/api_routes.py index 5b6031692d..8563ab15b8 100644 --- a/src/autotrain/app/api_routes.py +++ b/src/autotrain/app/api_routes.py @@ -12,7 +12,6 @@ from autotrain.app.utils import token_verification from autotrain.project import AutoTrainProject from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams @@ -99,7 +98,6 @@ def create_api_base_model(base_class, class_name): LLMORPOTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMORPOTrainingParamsAPI") LLMGenericTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMGenericTrainingParamsAPI") LLMRewardTrainingParamsAPI = create_api_base_model(LLMTrainingParams, "LLMRewardTrainingParamsAPI") -DreamBoothTrainingParamsAPI = create_api_base_model(DreamBoothTrainingParams, "DreamBoothTrainingParamsAPI") ImageClassificationParamsAPI = create_api_base_model(ImageClassificationParams, "ImageClassificationParamsAPI") Seq2SeqParamsAPI = create_api_base_model(Seq2SeqParams, "Seq2SeqParamsAPI") TabularClassificationParamsAPI = create_api_base_model(TabularParams, "TabularClassificationParamsAPI") @@ -141,10 +139,6 @@ class LLMRewardColumnMapping(BaseModel): rejected_text_column: str -class DreamBoothColumnMapping(BaseModel): - default: Optional[str] = None - - class ImageClassificationColumnMapping(BaseModel): image_column: str target_column: str @@ -237,7 +231,7 @@ class APICreateProjectModel(BaseModel): Attributes: project_name (str): The name of the project. task (Literal): The type of task for the project. Supported tasks include various LLM tasks, - image classification, dreambooth, seq2seq, token classification, text classification, + image classification, seq2seq, token classification, text classification, text regression, tabular classification, tabular regression, image regression, VLM tasks, and extractive question answering. base_model (str): The base model to be used for the project. @@ -270,7 +264,6 @@ class APICreateProjectModel(BaseModel): "st:triplet", "st:qa", "image-classification", - "dreambooth", "seq2seq", "token-classification", "text-classification", @@ -308,7 +301,6 @@ class APICreateProjectModel(BaseModel): LLMGenericTrainingParamsAPI, LLMRewardTrainingParamsAPI, SentenceTransformersParamsAPI, - DreamBoothTrainingParamsAPI, ImageClassificationParamsAPI, Seq2SeqParamsAPI, TabularClassificationParamsAPI, @@ -329,7 +321,6 @@ class APICreateProjectModel(BaseModel): LLMORPOColumnMapping, LLMGenericColumnMapping, LLMRewardColumnMapping, - DreamBoothColumnMapping, ImageClassificationColumnMapping, Seq2SeqColumnMapping, TabularClassificationColumnMapping, @@ -395,9 +386,6 @@ def validate_column_mapping(cls, values): if not values.get("column_mapping").get("rejected_text_column"): raise ValueError("rejected_text_column is required for llm:reward") values["column_mapping"] = LLMRewardColumnMapping(**values["column_mapping"]) - elif values.get("task") == "dreambooth": - if values.get("column_mapping"): - raise ValueError("column_mapping is not required for dreambooth") elif values.get("task") == "seq2seq": if not values.get("column_mapping"): raise ValueError("column_mapping is required for seq2seq") @@ -561,8 +549,6 @@ def validate_params(cls, values): values["params"] = LLMGenericTrainingParamsAPI(**values["params"]) elif values.get("task") == "llm:reward": values["params"] = LLMRewardTrainingParamsAPI(**values["params"]) - elif values.get("task") == "dreambooth": - values["params"] = DreamBoothTrainingParamsAPI(**values["params"]) elif values.get("task") == "seq2seq": values["params"] = Seq2SeqParamsAPI(**values["params"]) elif values.get("task") == "image-classification": diff --git a/src/autotrain/app/colab.py b/src/autotrain/app/colab.py index 9abc25695d..2193ba048f 100644 --- a/src/autotrain/app/colab.py +++ b/src/autotrain/app/colab.py @@ -32,7 +32,6 @@ def colab_app(): "Text Regression", "Sequence to Sequence", "Token Classification", - "DreamBooth LoRA", "Image Classification", "Image Regression", "Object Detection", @@ -55,7 +54,6 @@ def colab_app(): "Text Regression": "text-regression", "Sequence to Sequence": "seq2seq", "Token Classification": "token-classification", - "DreamBooth LoRA": "dreambooth", "Image Classification": "image-classification", "Image Regression": "image-regression", "Object Detection": "image-object-detection", @@ -260,11 +258,6 @@ def update_col_mapping(*args): col_mapping.value = '{"text": "text", "label": "target"}' dataset_source_dropdown.disabled = False valid_split.disabled = False - elif task == "dreambooth": - col_mapping.value = '{"image": "image"}' - dataset_source_dropdown.value = "Local" - dataset_source_dropdown.disabled = True - valid_split.disabled = True elif task == "image-classification": col_mapping.value = '{"image": "image", "label": "label"}' dataset_source_dropdown.disabled = False @@ -315,8 +308,6 @@ def update_base_model(*args): base_model.value = MODEL_CHOICES["llm"][0] elif TASK_MAP[task_dropdown.value] == "image-classification": base_model.value = MODEL_CHOICES["image-classification"][0] - elif TASK_MAP[task_dropdown.value] == "dreambooth": - base_model.value = MODEL_CHOICES["dreambooth"][0] elif TASK_MAP[task_dropdown.value] == "seq2seq": base_model.value = MODEL_CHOICES["seq2seq"][0] elif TASK_MAP[task_dropdown.value] == "tabular:classification": @@ -351,61 +342,33 @@ def start_training(b): if chat_template is not None: params_val = {k: v for k, v in params_val.items() if k != "chat_template"} - if TASK_MAP[task_dropdown.value] == "dreambooth": - prompt = params_val.get("prompt") - if prompt is None: - raise ValueError("Prompt is required for DreamBooth task") - if not isinstance(prompt, str): - raise ValueError("Prompt should be a string") - params_val = {k: v for k, v in params_val.items() if k != "prompt"} - else: - prompt = None - push_to_hub = params_val.get("push_to_hub", True) if "push_to_hub" in params_val: params_val = {k: v for k, v in params_val.items() if k != "push_to_hub"} - if TASK_MAP[task_dropdown.value] != "dreambooth": - config = { - "task": TASK_MAP[task_dropdown.value].split(":")[0], - "base_model": base_model.value, - "project_name": project_name.value, - "log": "tensorboard", - "backend": "local", - "data": { - "path": dataset_path.value, - "train_split": train_split_value, - "valid_split": valid_split_value, - "column_mapping": json.loads(col_mapping.value), - }, - "params": params_val, - "hub": { - "username": "${{HF_USERNAME}}", - "token": "${{HF_TOKEN}}", - "push_to_hub": push_to_hub, - }, - } - if TASK_MAP[task_dropdown.value].startswith("llm"): - config["data"]["chat_template"] = chat_template - if config["data"]["chat_template"] == "none": - config["data"]["chat_template"] = None - else: - config = { - "task": TASK_MAP[task_dropdown.value], - "base_model": base_model.value, - "project_name": project_name.value, - "backend": "local", - "data": { - "path": dataset_path.value, - "prompt": prompt, - }, - "params": params_val, - "hub": { - "username": "${HF_USERNAME}", - "token": "${HF_TOKEN}", - "push_to_hub": push_to_hub, - }, - } + config = { + "task": TASK_MAP[task_dropdown.value].split(":")[0], + "base_model": base_model.value, + "project_name": project_name.value, + "log": "tensorboard", + "backend": "local", + "data": { + "path": dataset_path.value, + "train_split": train_split_value, + "valid_split": valid_split_value, + "column_mapping": json.loads(col_mapping.value), + }, + "params": params_val, + "hub": { + "username": "${{HF_USERNAME}}", + "token": "${{HF_TOKEN}}", + "push_to_hub": push_to_hub, + }, + } + if TASK_MAP[task_dropdown.value].startswith("llm"): + config["data"]["chat_template"] = chat_template + if config["data"]["chat_template"] == "none": + config["data"]["chat_template"] = None with open("config.yml", "w") as f: yaml.dump(config, f) diff --git a/src/autotrain/app/models.py b/src/autotrain/app/models.py index 4f42b0f5ec..1d1f658113 100644 --- a/src/autotrain/app/models.py +++ b/src/autotrain/app/models.py @@ -166,59 +166,6 @@ def _fetch_image_object_detection_models(): return hub_models -def _fetch_dreambooth_models(): - hub_models1 = list( - list_models( - task="text-to-image", - sort="downloads", - direction=-1, - limit=100, - full=False, - filter=["diffusers:StableDiffusionXLPipeline"], - ) - ) - hub_models2 = list( - list_models( - task="text-to-image", - sort="downloads", - direction=-1, - limit=100, - full=False, - filter=["diffusers:StableDiffusionPipeline"], - ) - ) - hub_models = list(hub_models1) + list(hub_models2) - hub_models = get_sorted_models(hub_models) - - trending_models1 = list( - list_models( - task="text-to-image", - sort="likes7d", - direction=-1, - limit=30, - full=False, - filter=["diffusers:StableDiffusionXLPipeline"], - ) - ) - trending_models2 = list( - list_models( - task="text-to-image", - sort="likes7d", - direction=-1, - limit=30, - full=False, - filter=["diffusers:StableDiffusionPipeline"], - ) - ) - trending_models = list(trending_models1) + list(trending_models2) - if len(trending_models) > 0: - trending_models = get_sorted_models(trending_models) - hub_models = [m for m in hub_models if m not in trending_models] - hub_models = trending_models + hub_models - - return hub_models - - def _fetch_seq2seq_models(): hub_models = list( list_models( @@ -392,7 +339,6 @@ def fetch_models(): _mc["llm"] = _fetch_llm_models() _mc["image-classification"] = _fetch_image_classification_models() _mc["image-regression"] = _fetch_image_classification_models() - _mc["dreambooth"] = _fetch_dreambooth_models() _mc["seq2seq"] = _fetch_seq2seq_models() _mc["token-classification"] = _fetch_token_classification_models() _mc["text-regression"] = _fetch_text_classification_models() diff --git a/src/autotrain/app/params.py b/src/autotrain/app/params.py index 41b569b39a..a6f4addbc5 100644 --- a/src/autotrain/app/params.py +++ b/src/autotrain/app/params.py @@ -3,7 +3,6 @@ from typing import Optional from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams @@ -110,20 +109,6 @@ numerical_imputer="median", numeric_scaler="robust", ).model_dump() -PARAMS["dreambooth"] = DreamBoothTrainingParams( - prompt="", - vae_model="", - num_steps=500, - disable_gradient_checkpointing=False, - mixed_precision="fp16", - batch_size=1, - gradient_accumulation=4, - resolution=1024, - use_8bit_adam=False, - xformers=False, - train_text_encoder=False, - lr=1e-4, -).model_dump() PARAMS["token-classification"] = TokenClassificationParams( mixed_precision="fp16", log="tensorboard", @@ -187,7 +172,6 @@ class AppParams: _munge_params_img_reg(): Processes parameters for image regression task. _munge_params_img_obj_det(): Processes parameters for image object detection task. _munge_params_tabular(): Processes parameters for tabular data task. - _munge_params_dreambooth(): Processes parameters for DreamBooth training task. """ job_params_json: str @@ -218,8 +202,6 @@ def munge(self): return self._munge_params_img_obj_det() elif self.task.startswith("tabular"): return self._munge_params_tabular() - elif self.task == "dreambooth": - return self._munge_params_dreambooth() elif self.task.startswith("llm"): return self._munge_params_llm() elif self.task == "token-classification": @@ -506,17 +488,6 @@ def _munge_params_tabular(self): return TabularParams(**_params) - def _munge_params_dreambooth(self): - _params = self._munge_common_params() - _params["model"] = self.base_model - _params["image_path"] = self.data_path - - if "weight_decay" in _params: - _params["adam_weight_decay"] = _params["weight_decay"] - _params.pop("weight_decay") - - return DreamBoothTrainingParams(**_params) - def get_task_params(task, param_type): """ @@ -764,34 +735,5 @@ def get_task_params(task, param_type): "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} - if task == "dreambooth": - more_hidden_params = [ - "epochs", - "logging", - "bf16", - ] - if param_type == "basic": - more_hidden_params.extend( - [ - "prior_preservation", - "prior_loss_weight", - "seed", - "center_crop", - "train_text_encoder", - "disable_gradient_checkpointing", - "scale_lr", - "warmup_steps", - "num_cycles", - "lr_power", - "adam_beta1", - "adam_beta2", - "adam_weight_decay", - "adam_epsilon", - "max_grad_norm", - "pre_compute_text_embeddings", - "text_encoder_use_attention_mask", - ] - ) - task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} return task_params diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html index 437076e1c7..0ee5226c9d 100644 --- a/src/autotrain/app/templates/index.html +++ b/src/autotrain/app/templates/index.html @@ -76,10 +76,6 @@ fields = ['tokens', 'tags']; fieldNames = ['tokens', 'tags']; break; - case 'dreambooth': - fields = ['image']; - fieldNames = ['image']; - break; case 'image-classification': fields = ['image', 'label']; fieldNames = ['image', 'label']; @@ -132,8 +128,8 @@ function toggleValidationTab() { const task = taskSelect.value; - // Check if the selected task is DreamBooth or any LLM task - if (task === 'dreambooth' || task.includes('llm:')) { + // Check if the selected task is any LLM task + if (task.includes('llm:')) { validDataTab.style.display = 'none'; // Hide the tab } else { validDataTab.style.display = 'block'; // Show the tab @@ -222,7 +218,6 @@ - diff --git a/src/autotrain/app/ui_routes.py b/src/autotrain/app/ui_routes.py index ee1ebbf80a..78aa04b781 100644 --- a/src/autotrain/app/ui_routes.py +++ b/src/autotrain/app/ui_routes.py @@ -19,7 +19,6 @@ from autotrain.app.utils import get_running_jobs, get_user_and_orgs, kill_process_by_pid, token_verification from autotrain.dataset import ( AutoTrainDataset, - AutoTrainDreamboothDataset, AutoTrainImageClassificationDataset, AutoTrainImageRegressionDataset, AutoTrainObjectDetectionDataset, @@ -464,8 +463,6 @@ async def fetch_model_choices( hub_models = MODEL_CHOICE["sentence-transformers"] elif task == "image-classification": hub_models = MODEL_CHOICE["image-classification"] - elif task == "dreambooth": - hub_models = MODEL_CHOICE["dreambooth"] elif task == "seq2seq": hub_models = MODEL_CHOICE["seq2seq"] elif task == "tabular:classification": @@ -574,9 +571,6 @@ async def handle_form( status_code=400, detail="Please upload a dataset or choose a dataset from the Hugging Face Hub." ) - if len(hub_dataset) > 0 and task == "dreambooth": - raise HTTPException(status_code=400, detail="Dreambooth does not support Hugging Face Hub datasets.") - if len(hub_dataset) > 0: if not train_split: raise HTTPException(status_code=400, detail="Please enter a training split.") @@ -614,15 +608,6 @@ async def handle_form( percent_valid=None, # TODO: add to UI local=hardware.lower() == "local-ui", ) - elif task == "dreambooth": - dset = AutoTrainDreamboothDataset( - concept_images=data_files_training, - concept_name=params["prompt"], - token=token, - project_name=project_name, - username=autotrain_user, - local=hardware.lower() == "local-ui", - ) elif task.startswith("vlm:"): dset = AutoTrainVLMDataset( train_data=training_files[0], diff --git a/src/autotrain/backends/base.py b/src/autotrain/backends/base.py index 418a93a375..01aac4e30c 100644 --- a/src/autotrain/backends/base.py +++ b/src/autotrain/backends/base.py @@ -3,7 +3,6 @@ from typing import Union from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.generic.params import GenericParams from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -69,7 +68,7 @@ class BaseBackend: Attributes: params (Union[TextClassificationParams, ImageClassificationParams, LLMTrainingParams, - GenericParams, TabularParams, DreamBoothTrainingParams, Seq2SeqParams, + GenericParams, TabularParams, Seq2SeqParams, TokenClassificationParams, TextRegressionParams, ObjectDetectionParams, SentenceTransformersParams, ImageRegressionParams, VLMTrainingParams, ExtractiveQuestionAnsweringParams]): Training parameters. @@ -86,7 +85,6 @@ class BaseBackend: LLMTrainingParams, GenericParams, TabularParams, - DreamBoothTrainingParams, Seq2SeqParams, TokenClassificationParams, TextRegressionParams, @@ -123,8 +121,6 @@ def __post_init__(self): self.task_id = 26 elif isinstance(self.params, GenericParams): self.task_id = 27 - elif isinstance(self.params, DreamBoothTrainingParams): - self.task_id = 25 elif isinstance(self.params, Seq2SeqParams): self.task_id = 28 elif isinstance(self.params, ImageClassificationParams): @@ -161,10 +157,7 @@ def __post_init__(self): "TASK_ID": str(self.task_id), "PARAMS": json.dumps(self.params.model_dump_json()), } - if isinstance(self.params, DreamBoothTrainingParams): - self.env_vars["DATA_PATH"] = self.params.image_path - else: - self.env_vars["DATA_PATH"] = self.params.data_path + self.env_vars["DATA_PATH"] = self.params.data_path if not isinstance(self.params, GenericParams): self.env_vars["MODEL"] = self.params.model diff --git a/src/autotrain/backends/spaces.py b/src/autotrain/backends/spaces.py index 92ad2d39a9..a19976c510 100644 --- a/src/autotrain/backends/spaces.py +++ b/src/autotrain/backends/spaces.py @@ -3,7 +3,6 @@ from huggingface_hub import HfApi from autotrain.backends.base import BaseBackend -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.generic.params import GenericParams @@ -59,11 +58,7 @@ def _add_secrets(self, api, space_id): api.add_space_secret(repo_id=space_id, key="PROJECT_NAME", value=self.params.project_name) api.add_space_secret(repo_id=space_id, key="TASK_ID", value=str(self.task_id)) api.add_space_secret(repo_id=space_id, key="PARAMS", value=self.params.model_dump_json()) - - if isinstance(self.params, DreamBoothTrainingParams): - api.add_space_secret(repo_id=space_id, key="DATA_PATH", value=self.params.image_path) - else: - api.add_space_secret(repo_id=space_id, key="DATA_PATH", value=self.params.data_path) + api.add_space_secret(repo_id=space_id, key="DATA_PATH", value=self.params.data_path) if not isinstance(self.params, GenericParams): api.add_space_secret(repo_id=space_id, key="MODEL", value=self.params.model) diff --git a/src/autotrain/cli/autotrain.py b/src/autotrain/cli/autotrain.py index ce8c869e46..fcd85b9828 100644 --- a/src/autotrain/cli/autotrain.py +++ b/src/autotrain/cli/autotrain.py @@ -3,7 +3,6 @@ from autotrain import __version__, logger from autotrain.cli.run_api import RunAutoTrainAPICommand from autotrain.cli.run_app import RunAutoTrainAppCommand -from autotrain.cli.run_dreambooth import RunAutoTrainDreamboothCommand from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand @@ -35,7 +34,6 @@ def main(): RunAutoTrainAppCommand.register_subcommand(commands_parser) RunAutoTrainLLMCommand.register_subcommand(commands_parser) RunSetupCommand.register_subcommand(commands_parser) - RunAutoTrainDreamboothCommand.register_subcommand(commands_parser) RunAutoTrainAPICommand.register_subcommand(commands_parser) RunAutoTrainTextClassificationCommand.register_subcommand(commands_parser) RunAutoTrainImageClassificationCommand.register_subcommand(commands_parser) diff --git a/src/autotrain/cli/run_dreambooth.py b/src/autotrain/cli/run_dreambooth.py deleted file mode 100644 index 31589d4b06..0000000000 --- a/src/autotrain/cli/run_dreambooth.py +++ /dev/null @@ -1,392 +0,0 @@ -import glob -import os -from argparse import ArgumentParser - -from autotrain import logger -from autotrain.cli import BaseAutoTrainCommand -from autotrain.cli.utils import common_args -from autotrain.project import AutoTrainProject -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams -from autotrain.trainers.dreambooth.utils import VALID_IMAGE_EXTENSIONS, XL_MODELS - - -def count_images(directory): - files_grabbed = [] - for files in VALID_IMAGE_EXTENSIONS: - files_grabbed.extend(glob.glob(os.path.join(directory, "*" + files))) - return len(files_grabbed) - - -def run_dreambooth_command_factory(args): - return RunAutoTrainDreamboothCommand(args) - - -class RunAutoTrainDreamboothCommand(BaseAutoTrainCommand): - @staticmethod - def register_subcommand(parser: ArgumentParser): - arg_list = [ - { - "arg": "--revision", - "help": "Model revision to use for training", - "required": False, - "type": str, - }, - { - "arg": "--tokenizer", - "help": "Tokenizer to use for training", - "required": False, - "type": str, - }, - { - "arg": "--image-path", - "help": "Path to the images", - "required": True, - "type": str, - }, - { - "arg": "--class-image-path", - "help": "Path to the class images", - "required": False, - "type": str, - }, - { - "arg": "--prompt", - "help": "Instance prompt", - "required": True, - "type": str, - }, - { - "arg": "--class-prompt", - "help": "Class prompt", - "required": False, - "type": str, - }, - { - "arg": "--num-class-images", - "help": "Number of class images", - "required": False, - "default": 100, - "type": int, - }, - { - "arg": "--class-labels-conditioning", - "help": "Class labels conditioning", - "required": False, - "type": str, - }, - { - "arg": "--prior-preservation", - "help": "With prior preservation", - "required": False, - "action": "store_true", - }, - { - "arg": "--prior-loss-weight", - "help": "Prior loss weight", - "required": False, - "default": 1.0, - "type": float, - }, - { - "arg": "--resolution", - "help": "Resolution", - "required": True, - "type": int, - }, - { - "arg": "--center-crop", - "help": "Center crop", - "required": False, - "action": "store_true", - }, - { - "arg": "--train-text-encoder", - "help": "Train text encoder", - "required": False, - "action": "store_true", - }, - { - "arg": "--sample-batch-size", - "help": "Sample batch size", - "required": False, - "default": 4, - "type": int, - }, - { - "arg": "--num-steps", - "help": "Max train steps", - "required": False, - "type": int, - }, - { - "arg": "--checkpointing-steps", - "help": "Checkpointing steps", - "required": False, - "default": 100000, - "type": int, - }, - { - "arg": "--resume-from-checkpoint", - "help": "Resume from checkpoint", - "required": False, - "type": str, - }, - { - "arg": "--scale-lr", - "help": "Scale learning rate", - "required": False, - "action": "store_true", - }, - { - "arg": "--scheduler", - "help": "Learning rate scheduler", - "required": False, - "default": "constant", - }, - { - "arg": "--warmup-steps", - "help": "Learning rate warmup steps", - "required": False, - "default": 0, - "type": int, - }, - { - "arg": "--num-cycles", - "help": "Learning rate num cycles", - "required": False, - "default": 1, - "type": int, - }, - { - "arg": "--lr-power", - "help": "Learning rate power", - "required": False, - "default": 1.0, - "type": float, - }, - { - "arg": "--dataloader-num-workers", - "help": "Dataloader num workers", - "required": False, - "default": 0, - "type": int, - }, - { - "arg": "--use-8bit-adam", - "help": "Use 8bit adam", - "required": False, - "action": "store_true", - }, - { - "arg": "--adam-beta1", - "help": "Adam beta 1", - "required": False, - "default": 0.9, - "type": float, - }, - { - "arg": "--adam-beta2", - "help": "Adam beta 2", - "required": False, - "default": 0.999, - "type": float, - }, - { - "arg": "--adam-weight-decay", - "help": "Adam weight decay", - "required": False, - "default": 1e-2, - "type": float, - }, - { - "arg": "--adam-epsilon", - "help": "Adam epsilon", - "required": False, - "default": 1e-8, - "type": float, - }, - { - "arg": "--max-grad-norm", - "help": "Max grad norm", - "required": False, - "default": 1.0, - "type": float, - }, - { - "arg": "--allow-tf32", - "help": "Allow TF32", - "required": False, - "action": "store_true", - }, - { - "arg": "--prior-generation-precision", - "help": "Prior generation precision", - "required": False, - "type": str, - }, - { - "arg": "--local-rank", - "help": "Local rank", - "required": False, - "default": -1, - "type": int, - }, - { - "arg": "--xformers", - "help": "Enable xformers memory efficient attention", - "required": False, - "action": "store_true", - }, - { - "arg": "--pre-compute-text-embeddings", - "help": "Pre compute text embeddings", - "required": False, - "action": "store_true", - }, - { - "arg": "--tokenizer-max-length", - "help": "Tokenizer max length", - "required": False, - "type": int, - }, - { - "arg": "--text-encoder-use-attention-mask", - "help": "Text encoder use attention mask", - "required": False, - "action": "store_true", - }, - { - "arg": "--rank", - "help": "Rank", - "required": False, - "default": 4, - "type": int, - }, - { - "arg": "--xl", - "help": "XL", - "required": False, - "action": "store_true", - }, - { - "arg": "--mixed-precision", - "help": "mixed precision, fp16, bf16, none", - "required": False, - "type": str, - "default": "none", - }, - { - "arg": "--validation-prompt", - "help": "Validation prompt", - "required": False, - "type": str, - }, - { - "arg": "--num-validation-images", - "help": "Number of validation images", - "required": False, - "default": 4, - "type": int, - }, - { - "arg": "--validation-epochs", - "help": "Validation epochs", - "required": False, - "default": 50, - "type": int, - }, - { - "arg": "--checkpoints-total-limit", - "help": "Checkpoints total limit", - "required": False, - "type": int, - }, - { - "arg": "--validation-images", - "help": "Validation images", - "required": False, - "type": str, - }, - { - "arg": "--logging", - "help": "Logging using tensorboard", - "required": False, - "action": "store_true", - }, - ] - - arg_list = common_args() + arg_list - run_dreambooth_parser = parser.add_parser("dreambooth", description="✨ Run AutoTrain DreamBooth Training") - for arg in arg_list: - if "action" in arg: - run_dreambooth_parser.add_argument( - arg["arg"], - help=arg["help"], - required=arg.get("required", False), - action=arg.get("action"), - default=arg.get("default"), - ) - else: - run_dreambooth_parser.add_argument( - arg["arg"], - help=arg["help"], - required=arg.get("required", False), - type=arg.get("type"), - default=arg.get("default"), - choices=arg.get("choices"), - ) - run_dreambooth_parser.set_defaults(func=run_dreambooth_command_factory) - - def __init__(self, args): - self.args = args - - store_true_arg_names = [ - "center_crop", - "train_text_encoder", - "disable_gradient_checkpointing", - "scale_lr", - "use_8bit_adam", - "allow_tf32", - "xformers", - "pre_compute_text_embeddings", - "text_encoder_use_attention_mask", - "xl", - "push_to_hub", - "logging", - "prior_preservation", - ] - - for arg_name in store_true_arg_names: - if getattr(self.args, arg_name) is None: - setattr(self.args, arg_name, False) - - # check if self.args.image_path is a directory with images - if not os.path.isdir(self.args.image_path): - raise ValueError("❌ Please specify a valid image directory") - - # count the number of images in the directory. valid images are .jpg, .jpeg, .png - num_images = count_images(self.args.image_path) - if num_images == 0: - raise ValueError("❌ Please specify a valid image directory") - - if self.args.push_to_hub: - if self.args.username is None: - raise ValueError("❌ Please specify a username to push to hub") - - if self.args.model in XL_MODELS: - self.args.xl = True - - if self.args.backend.startswith("spaces") or self.args.backend.startswith("ep-"): - if not self.args.push_to_hub: - raise ValueError("Push to hub must be specified for spaces backend") - if self.args.username is None: - raise ValueError("Username must be specified for spaces backend") - if self.args.token is None: - raise ValueError("Token must be specified for spaces backend") - - def run(self): - logger.info("Running DreamBooth Training") - params = DreamBoothTrainingParams(**vars(self.args)) - project = AutoTrainProject(params=params, backend=self.args.backend, process=True) - job_id = project.create() - logger.info(f"Job ID: {job_id}") diff --git a/src/autotrain/client.py b/src/autotrain/client.py index f204bf36c8..ea3734c0e9 100644 --- a/src/autotrain/client.py +++ b/src/autotrain/client.py @@ -76,20 +76,6 @@ "numeric_scaler": "robust", } -PARAMS["dreambooth"] = { - "vae_model": "", - "num_steps": 500, - "disable_gradient_checkpointing": False, - "mixed_precision": "fp16", - "batch_size": 1, - "gradient_accumulation": 4, - "resolution": 1024, - "use_8bit_adam": False, - "xformers": False, - "train_text_encoder": False, - "lr": 1e-4, -} - PARAMS["token-classification"] = { "mixed_precision": "fp16", "log": "tensorboard", @@ -163,7 +149,6 @@ DEFAULT_COLUMN_MAPPING["seq2seq"] = {"text_column": "text", "target_column": "target"} DEFAULT_COLUMN_MAPPING["text-regression"] = {"text_column": "text", "target_column": "target"} DEFAULT_COLUMN_MAPPING["token-classification"] = {"text_column": "tokens", "target_column": "tags"} -DEFAULT_COLUMN_MAPPING["dreambooth"] = {"default": "default"} DEFAULT_COLUMN_MAPPING["image-classification"] = {"image_column": "image", "target_column": "label"} DEFAULT_COLUMN_MAPPING["image-regression"] = {"image_column": "image", "target_column": "target"} DEFAULT_COLUMN_MAPPING["image-object-detection"] = {"image_column": "image", "objects_column": "objects"} @@ -273,9 +258,6 @@ def create( if missing_cols: raise ValueError(f"Missing columns in column_mapping: {missing_cols}") - if task == "dreambooth" and params.get("prompt") is None: - raise ValueError("Please provide a prompt for the DreamBooth task") - data = { "project_name": project_name, "task": task, diff --git a/src/autotrain/commands.py b/src/autotrain/commands.py index c3dfd34170..23182c26c5 100644 --- a/src/autotrain/commands.py +++ b/src/autotrain/commands.py @@ -5,7 +5,6 @@ from autotrain import logger from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.generic.params import GenericParams from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -103,7 +102,6 @@ def launch_command(params): Args: params (object): An instance of one of the training parameter classes. This can be one of the following: - LLMTrainingParams - - DreamBoothTrainingParams - GenericParams - TabularParams - TextClassificationParams @@ -152,14 +150,7 @@ def launch_command(params): os.path.join(params.project_name, "training_params.json"), ] ) - elif isinstance(params, DreamBoothTrainingParams): - cmd = [ - "python", - "-m", - "autotrain.trainers.dreambooth", - "--training_config", - os.path.join(params.project_name, "training_params.json"), - ] + elif isinstance(params, GenericParams): cmd = [ "python", diff --git a/src/autotrain/dataset.py b/src/autotrain/dataset.py index ab69abbdc9..5da820985f 100644 --- a/src/autotrain/dataset.py +++ b/src/autotrain/dataset.py @@ -3,11 +3,10 @@ import uuid import zipfile from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import pandas as pd -from autotrain.preprocessor.dreambooth import DreamboothPreprocessor from autotrain.preprocessor.tabular import ( TabularBinaryClassificationPreprocessor, TabularMultiClassClassificationPreprocessor, @@ -67,63 +66,6 @@ def remove_non_image_files(folder): remove_non_image_files(os.path.join(root, subfolder)) -@dataclass -class AutoTrainDreamboothDataset: - """ - AutoTrainDreamboothDataset prepares dataset for Dreambooth task. - - Attributes: - concept_images (List[Any]): A list of images related to the concept. - concept_name (str): The name of the concept. - token (str): The token associated with the concept. - project_name (str): The name of the project. - username (Optional[str]): The username of the person associated with the project. Defaults to None. - local (bool): A flag indicating whether the dataset is local. Defaults to False. - - Methods: - __str__() -> str: - Returns a string representation of the dataset, including the project name and task. - - __post_init__(): - Initializes the task attribute to "dreambooth". - - num_samples() -> int: - Returns the number of samples in the concept_images list. - - prepare(): - Prepares the dataset using the DreamboothPreprocessor and returns the preprocessed data. - """ - - concept_images: List[Any] - concept_name: str - token: str - project_name: str - username: Optional[str] = None - local: bool = False - - def __str__(self) -> str: - info = f"Dataset: {self.project_name} ({self.task})\n" - return info - - def __post_init__(self): - self.task = "dreambooth" - - @property - def num_samples(self): - return len(self.concept_images) - - def prepare(self): - preprocessor = DreamboothPreprocessor( - concept_images=self.concept_images, - concept_name=self.concept_name, - token=self.token, - project_name=self.project_name, - username=self.username, - local=self.local, - ) - return preprocessor.prepare() - - @dataclass class AutoTrainImageClassificationDataset: """ diff --git a/src/autotrain/params.py b/src/autotrain/params.py index adf49bb1de..5ab8f8f569 100644 --- a/src/autotrain/params.py +++ b/src/autotrain/params.py @@ -1,5 +1,4 @@ from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams diff --git a/src/autotrain/parser.py b/src/autotrain/parser.py index f1f7e256d6..fd7327e904 100644 --- a/src/autotrain/parser.py +++ b/src/autotrain/parser.py @@ -7,7 +7,6 @@ from autotrain import logger from autotrain.project import ( AutoTrainProject, - dreambooth_munge_data, ext_qa_munge_data, img_clf_munge_data, img_obj_detect_munge_data, @@ -23,7 +22,6 @@ ) from autotrain.tasks import TASKS from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams @@ -76,7 +74,6 @@ def __post_init__(self): self.task_param_map = { "lm_training": LLMTrainingParams, - "dreambooth": DreamBoothTrainingParams, "image_binary_classification": ImageClassificationParams, "image_multi_class_classification": ImageClassificationParams, "image_object_detection": ObjectDetectionParams, @@ -93,7 +90,6 @@ def __post_init__(self): } self.munge_data_map = { "lm_training": llm_munge_data, - "dreambooth": dreambooth_munge_data, "tabular": tabular_munge_data, "seq2seq": seq2seq_munge_data, "image_multi_class_classification": img_clf_munge_data, @@ -113,7 +109,6 @@ def __post_init__(self): "llm-generic": "lm_training", "llm-dpo": "lm_training", "llm-reward": "lm_training", - "dreambooth": "dreambooth", "image_binary_classification": "image_multi_class_classification", "image-binary-classification": "image_multi_class_classification", "image_classification": "image_multi_class_classification", @@ -178,11 +173,7 @@ def _parse_config(self): "project_name": self.config["project_name"], } - if self.task == "dreambooth": - params["image_path"] = self.config["data"]["path"] - params["prompt"] = self.config["data"]["prompt"] - else: - params["data_path"] = self.config["data"]["path"] + params["data_path"] = self.config["data"]["path"] if self.task == "lm_training": params["chat_template"] = self.config["data"]["chat_template"] @@ -199,12 +190,11 @@ def _parse_config(self): if self.task == "vlm": params["trainer"] = self.config["task"].split(":")[1] - if self.task != "dreambooth": - for k, v in self.config["data"]["column_mapping"].items(): - params[k] = v - params["train_split"] = self.config["data"]["train_split"] - params["valid_split"] = self.config["data"]["valid_split"] - params["log"] = self.config["log"] + for k, v in self.config["data"]["column_mapping"].items(): + params[k] = v + params["train_split"] = self.config["data"]["train_split"] + params["valid_split"] = self.config["data"]["valid_split"] + params["log"] = self.config["log"] if "hub" in self.config: params["username"] = self.config["hub"]["username"] diff --git a/src/autotrain/preprocessor/dreambooth.py b/src/autotrain/preprocessor/dreambooth.py deleted file mode 100644 index e6beeb1401..0000000000 --- a/src/autotrain/preprocessor/dreambooth.py +++ /dev/null @@ -1,129 +0,0 @@ -import io -import json -import os -from dataclasses import dataclass -from typing import Any, List - -from huggingface_hub import HfApi, create_repo - -from autotrain import logger - - -@dataclass -class DreamboothPreprocessor: - """ - DreamboothPreprocessor is a class responsible for preparing concept images and prompts data for DreamBooth Task. - - Attributes: - concept_images (List[Any]): A list of concept images to be processed. - concept_name (str): The name of the concept. - username (str): The username of the person creating the project. - project_name (str): The name of the project. - token (str): The authentication token for accessing the repository. - local (bool): A flag indicating whether the processing is local or remote. - - Methods: - __post_init__(): Initializes the repository name and creates a remote repository if not local. - _upload_concept_images(file, api): Uploads a concept image to the remote repository. - _upload_concept_prompts(api): Uploads the concept prompts to the remote repository. - _save_concept_images(file): Saves a concept image locally. - _save_concept_prompts(): Saves the concept prompts locally. - prepare(): Prepares the concept images and prompts by either saving them locally or uploading them to a remote repository. - """ - - concept_images: List[Any] - concept_name: str - username: str - project_name: str - token: str - local: bool - - def __post_init__(self): - self.repo_name = f"{self.username}/autotrain-data-{self.project_name}" - if not self.local: - try: - create_repo( - repo_id=self.repo_name, - repo_type="dataset", - token=self.token, - private=True, - exist_ok=False, - ) - except Exception: - logger.error("Error creating repo") - raise ValueError("Error creating repo") - - def _upload_concept_images(self, file, api): - logger.info(f"Uploading {file} to concept1") - if isinstance(file, str): - path_in_repo = f"concept1/{file.split('/')[-1]}" - else: - path_in_repo = f"concept1/{file.filename.split('/')[-1]}" - - api.upload_file( - path_or_fileobj=file if isinstance(file, str) else file.file.read(), - path_in_repo=path_in_repo, - repo_id=self.repo_name, - repo_type="dataset", - token=self.token, - ) - - def _upload_concept_prompts(self, api): - _prompts = {} - _prompts["concept1"] = self.concept_name - - prompts = json.dumps(_prompts) - prompts = prompts.encode("utf-8") - prompts = io.BytesIO(prompts) - api.upload_file( - path_or_fileobj=prompts, - path_in_repo="prompts.json", - repo_id=self.repo_name, - repo_type="dataset", - token=self.token, - ) - - def _save_concept_images(self, file): - logger.info("Saving concept images") - logger.info(file) - if isinstance(file, str): - _file = file - path = f"{self.project_name}/autotrain-data/concept1/{_file.split('/')[-1]}" - - else: - _file = file.file.read() - path = f"{self.project_name}/autotrain-data/concept1/{file.filename.split('/')[-1]}" - - os.makedirs(os.path.dirname(path), exist_ok=True) - # if file is a string, copy the file to the new location - if isinstance(file, str): - with open(_file, "rb") as f: - with open(path, "wb") as f2: - f2.write(f.read()) - else: - with open(path, "wb") as f: - f.write(_file) - - def _save_concept_prompts(self): - _prompts = {} - _prompts["concept1"] = self.concept_name - path = f"{self.project_name}/autotrain-data/prompts.json" - with open(path, "w", encoding="utf-8") as f: - json.dump(_prompts, f) - - def prepare(self): - api = HfApi(token=self.token) - for _file in self.concept_images: - if self.local: - self._save_concept_images(_file) - else: - self._upload_concept_images(_file, api) - - if self.local: - self._save_concept_prompts() - else: - self._upload_concept_prompts(api) - - if self.local: - return f"{self.project_name}/autotrain-data" - return f"{self.username}/autotrain-data-{self.project_name}" diff --git a/src/autotrain/project.py b/src/autotrain/project.py index 528a4aa7cd..86d5933f02 100644 --- a/src/autotrain/project.py +++ b/src/autotrain/project.py @@ -14,14 +14,12 @@ from autotrain.backends.spaces import SpaceRunner from autotrain.dataset import ( AutoTrainDataset, - AutoTrainDreamboothDataset, AutoTrainImageClassificationDataset, AutoTrainImageRegressionDataset, AutoTrainObjectDetectionDataset, AutoTrainVLMDataset, ) from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams from autotrain.trainers.image_regression.params import ImageRegressionParams @@ -289,22 +287,6 @@ def img_clf_munge_data(params, local): return params -def dreambooth_munge_data(params, local): - # check if params.image_path is a directory - if os.path.isdir(params.image_path): - training_data = [os.path.join(params.image_path, f) for f in os.listdir(params.image_path)] - dset = AutoTrainDreamboothDataset( - concept_images=training_data, - concept_name=params.prompt, - token=params.token, - project_name=params.project_name, - username=params.username, - local=local, - ) - params.image_path = dset.prepare() - return params - - def img_obj_detect_munge_data(params, local): train_data_path = f"{params.data_path}/{params.train_split}" if params.valid_split is not None: @@ -469,7 +451,6 @@ class AutoTrainProject: LLMTrainingParams, TextClassificationParams, TabularParams, - DreamBoothTrainingParams, Seq2SeqParams, ImageClassificationParams, TextRegressionParams, @@ -513,7 +494,6 @@ class AutoTrainProject: LLMTrainingParams, TextClassificationParams, TabularParams, - DreamBoothTrainingParams, Seq2SeqParams, ImageClassificationParams, TextRegressionParams, @@ -535,8 +515,6 @@ def __post_init__(self): def _process_params_data(self): if isinstance(self.params, LLMTrainingParams): return llm_munge_data(self.params, self.local) - elif isinstance(self.params, DreamBoothTrainingParams): - return dreambooth_munge_data(self.params, self.local) elif isinstance(self.params, ExtractiveQuestionAnsweringParams): return ext_qa_munge_data(self.params, self.local) elif isinstance(self.params, ImageClassificationParams): diff --git a/src/autotrain/tasks.py b/src/autotrain/tasks.py index 5a9bc0d049..05c1fed942 100644 --- a/src/autotrain/tasks.py +++ b/src/autotrain/tasks.py @@ -18,7 +18,6 @@ "image_multi_class_classification": 18, "image_single_column_regression": 24, "image_object_detection": 29, - "dreambooth": 25, } TABULAR_TASKS = { diff --git a/src/autotrain/trainers/dreambooth/__init__.py b/src/autotrain/trainers/dreambooth/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/autotrain/trainers/dreambooth/__main__.py b/src/autotrain/trainers/dreambooth/__main__.py deleted file mode 100644 index 8e2764d45e..0000000000 --- a/src/autotrain/trainers/dreambooth/__main__.py +++ /dev/null @@ -1,252 +0,0 @@ -import argparse -import json -import os - -from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya -from huggingface_hub import create_repo, snapshot_download, upload_folder -from safetensors.torch import load_file, save_file - -from autotrain import logger -from autotrain.trainers.common import monitor, pause_space, remove_autotrain_data -from autotrain.trainers.dreambooth import utils -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams - - -def parse_args(): - # get training_config.json from the end user - parser = argparse.ArgumentParser() - parser.add_argument("--training_config", type=str, required=True) - return parser.parse_args() - - -@monitor -def train(config): - if isinstance(config, dict): - config = DreamBoothTrainingParams(**config) - config.prompt = str(config.prompt).strip() - - if config.model in utils.XL_MODELS: - config.xl = True - try: - snapshot_download( - repo_id=config.image_path, - local_dir=config.project_name, - token=config.token, - repo_type="dataset", - ) - config.image_path = os.path.join(config.project_name, "concept1") - except Exception as e: - logger.warning(f"Failed to download dataset: {e}") - pass - if config.image_path == f"{config.project_name}/autotrain-data": - config.image_path = os.path.join(config.image_path, "concept1") - - if config.vae_model is not None: - if config.vae_model.strip() == "": - config.vae_model = None - - if config.xl: - from autotrain.trainers.dreambooth.train_xl import main - - class Args: - pretrained_model_name_or_path = config.model - pretrained_vae_model_name_or_path = config.vae_model - revision = config.revision - variant = None - dataset_name = None - dataset_config_name = None - instance_data_dir = config.image_path - cache_dir = None - image_column = "image" - caption_column = None - repeats = 1 - class_data_dir = config.class_image_path - instance_prompt = config.prompt - class_prompt = config.class_prompt - validation_prompt = None - num_validation_images = 4 - validation_epochs = 50 - with_prior_preservation = config.prior_preservation - num_class_images = config.num_class_images - output_dir = config.project_name - seed = config.seed - resolution = config.resolution - center_crop = config.center_crop - train_text_encoder = config.train_text_encoder - train_batch_size = config.batch_size - sample_batch_size = config.sample_batch_size - num_train_epochs = config.epochs - max_train_steps = config.num_steps - checkpointing_steps = config.checkpointing_steps - checkpoints_total_limit = None - resume_from_checkpoint = config.resume_from_checkpoint - gradient_accumulation_steps = config.gradient_accumulation - gradient_checkpointing = not config.disable_gradient_checkpointing - learning_rate = config.lr - text_encoder_lr = 5e-6 - scale_lr = config.scale_lr - lr_scheduler = config.scheduler - snr_gamma = None - lr_warmup_steps = config.warmup_steps - lr_num_cycles = config.num_cycles - lr_power = config.lr_power - dataloader_num_workers = config.dataloader_num_workers - optimizer = "AdamW" - use_8bit_adam = config.use_8bit_adam - adam_beta1 = config.adam_beta1 - adam_beta2 = config.adam_beta2 - prodigy_beta3 = None - prodigy_decouple = True - adam_weight_decay = config.adam_weight_decay - adam_weight_decay_text_encoder = 1e-3 - adam_epsilon = config.adam_epsilon - prodigy_use_bias_correction = True - prodigy_safeguard_warmup = True - max_grad_norm = config.max_grad_norm - push_to_hub = config.push_to_hub - hub_token = config.token - hub_model_id = f"{config.username}/{config.project_name}" - logging_dir = os.path.join(config.project_name, "logs") - allow_tf32 = config.allow_tf32 - report_to = "tensorboard" if config.logging else None - mixed_precision = config.mixed_precision - prior_generation_precision = config.prior_generation_precision - local_rank = config.local_rank - enable_xformers_memory_efficient_attention = config.xformers - rank = config.rank - do_edm_style_training = False - random_flip = False - use_dora = False - - _args = Args() - main(_args) - else: - from autotrain.trainers.dreambooth.train import main - - class Args: - pretrained_model_name_or_path = config.model - pretrained_vae_model_name_or_path = config.vae_model - revision = config.revision - variant = None - tokenizer_name = None - instance_data_dir = config.image_path - class_data_dir = config.class_image_path - instance_prompt = config.prompt - class_prompt = config.class_prompt - validation_prompt = None - num_validation_images = 4 - validation_epochs = 50 - with_prior_preservation = config.prior_preservation - num_class_images = config.num_class_images - output_dir = config.project_name - seed = config.seed - resolution = config.resolution - center_crop = config.center_crop - train_text_encoder = config.train_text_encoder - train_batch_size = config.batch_size - sample_batch_size = config.sample_batch_size - max_train_steps = config.num_steps - checkpointing_steps = config.checkpointing_steps - checkpoints_total_limit = None - resume_from_checkpoint = config.resume_from_checkpoint - gradient_accumulation_steps = config.gradient_accumulation - gradient_checkpointing = not config.disable_gradient_checkpointing - learning_rate = config.lr - scale_lr = config.scale_lr - lr_scheduler = config.scheduler - lr_warmup_steps = config.warmup_steps - lr_num_cycles = config.num_cycles - lr_power = config.lr_power - dataloader_num_workers = config.dataloader_num_workers - use_8bit_adam = config.use_8bit_adam - adam_beta1 = config.adam_beta1 - adam_beta2 = config.adam_beta2 - adam_weight_decay = config.adam_weight_decay - adam_epsilon = config.adam_epsilon - max_grad_norm = config.max_grad_norm - push_to_hub = config.push_to_hub - hub_token = config.token - hub_model_id = f"{config.username}/{config.project_name}" - logging_dir = os.path.join(config.project_name, "logs") - allow_tf32 = config.allow_tf32 - report_to = "tensorboard" if config.logging else None - mixed_precision = config.mixed_precision - prior_generation_precision = config.prior_generation_precision - local_rank = config.local_rank - enable_xformers_memory_efficient_attention = config.xformers - pre_compute_text_embeddings = config.pre_compute_text_embeddings - tokenizer_max_length = config.tokenizer_max_length - text_encoder_use_attention_mask = config.text_encoder_use_attention_mask - validation_images = None - class_labels_conditioning = config.class_labels_conditioning - rank = config.rank - - _args = Args() - main(_args) - - if os.path.exists(f"{config.project_name}/training_params.json"): - training_params = json.load(open(f"{config.project_name}/training_params.json")) - if "token" in training_params: - training_params.pop("token") - json.dump( - training_params, - open(f"{config.project_name}/training_params.json", "w"), - ) - - # add config.prompt as a text file in the output directory - with open(f"{config.project_name}/prompt.txt", "w") as f: - f.write(config.prompt) - - try: - logger.info("Converting model to Kohya format...") - lora_state_dict = load_file(f"{config.project_name}/pytorch_lora_weights.safetensors") - peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) - kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) - save_file(kohya_state_dict, f"{config.project_name}/pytorch_lora_weights_kohya.safetensors") - except Exception as e: - logger.warning(e) - logger.warning("Failed to convert model to Kohya format, skipping...") - - if config.push_to_hub: - remove_autotrain_data(config) - - repo_id = create_repo( - repo_id=f"{config.username}/{config.project_name}", - exist_ok=True, - private=True, - token=config.token, - ).repo_id - if config.xl: - utils.save_model_card_xl( - repo_id, - base_model=config.model, - train_text_encoder=config.train_text_encoder, - instance_prompt=config.prompt, - vae_path=config.vae_model, - repo_folder=config.project_name, - ) - else: - utils.save_model_card( - repo_id, - base_model=config.model, - train_text_encoder=config.train_text_encoder, - instance_prompt=config.prompt, - repo_folder=config.project_name, - ) - - upload_folder( - repo_id=repo_id, - folder_path=config.project_name, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - token=config.token, - ) - - pause_space(config) - - -if __name__ == "__main__": - args = parse_args() - training_config = json.load(open(args.training_config)) - config = DreamBoothTrainingParams(**training_config) - train(config) diff --git a/src/autotrain/trainers/dreambooth/datasets.py b/src/autotrain/trainers/dreambooth/datasets.py deleted file mode 100644 index 8121094ba8..0000000000 --- a/src/autotrain/trainers/dreambooth/datasets.py +++ /dev/null @@ -1,236 +0,0 @@ -from pathlib import Path - -import torch -from PIL import Image -from PIL.ImageOps import exif_transpose -from torch.utils.data import Dataset -from torchvision import transforms - - -class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." - - def __init__(self, prompt, num_samples): - self.prompt = prompt - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["index"] = index - return example - - -class DreamBoothDatasetXL(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images. - """ - - def __init__( - self, - instance_data_root, - class_data_root=None, - class_num=None, - size=1024, - center_crop=False, - ): - self.size = size - self.center_crop = center_crop - - self.instance_data_root = Path(instance_data_root) - if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") - - self.instance_images_path = list(Path(instance_data_root).iterdir()) - self.num_instance_images = len(self.instance_images_path) - self._length = self.num_instance_images - - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - if class_num is not None: - self.num_class_images = min(len(self.class_images_path), class_num) - else: - self.num_class_images = len(self.class_images_path) - self._length = max(self.num_class_images, self.num_instance_images) - else: - self.class_data_root = None - - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return self._length - - def __getitem__(self, index): - example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) - instance_image = exif_transpose(instance_image) - - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - example["instance_images"] = self.image_transforms(instance_image) - - if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) - class_image = exif_transpose(class_image) - - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) - - return example - - -class DreamBoothDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and tokenizes prompts. - """ - - def __init__(self, config, tokenizers, encoder_hidden_states, instance_prompt_encoder_hidden_states): - self.config = config - self.tokenizer = tokenizers[0] - self.size = self.config.resolution - self.center_crop = self.config.center_crop - self.tokenizer_max_length = self.config.tokenizer_max_length - self.instance_data_root = Path(self.config.image_path) - self.instance_prompt = self.config.prompt - self.class_data_root = Path(self.config.class_image_path) if self.config.prior_preservation else None - self.class_prompt = self.config.class_prompt - self.class_num = self.config.num_class_images - - self.encoder_hidden_states = encoder_hidden_states - self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states - - if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") - - self.instance_images_path = list(Path(self.instance_data_root).iterdir()) - - self.num_instance_images = len(self.instance_images_path) - self._length = self.num_instance_images - - if self.class_data_root is not None: - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - if self.class_num is not None: - self.num_class_images = min(len(self.class_images_path), self.class_num) - else: - self.num_class_images = len(self.class_images_path) - self._length = max(self.num_class_images, self.num_instance_images) - else: - self.class_data_root = None - - self.image_transforms = transforms.Compose( - [ - transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(self.size) if self.center_crop else transforms.RandomCrop(self.size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return self._length - - def _tokenize_prompt(self, tokenizer, prompt, tokenizer_max_length=None): - # this function is here to avoid cyclic import issues - if tokenizer_max_length is not None: - max_length = tokenizer_max_length - else: - max_length = tokenizer.model_max_length - - text_inputs = tokenizer( - prompt, - truncation=True, - padding="max_length", - max_length=max_length, - return_tensors="pt", - ) - - return text_inputs - - def __getitem__(self, index): - example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) - instance_image = exif_transpose(instance_image) - - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - example["instance_images"] = self.image_transforms(instance_image) - - if not self.config.xl: - if self.encoder_hidden_states is not None: - example["instance_prompt_ids"] = self.encoder_hidden_states - else: - text_inputs = self._tokenize_prompt( - self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length - ) - example["instance_prompt_ids"] = text_inputs.input_ids - example["instance_attention_mask"] = text_inputs.attention_mask - - if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) - class_image = exif_transpose(class_image) - - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) - - if not self.config.xl: - if self.instance_prompt_encoder_hidden_states is not None: - example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states - else: - class_text_inputs = self._tokenize_prompt( - self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length - ) - example["class_prompt_ids"] = class_text_inputs.input_ids - example["class_attention_mask"] = class_text_inputs.attention_mask - - return example - - -def collate_fn(examples, config): - pixel_values = [example["instance_images"] for example in examples] - - if not config.xl: - has_attention_mask = "instance_attention_mask" in examples[0] - input_ids = [example["instance_prompt_ids"] for example in examples] - - if has_attention_mask: - attention_mask = [example["instance_attention_mask"] for example in examples] - - if config.prior_preservation: - pixel_values += [example["class_images"] for example in examples] - if not config.xl: - input_ids += [example["class_prompt_ids"] for example in examples] - if has_attention_mask: - attention_mask += [example["class_attention_mask"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - batch = { - "pixel_values": pixel_values, - } - - if not config.xl: - input_ids = torch.cat(input_ids, dim=0) - batch["input_ids"] = input_ids - if has_attention_mask: - # attention_mask = torch.cat(attention_mask, dim=0) - batch["attention_mask"] = attention_mask - - return batch diff --git a/src/autotrain/trainers/dreambooth/params.py b/src/autotrain/trainers/dreambooth/params.py deleted file mode 100644 index 6109b85976..0000000000 --- a/src/autotrain/trainers/dreambooth/params.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import Optional - -from pydantic import Field - -from autotrain.trainers.common import AutoTrainParams - - -class DreamBoothTrainingParams(AutoTrainParams): - """ - DreamBoothTrainingParams - - Attributes: - model (str): Name of the model to be used for training. - vae_model (Optional[str]): Name of the VAE model to be used, if any. - revision (Optional[str]): Specific model version to use. - tokenizer (Optional[str]): Tokenizer to be used, if different from the model. - image_path (str): Path to the training images. - class_image_path (Optional[str]): Path to the class images. - prompt (str): Prompt for the instance images. - class_prompt (Optional[str]): Prompt for the class images. - num_class_images (int): Number of class images to generate. - class_labels_conditioning (Optional[str]): Conditioning labels for class images. - prior_preservation (bool): Enable prior preservation during training. - prior_loss_weight (float): Weight of the prior preservation loss. - project_name (str): Name of the project for output directory. - seed (int): Random seed for reproducibility. - resolution (int): Resolution of the training images. - center_crop (bool): Enable center cropping of images. - train_text_encoder (bool): Enable training of the text encoder. - batch_size (int): Batch size for training. - sample_batch_size (int): Batch size for sampling. - epochs (int): Number of training epochs. - num_steps (int): Maximum number of training steps. - checkpointing_steps (int): Steps interval for checkpointing. - resume_from_checkpoint (Optional[str]): Path to resume training from a checkpoint. - gradient_accumulation (int): Number of gradient accumulation steps. - disable_gradient_checkpointing (bool): Disable gradient checkpointing. - lr (float): Learning rate for training. - scale_lr (bool): Enable scaling of the learning rate. - scheduler (str): Type of learning rate scheduler. - warmup_steps (int): Number of warmup steps for learning rate scheduler. - num_cycles (int): Number of cycles for learning rate scheduler. - lr_power (float): Power factor for learning rate scheduler. - dataloader_num_workers (int): Number of workers for data loading. - use_8bit_adam (bool): Enable use of 8-bit Adam optimizer. - adam_beta1 (float): Beta1 parameter for Adam optimizer. - adam_beta2 (float): Beta2 parameter for Adam optimizer. - adam_weight_decay (float): Weight decay for Adam optimizer. - adam_epsilon (float): Epsilon parameter for Adam optimizer. - max_grad_norm (float): Maximum gradient norm for clipping. - allow_tf32 (bool): Allow use of TF32 for training. - prior_generation_precision (Optional[str]): Precision for prior generation. - local_rank (int): Local rank for distributed training. - xformers (bool): Enable xformers memory efficient attention. - pre_compute_text_embeddings (bool): Pre-compute text embeddings before training. - tokenizer_max_length (Optional[int]): Maximum length for tokenizer. - text_encoder_use_attention_mask (bool): Use attention mask for text encoder. - rank (int): Rank for distributed training. - xl (bool): Enable XL model training. - mixed_precision (Optional[str]): Enable mixed precision training. - token (Optional[str]): Token for accessing the model hub. - push_to_hub (bool): Enable pushing the model to the hub. - username (Optional[str]): Username for the model hub. - validation_prompt (Optional[str]): Prompt for validation images. - num_validation_images (int): Number of validation images to generate. - validation_epochs (int): Epoch interval for validation. - checkpoints_total_limit (Optional[int]): Total limit for checkpoints. - validation_images (Optional[str]): Path to validation images. - logging (bool): Enable logging using TensorBoard. - """ - - model: str = Field(None, title="Name of the model to be used for training") - vae_model: Optional[str] = Field(None, title="Name of the VAE model to be used, if any") - revision: Optional[str] = Field(None, title="Specific model version to use") - tokenizer: Optional[str] = Field(None, title="Tokenizer to be used, if different from the model") - image_path: str = Field(None, title="Path to the training images") - class_image_path: Optional[str] = Field(None, title="Path to the class images") - prompt: str = Field(None, title="Prompt for the instance images") - class_prompt: Optional[str] = Field(None, title="Prompt for the class images") - num_class_images: int = Field(100, title="Number of class images to generate") - class_labels_conditioning: Optional[str] = Field(None, title="Conditioning labels for class images") - - prior_preservation: bool = Field(False, title="Enable prior preservation during training") - prior_loss_weight: float = Field(1.0, title="Weight of the prior preservation loss") - - project_name: str = Field("dreambooth-model", title="Name of the project for output directory") - seed: int = Field(42, title="Random seed for reproducibility") - resolution: int = Field(512, title="Resolution of the training images") - center_crop: bool = Field(False, title="Enable center cropping of images") - train_text_encoder: bool = Field(False, title="Enable training of the text encoder") - batch_size: int = Field(4, title="Batch size for training") - sample_batch_size: int = Field(4, title="Batch size for sampling") - epochs: int = Field(1, title="Number of training epochs") - num_steps: int = Field(None, title="Maximum number of training steps") - checkpointing_steps: int = Field(500, title="Steps interval for checkpointing") - resume_from_checkpoint: Optional[str] = Field(None, title="Path to resume training from a checkpoint") - - gradient_accumulation: int = Field(1, title="Number of gradient accumulation steps") - disable_gradient_checkpointing: bool = Field(False, title="Disable gradient checkpointing") - - lr: float = Field(1e-4, title="Learning rate for training") - scale_lr: bool = Field(False, title="Enable scaling of the learning rate") - scheduler: str = Field("constant", title="Type of learning rate scheduler") - warmup_steps: int = Field(0, title="Number of warmup steps for learning rate scheduler") - num_cycles: int = Field(1, title="Number of cycles for learning rate scheduler") - lr_power: float = Field(1.0, title="Power factor for learning rate scheduler") - - dataloader_num_workers: int = Field(0, title="Number of workers for data loading") - use_8bit_adam: bool = Field(False, title="Enable use of 8-bit Adam optimizer") - adam_beta1: float = Field(0.9, title="Beta1 parameter for Adam optimizer") - adam_beta2: float = Field(0.999, title="Beta2 parameter for Adam optimizer") - adam_weight_decay: float = Field(1e-2, title="Weight decay for Adam optimizer") - adam_epsilon: float = Field(1e-8, title="Epsilon parameter for Adam optimizer") - max_grad_norm: float = Field(1.0, title="Maximum gradient norm for clipping") - - allow_tf32: bool = Field(False, title="Allow use of TF32 for training") - prior_generation_precision: Optional[str] = Field(None, title="Precision for prior generation") - local_rank: int = Field(-1, title="Local rank for distributed training") - xformers: bool = Field(False, title="Enable xformers memory efficient attention") - pre_compute_text_embeddings: bool = Field(False, title="Pre-compute text embeddings before training") - tokenizer_max_length: Optional[int] = Field(None, title="Maximum length for tokenizer") - text_encoder_use_attention_mask: bool = Field(False, title="Use attention mask for text encoder") - - rank: int = Field(4, title="Rank for distributed training") - xl: bool = Field(False, title="Enable XL model training") - - mixed_precision: Optional[str] = Field(None, title="Enable mixed precision training") - - token: Optional[str] = Field(None, title="Token for accessing the model hub") - push_to_hub: bool = Field(False, title="Enable pushing the model to the hub") - username: Optional[str] = Field(None, title="Username for the model hub") - - # disabled: - validation_prompt: Optional[str] = Field(None, title="Prompt for validation images") - num_validation_images: int = Field(4, title="Number of validation images to generate") - validation_epochs: int = Field(50, title="Epoch interval for validation") - checkpoints_total_limit: Optional[int] = Field(None, title="Total limit for checkpoints") - validation_images: Optional[str] = Field(None, title="Path to validation images") - - logging: bool = Field(False, title="Enable logging using TensorBoard") diff --git a/src/autotrain/trainers/dreambooth/train.py b/src/autotrain/trainers/dreambooth/train.py deleted file mode 100644 index 0131f9629a..0000000000 --- a/src/autotrain/trainers/dreambooth/train.py +++ /dev/null @@ -1,979 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# taken from: https://github.com/huggingface/diffusers/blob/v0.27.2-patch/ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -import copy -import gc -import logging -import math -import os -import shutil -from pathlib import Path - -import diffusers -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import transformers -from accelerate import Accelerator -from accelerate.utils import ProjectConfiguration, set_seed -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - DiffusionPipeline, - DPMSolverMultistepScheduler, - UNet2DConditionModel, -) -from diffusers.loaders import StableDiffusionLoraLoaderMixin -from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params -from diffusers.utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft, is_wandb_available -from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.torch_utils import is_compiled_module -from huggingface_hub.utils import insecure_hashlib -from packaging import version -from peft import LoraConfig -from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict -from PIL import Image -from PIL.ImageOps import exif_transpose -from torch.utils.data import Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig - -from autotrain import logger - - -def log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - torch_dtype, # Add torch_dtype parameter - is_final_validation=False, -): - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) # Use torch_dtype - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - - if args.validation_images is None: - images = [] - for _ in range(args.num_validation_images): - with torch.cuda.amp.autocast(): - image = pipeline(**pipeline_args, generator=generator).images[0] - images.append(image) - else: - images = [] - for image in args.validation_images: - image = Image.open(image) - with torch.cuda.amp.autocast(): - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) - - for tracker in accelerator.trackers: - phase_name = "test" if is_final_validation else "validation" - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") - - del pipeline - torch.cuda.empty_cache() - - return images - - -def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): - text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, - subfolder="text_encoder", - revision=revision, - ) - model_class = text_encoder_config.architectures[0] - - if model_class == "CLIPTextModel": - from transformers import CLIPTextModel - - return CLIPTextModel - elif model_class == "RobertaSeriesModelWithTransformation": - from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation - - return RobertaSeriesModelWithTransformation - elif model_class == "T5EncoderModel": - from transformers import T5EncoderModel - - return T5EncoderModel - else: - raise ValueError(f"{model_class} is not supported.") - - -class DreamBoothDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. - """ - - def __init__( - self, - instance_data_root, - instance_prompt, - tokenizer, - class_data_root=None, - class_prompt=None, - class_num=None, - size=512, - center_crop=False, - encoder_hidden_states=None, - class_prompt_encoder_hidden_states=None, - tokenizer_max_length=None, - ): - self.size = size - self.center_crop = center_crop - self.tokenizer = tokenizer - self.encoder_hidden_states = encoder_hidden_states - self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states - self.tokenizer_max_length = tokenizer_max_length - - self.instance_data_root = Path(instance_data_root) - if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") - - self.instance_images_path = list(Path(instance_data_root).iterdir()) - self.num_instance_images = len(self.instance_images_path) - self.instance_prompt = instance_prompt - self._length = self.num_instance_images - - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - if class_num is not None: - self.num_class_images = min(len(self.class_images_path), class_num) - else: - self.num_class_images = len(self.class_images_path) - self._length = max(self.num_class_images, self.num_instance_images) - self.class_prompt = class_prompt - else: - self.class_data_root = None - - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return self._length - - def __getitem__(self, index): - example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) - instance_image = exif_transpose(instance_image) - - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - example["instance_images"] = self.image_transforms(instance_image) - - if self.encoder_hidden_states is not None: - example["instance_prompt_ids"] = self.encoder_hidden_states - else: - text_inputs = tokenize_prompt( - self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length - ) - example["instance_prompt_ids"] = text_inputs.input_ids - example["instance_attention_mask"] = text_inputs.attention_mask - - if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) - class_image = exif_transpose(class_image) - - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) - - if self.class_prompt_encoder_hidden_states is not None: - example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states - else: - class_text_inputs = tokenize_prompt( - self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length - ) - example["class_prompt_ids"] = class_text_inputs.input_ids - example["class_attention_mask"] = class_text_inputs.attention_mask - - return example - - -def collate_fn(examples, with_prior_preservation=False): - has_attention_mask = "instance_attention_mask" in examples[0] - - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - - if has_attention_mask: - attention_mask = [example["instance_attention_mask"] for example in examples] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. - if with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - if has_attention_mask: - attention_mask += [example["class_attention_mask"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - input_ids = torch.cat(input_ids, dim=0) - - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - } - - if has_attention_mask: - batch["attention_mask"] = attention_mask - - return batch - - -class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." - - def __init__(self, prompt, num_samples): - self.prompt = prompt - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["index"] = index - return example - - -def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): - if tokenizer_max_length is not None: - max_length = tokenizer_max_length - else: - max_length = tokenizer.model_max_length - - text_inputs = tokenizer( - prompt, - truncation=True, - padding="max_length", - max_length=max_length, - return_tensors="pt", - ) - - return text_inputs - - -def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): - text_input_ids = input_ids.to(text_encoder.device) - - if text_encoder_use_attention_mask: - attention_mask = attention_mask.to(text_encoder.device) - else: - attention_mask = None - - prompt_embeds = text_encoder( - text_input_ids, - attention_mask=attention_mask, - return_dict=False, - ) - prompt_embeds = prompt_embeds[0] - - return prompt_embeds - - -def main(args): - if args.report_to == "wandb" and args.hub_token is not None: - raise ValueError( - "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." - " Please use `huggingface-cli login` to authenticate with the Hub." - ) - - logging_dir = Path(args.output_dir, args.logging_dir) - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - ) - - # Add MPS support check - if torch.backends.mps.is_available(): - accelerator.native_amp = False - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - - # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate - # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Generate class images if prior preservation is enabled. - if args.with_prior_preservation: - class_images_dir = Path(args.class_data_dir) - if not class_images_dir.exists(): - class_images_dir.mkdir(parents=True) - cur_class_images = len(list(class_images_dir.iterdir())) - - if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 - if args.prior_generation_precision == "fp32": - torch_dtype = torch.float32 - elif args.prior_generation_precision == "fp16": - torch_dtype = torch.float16 - elif args.prior_generation_precision == "bf16": - torch_dtype = torch.bfloat16 - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, - safety_checker=None, - revision=args.revision, - variant=args.variant, - ) - pipeline.set_progress_bar_config(disable=True) - - num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - - sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - - sample_dataloader = accelerator.prepare(sample_dataloader) - pipeline.to(accelerator.device) - - for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process - ): - images = pipeline(example["prompt"]).images - - for i, image in enumerate(images): - hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - image.save(image_filename) - - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - # Load the tokenizer - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) - elif args.pretrained_model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, - ) - - # import correct text encoder class - text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) - - # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant - ) - vae_path = ( - args.pretrained_model_name_or_path - if args.pretrained_vae_model_name_or_path is None - else args.pretrained_vae_model_name_or_path - ) - try: - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - ) - except OSError: - # IF does not have a VAE so let's just set it to None - # We don't have to error out here - vae = None - - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant - ) - - # We only train the additional adapter LoRA layers - if vae is not None: - vae.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.requires_grad_(False) - - # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) - if vae is not None: - vae.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warning( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() - - # now we will add new LoRA weights to the attention layers - unet_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], - ) - unet.add_adapter(unet_lora_config) - - # The text encoder comes from 🤗 transformers, we will also attach adapters to it. - if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) - text_encoder.add_adapter(text_lora_config) - - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - # there are only two options here. Either are just the unet attn processor layers - # or there are the unet and text encoder atten layers - unet_lora_layers_to_save = None - text_encoder_lora_layers_to_save = None - - for model in models: - if isinstance(model, type(unwrap_model(unet))): - unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(unwrap_model(text_encoder))): - text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - StableDiffusionLoraLoaderMixin.save_lora_weights( - output_dir, - unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_lora_layers_to_save, - ) - - def load_model_hook(models, input_dir): - unet_ = None - text_encoder_ = None - - while len(models) > 0: - model = models.pop() - - if isinstance(model, type(unwrap_model(unet))): - unet_ = model - elif isinstance(model, type(unwrap_model(text_encoder))): - text_encoder_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) - - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} - unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) - incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - if args.train_text_encoder: - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) - - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if args.mixed_precision == "fp16": - models = [unet_] - if args.train_text_encoder: - models.append(text_encoder_) - - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models, dtype=torch.float32) - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.append(text_encoder) - - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models, dtype=torch.float32) - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - # Optimizer creation - params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) - if args.train_text_encoder: - params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters())) - - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - if args.pre_compute_text_embeddings: - - def compute_text_embeddings(prompt): - with torch.no_grad(): - text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) - prompt_embeds = encode_prompt( - text_encoder, - text_inputs.input_ids, - text_inputs.attention_mask, - text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, - ) - - return prompt_embeds - - pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) - validation_prompt_negative_prompt_embeds = compute_text_embeddings("") - - if args.validation_prompt is not None: - validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) - else: - validation_prompt_encoder_hidden_states = None - - if args.class_prompt is not None: - pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) - else: - pre_computed_class_prompt_encoder_hidden_states = None - - text_encoder = None - tokenizer = None - - gc.collect() - torch.cuda.empty_cache() - else: - pre_computed_encoder_hidden_states = None - validation_prompt_encoder_hidden_states = None - validation_prompt_negative_prompt_embeds = None - pre_computed_class_prompt_encoder_hidden_states = None - - # Dataset and DataLoaders creation: - train_dataset = DreamBoothDataset( - instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, - class_num=args.num_class_images, - tokenizer=tokenizer, - size=args.resolution, - center_crop=args.center_crop, - encoder_hidden_states=pre_computed_encoder_hidden_states, - class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, - tokenizer_max_length=args.tokenizer_max_length, - ) - - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), - num_workers=args.dataloader_num_workers, - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) - - # Prepare everything with our `accelerator`. - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - tracker_config = vars(copy.deepcopy(args)) - accelerator.init_trackers("dreambooth-lora", config=tracker_config) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - else: - initial_global_step = 0 - - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=initial_global_step, - desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, - ) - - for epoch in range(first_epoch, args.num_train_epochs): - unet.train() - if args.train_text_encoder: - text_encoder.train() - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): - pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - - if vae is not None: - # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor - else: - model_input = pixel_values - - # Sample noise that we'll add to the latents - noise = torch.randn_like(model_input) - bsz, channels, height, width = model_input.shape - # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - - # Get the text embedding for conditioning - if args.pre_compute_text_embeddings: - encoder_hidden_states = batch["input_ids"] - else: - encoder_hidden_states = encode_prompt( - text_encoder, - batch["input_ids"], - batch["attention_mask"], - text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, - ) - - if unwrap_model(unet).config.in_channels == channels * 2: - noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) - - if args.class_labels_conditioning == "timesteps": - class_labels = timesteps - else: - class_labels = None - - # Predict the noise residual - model_pred = unet( - noisy_model_input, - timesteps, - encoder_hidden_states, - class_labels=class_labels, - return_dict=False, - )[0] - - # if model predicts variance, throw away the prediction. we will only train on the - # simplified training objective. This means that all schedulers using the fine tuned - # model must be configured to use one of the fixed variance variance types. - if model_pred.shape[1] == 6: - model_pred, _ = torch.chunk(model_pred, 2, dim=1) - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(model_input, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - accelerator.backward(loss) - if accelerator.sync_gradients: - accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - # create pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unwrap_model(unet), - text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder), - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - if args.pre_compute_text_embeddings: - pipeline_args = { - "prompt_embeds": validation_prompt_encoder_hidden_states, - "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, - } - else: - pipeline_args = {"prompt": args.validation_prompt} - - images = log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - torch_dtype=weight_dtype, - ) - - # Save the lora layers - accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = unwrap_model(unet) - unet = unet.to(torch.float32) - - unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) - - if args.train_text_encoder: - text_encoder = unwrap_model(text_encoder) - text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder)) - else: - text_encoder_state_dict = None - - StableDiffusionLoraLoaderMixin.save_lora_weights( - save_directory=args.output_dir, - unet_lora_layers=unet_lora_state_dict, - text_encoder_lora_layers=text_encoder_state_dict, - ) - - accelerator.end_training() diff --git a/src/autotrain/trainers/dreambooth/train_xl.py b/src/autotrain/trainers/dreambooth/train_xl.py deleted file mode 100644 index d0b844537c..0000000000 --- a/src/autotrain/trainers/dreambooth/train_xl.py +++ /dev/null @@ -1,1213 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# taken from: https://github.com/huggingface/diffusers/blob/v0.27.2-patch/ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -import gc -import itertools -import json -import logging -import math -import os -import random -import shutil -from pathlib import Path - -import diffusers -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import transformers -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - EDMEulerScheduler, - EulerDiscreteScheduler, - StableDiffusionXLPipeline, - UNet2DConditionModel, -) -from diffusers.loaders import LoraLoaderMixin -from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr -from diffusers.utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft, is_wandb_available -from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.torch_utils import is_compiled_module -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import insecure_hashlib -from packaging import version -from peft import LoraConfig, set_peft_model_state_dict -from peft.utils import get_peft_model_state_dict -from PIL import Image -from PIL.ImageOps import exif_transpose -from torch.utils.data import Dataset -from torchvision import transforms -from torchvision.transforms.functional import crop -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig - -from autotrain import logger - - -def determine_scheduler_type(pretrained_model_name_or_path, revision): - model_index_filename = "model_index.json" - if os.path.isdir(pretrained_model_name_or_path): - model_index = os.path.join(pretrained_model_name_or_path, model_index_filename) - else: - model_index = hf_hub_download( - repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision - ) - - with open(model_index, "r") as f: - scheduler_type = json.load(f)["scheduler"][1] - return scheduler_type - - -def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" -): - text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder=subfolder, revision=revision - ) - model_class = text_encoder_config.architectures[0] - - if model_class == "CLIPTextModel": - from transformers import CLIPTextModel - - return CLIPTextModel - elif model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection - - return CLIPTextModelWithProjection - else: - raise ValueError(f"{model_class} is not supported.") - - -class DreamBoothDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images. - """ - - def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, - random_flip=False, - ): - self.size = size - self.resolution = size - self.center_crop = center_crop - - self.instance_prompt = instance_prompt - self.custom_instance_prompts = None - self.class_prompt = class_prompt - - self.random_flip = random_flip - - self.instance_data_root = Path(instance_data_root) - if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") - - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] - self.custom_instance_prompts = None - - self.instance_images = [] - for img in instance_images: - self.instance_images.extend(itertools.repeat(img, repeats)) - - # image processing to prepare for using SD-XL micro-conditioning - self.original_sizes = [] - self.crop_top_lefts = [] - self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - for image in self.instance_images: - image = exif_transpose(image) - if not image.mode == "RGB": - image = image.convert("RGB") - self.original_sizes.append((image.height, image.width)) - image = train_resize(image) - if self.random_flip and random.random() < 0.5: - # flip - image = train_flip(image) - if self.center_crop: - y1 = max(0, int(round((image.height - self.resolution) / 2.0))) - x1 = max(0, int(round((image.width - self.resolution) / 2.0))) - image = train_crop(image) - else: - y1, x1, h, w = train_crop.get_params(image, (self.resolution, self.resolution)) - image = crop(image, y1, x1, h, w) - crop_top_left = (y1, x1) - self.crop_top_lefts.append(crop_top_left) - image = train_transforms(image) - self.pixel_values.append(image) - - self.num_instance_images = len(self.instance_images) - self._length = self.num_instance_images - - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - if class_num is not None: - self.num_class_images = min(len(self.class_images_path), class_num) - else: - self.num_class_images = len(self.class_images_path) - self._length = max(self.num_class_images, self.num_instance_images) - else: - self.class_data_root = None - - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - def __len__(self): - return self._length - - def __getitem__(self, index): - example = {} - instance_image = self.pixel_values[index % self.num_instance_images] - original_size = self.original_sizes[index % self.num_instance_images] - crop_top_left = self.crop_top_lefts[index % self.num_instance_images] - example["instance_images"] = instance_image - example["original_size"] = original_size - example["crop_top_left"] = crop_top_left - - if self.custom_instance_prompts: - caption = self.custom_instance_prompts[index % self.num_instance_images] - if caption: - example["instance_prompt"] = caption - else: - example["instance_prompt"] = self.instance_prompt - - else: # costum prompts were provided, but length does not match size of image dataset - example["instance_prompt"] = self.instance_prompt - - if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) - class_image = exif_transpose(class_image) - - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) - example["class_prompt"] = self.class_prompt - - return example - - -def collate_fn(examples, with_prior_preservation=False): - pixel_values = [example["instance_images"] for example in examples] - prompts = [example["instance_prompt"] for example in examples] - original_sizes = [example["original_size"] for example in examples] - crop_top_lefts = [example["crop_top_left"] for example in examples] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. - if with_prior_preservation: - pixel_values += [example["class_images"] for example in examples] - prompts += [example["class_prompt"] for example in examples] - original_sizes += [example["original_size"] for example in examples] - crop_top_lefts += [example["crop_top_left"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - batch = { - "pixel_values": pixel_values, - "prompts": prompts, - "original_sizes": original_sizes, - "crop_top_lefts": crop_top_lefts, - } - return batch - - -class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." - - def __init__(self, prompt, num_samples): - self.prompt = prompt - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["index"] = index - return example - - -def tokenize_prompt(tokenizer, prompt): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - return text_input_ids - - -# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt -def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): - prompt_embeds_list = [] - - for i, text_encoder in enumerate(text_encoders): - if tokenizers is not None: - tokenizer = tokenizers[i] - text_input_ids = tokenize_prompt(tokenizer, prompt) - else: - assert text_input_ids_list is not None - text_input_ids = text_input_ids_list[i] - - prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False - ) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds[-1][-2] - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) - return prompt_embeds, pooled_prompt_embeds - - -def main(args): - if args.report_to == "wandb" and args.hub_token is not None: - raise ValueError( - "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." - " Please use `huggingface-cli login` to authenticate with the Hub." - ) - - if args.do_edm_style_training and args.snr_gamma is not None: - raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") - - logging_dir = Path(args.output_dir, args.logging_dir) - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - kwargs_handlers=[kwargs], - ) - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Generate class images if prior preservation is enabled. - if args.with_prior_preservation: - class_images_dir = Path(args.class_data_dir) - if not class_images_dir.exists(): - class_images_dir.mkdir(parents=True) - cur_class_images = len(list(class_images_dir.iterdir())) - - if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 - if args.prior_generation_precision == "fp32": - torch_dtype = torch.float32 - elif args.prior_generation_precision == "fp16": - torch_dtype = torch.float16 - elif args.prior_generation_precision == "bf16": - torch_dtype = torch.bfloat16 - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, - revision=args.revision, - variant=args.variant, - ) - pipeline.set_progress_bar_config(disable=True) - - num_new_images = args.num_class_images - cur_class_images - logger.info(f"Number of class images to sample: {num_new_images}.") - - sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - - sample_dataloader = accelerator.prepare(sample_dataloader) - pipeline.to(accelerator.device) - - for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process - ): - images = pipeline(example["prompt"]).images - - for i, image in enumerate(images): - hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - image.save(image_filename) - - del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - # Load the tokenizers - tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - use_fast=False, - ) - tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_2", - revision=args.revision, - use_fast=False, - ) - - # import correct text encoder classes - text_encoder_cls_one = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision - ) - text_encoder_cls_two = import_model_class_from_model_name_or_path( - args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" - ) - - # Load scheduler and models - scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision) - if "EDM" in scheduler_type: - args.do_edm_style_training = True - noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - logger.info("Performing EDM-style training!") - elif args.do_edm_style_training: - noise_scheduler = EulerDiscreteScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler" - ) - logger.info("Performing EDM-style training!") - else: - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - - text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant - ) - text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant - ) - vae_path = ( - args.pretrained_model_name_or_path - if args.pretrained_vae_model_name_or_path is None - else args.pretrained_vae_model_name_or_path - ) - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - ) - latents_mean = latents_std = None - if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: - latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: - latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) - - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant - ) - - # We only train the additional adapter LoRA layers - vae.requires_grad_(False) - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) - unet.requires_grad_(False) - - # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) - - # The VAE is always in float32 to avoid NaN losses. - vae.to(accelerator.device, dtype=torch.float32) - - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warning( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " - "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder_one.gradient_checkpointing_enable() - text_encoder_two.gradient_checkpointing_enable() - - # now we will add new LoRA weights to the attention layers - unet_lora_config = LoraConfig( - r=args.rank, - use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - ) - unet.add_adapter(unet_lora_config) - - # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. - if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) - text_encoder_one.add_adapter(text_lora_config) - text_encoder_two.add_adapter(text_lora_config) - - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - # there are only two options here. Either are just the unet attn processor layers - # or there are the unet and text encoder atten layers - unet_lora_layers_to_save = None - text_encoder_one_lora_layers_to_save = None - text_encoder_two_lora_layers_to_save = None - - for model in models: - if isinstance(model, type(unwrap_model(unet))): - unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) - elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - StableDiffusionXLPipeline.save_lora_weights( - output_dir, - unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, - text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, - ) - - def load_model_hook(models, input_dir): - unet_ = None - text_encoder_one_ = None - text_encoder_two_ = None - - while len(models) > 0: - model = models.pop() - - if isinstance(model, type(unwrap_model(unet))): - unet_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_ = model - elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} - unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) - incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) - - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ - ) - - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if args.mixed_precision == "fp16": - models = [unet_] - if args.train_text_encoder: - models.extend([text_encoder_one_, text_encoder_two_]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) - - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models, dtype=torch.float32) - - unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) - - if args.train_text_encoder: - text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) - text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) - - # Optimization parameters - unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} - if args.train_text_encoder: - # different learning rate for text encoder and unet - text_lora_parameters_one_with_lr = { - "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } - text_lora_parameters_two_with_lr = { - "params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } - params_to_optimize = [ - unet_lora_parameters_with_lr, - text_lora_parameters_one_with_lr, - text_lora_parameters_two_with_lr, - ] - else: - params_to_optimize = [unet_lora_parameters_with_lr] - - # Optimizer creation - if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): - logger.warning( - f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." - "Defaulting to adamW" - ) - args.optimizer = "adamw" - - if args.use_8bit_adam and not args.optimizer.lower() == "adamw": - logger.warning( - f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " - f"set to {args.optimizer.lower()}" - ) - - if args.optimizer.lower() == "adamw": - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - if args.optimizer.lower() == "prodigy": - try: - import prodigyopt - except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") - - optimizer_class = prodigyopt.Prodigy - - if args.learning_rate <= 0.1: - logger.warning( - "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" - ) - if args.train_text_encoder and args.text_encoder_lr: - logger.warning( - f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" - f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " - f"When using prodigy only learning_rate is used as the initial learning rate." - ) - # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be - # --learning_rate - params_to_optimize[1]["lr"] = args.learning_rate - params_to_optimize[2]["lr"] = args.learning_rate - - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - beta3=args.prodigy_beta3, - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - decouple=args.prodigy_decouple, - use_bias_correction=args.prodigy_use_bias_correction, - safeguard_warmup=args.prodigy_safeguard_warmup, - ) - - # Dataset and DataLoaders creation: - train_dataset = DreamBoothDataset( - instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, - class_prompt=args.class_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_num=args.num_class_images, - size=args.resolution, - repeats=args.repeats, - center_crop=args.center_crop, - random_flip=args.random_flip, - ) - - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), - num_workers=args.dataloader_num_workers, - ) - - # Computes additional embeddings/ids required by the SDXL UNet. - # regular text embeddings (when `train_text_encoder` is not True) - # pooled text embeddings - # time ids - - def compute_time_ids(original_size, crops_coords_top_left): - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - target_size = (args.resolution, args.resolution) - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) - return add_time_ids - - if not args.train_text_encoder: - tokenizers = [tokenizer_one, tokenizer_two] - text_encoders = [text_encoder_one, text_encoder_two] - - def compute_text_embeddings(prompt, text_encoders, tokenizers): - with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) - prompt_embeds = prompt_embeds.to(accelerator.device) - pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) - return prompt_embeds, pooled_prompt_embeds - - # If no type of tuning is done on the text_encoder and custom instance prompts are NOT - # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid - # the redundant encoding. - if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( - args.instance_prompt, text_encoders, tokenizers - ) - - # Handle class prompt for prior-preservation. - if args.with_prior_preservation: - if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers - ) - - # Clear the memory here - if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - del tokenizers, text_encoders - gc.collect() - torch.cuda.empty_cache() - - # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), - # pack the statically computed variables appropriately here. This is so that we don't - # have to pass them to the dataloader. - - if not train_dataset.custom_instance_prompts: - if not args.train_text_encoder: - prompt_embeds = instance_prompt_hidden_states - unet_add_text_embeds = instance_pooled_prompt_embeds - if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) - # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the - # batch prompts on all training steps - else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) - tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) - if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) - tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) - tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) - - # Prepare everything with our `accelerator`. - if args.train_text_encoder: - unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler - ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - tracker_name = ( - "dreambooth-lora-sd-xl" - if "playground" not in args.pretrained_model_name_or_path - else "dreambooth-lora-playground" - ) - accelerator.init_trackers(tracker_name, config=vars(args)) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - - else: - initial_global_step = 0 - - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=initial_global_step, - desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, - ) - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - for epoch in range(first_epoch, args.num_train_epochs): - unet.train() - if args.train_text_encoder: - text_encoder_one.train() - text_encoder_two.train() - - # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) - accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) - - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - prompts = batch["prompts"] - - # encode batch prompts when custom prompts are provided for each image - - if train_dataset.custom_instance_prompts: - if not args.train_text_encoder: - prompt_embeds, unet_add_text_embeds = compute_text_embeddings( - prompts, text_encoders, tokenizers - ) - else: - tokens_one = tokenize_prompt(tokenizer_one, prompts) - tokens_two = tokenize_prompt(tokenizer_two, prompts) - - # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() - - if latents_mean is None and latents_std is None: - model_input = model_input * vae.config.scaling_factor - if args.pretrained_vae_model_name_or_path is None: - model_input = model_input.to(weight_dtype) - else: - latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype) - latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype) - model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std - model_input = model_input.to(dtype=weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(model_input) - bsz = model_input.shape[0] - - # Sample a random timestep for each image - if not args.do_edm_style_training: - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() - else: - # in EDM formulation, the model is conditioned on the pre-conditioned noise levels - # instead of discrete timesteps, so here we sample indices to get the noise levels - # from `scheduler.timesteps` - indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)) - timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - # For EDM-style training, we first obtain the sigmas based on the continuous timesteps. - # We then precondition the final model inputs based on these sigmas instead of the timesteps. - # Follow: Section 5 of https://arxiv.org/abs/2206.00364. - if args.do_edm_style_training: - sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype) - if "EDM" in scheduler_type: - inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas) - else: - inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5) - - # time ids - add_time_ids = torch.cat( - [ - compute_time_ids(original_size=s, crops_coords_top_left=c) - for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"]) - ] - ) - - # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. - if not train_dataset.custom_instance_prompts: - elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz - else: - elems_to_repeat_text_embeds = 1 - - # Predict the noise residual - if not args.train_text_encoder: - unet_added_conditions = { - "time_ids": add_time_ids, - "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), - } - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) - model_pred = unet( - inp_noisy_latents if args.do_edm_style_training else noisy_model_input, - timesteps, - prompt_embeds_input, - added_cond_kwargs=unet_added_conditions, - return_dict=False, - )[0] - else: - unet_added_conditions = {"time_ids": add_time_ids} - prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=None, - prompt=None, - text_input_ids_list=[tokens_one, tokens_two], - ) - unet_added_conditions.update( - {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} - ) - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) - model_pred = unet( - inp_noisy_latents if args.do_edm_style_training else noisy_model_input, - timesteps, - prompt_embeds_input, - added_cond_kwargs=unet_added_conditions, - return_dict=False, - )[0] - - weighting = None - if args.do_edm_style_training: - # Similar to the input preconditioning, the model predictions are also preconditioned - # on noised model inputs (before preconditioning) and the sigmas. - # Follow: Section 5 of https://arxiv.org/abs/2206.00364. - if "EDM" in scheduler_type: - model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas) - else: - if noise_scheduler.config.prediction_type == "epsilon": - model_pred = model_pred * (-sigmas) + noisy_model_input - elif noise_scheduler.config.prediction_type == "v_prediction": - model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + ( - noisy_model_input / (sigmas**2 + 1) - ) - # We are not doing weighting here because it tends result in numerical problems. - # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 - # There might be other alternatives for weighting as well: - # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686 - if "EDM" not in scheduler_type: - weighting = (sigmas**-2.0).float() - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = model_input if args.do_edm_style_training else noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = ( - model_input - if args.do_edm_style_training - else noise_scheduler.get_velocity(model_input, noise, timesteps) - ) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute prior loss - if weighting is not None: - prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( - target_prior.shape[0], -1 - ), - 1, - ) - prior_loss = prior_loss.mean() - else: - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - if args.snr_gamma is None: - if weighting is not None: - loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape( - target.shape[0], -1 - ), - 1, - ) - loss = loss.mean() - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - else: - # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. - # Since we predict the noise instead of x_0, the original formulation is slightly changed. - # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) - - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 - else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight - - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights - loss = loss.mean() - - if args.with_prior_preservation: - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - - accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) - if args.train_text_encoder - else unet_lora_parameters - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - # Save the lora layers - accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = unwrap_model(unet) - unet = unet.to(torch.float32) - unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) - - if args.train_text_encoder: - text_encoder_one = unwrap_model(text_encoder_one) - text_encoder_lora_layers = convert_state_dict_to_diffusers( - get_peft_model_state_dict(text_encoder_one.to(torch.float32)) - ) - text_encoder_two = unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = convert_state_dict_to_diffusers( - get_peft_model_state_dict(text_encoder_two.to(torch.float32)) - ) - else: - text_encoder_lora_layers = None - text_encoder_2_lora_layers = None - - StableDiffusionXLPipeline.save_lora_weights( - save_directory=args.output_dir, - unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, - text_encoder_2_lora_layers=text_encoder_2_lora_layers, - ) - - accelerator.end_training() diff --git a/src/autotrain/trainers/dreambooth/trainer.py b/src/autotrain/trainers/dreambooth/trainer.py deleted file mode 100644 index b650c9d2be..0000000000 --- a/src/autotrain/trainers/dreambooth/trainer.py +++ /dev/null @@ -1,484 +0,0 @@ -import itertools -import math -import os -import shutil - -import torch -import torch.nn.functional as F -from diffusers import StableDiffusionXLPipeline -from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict -from diffusers.optimization import get_scheduler -from huggingface_hub import create_repo, upload_folder -from tqdm import tqdm - -from autotrain import logger -from autotrain.trainers.dreambooth import utils - - -class Trainer: - def __init__( - self, - unet, - vae, - train_dataloader, - train_dataset, - text_encoders, - config, - optimizer, - accelerator, - noise_scheduler, - weight_dtype, - text_lora_parameters, - unet_lora_parameters, - tokenizers, - ): - self.train_dataloader = train_dataloader - self.config = config - self.optimizer = optimizer - self.accelerator = accelerator - self.unet = unet - self.vae = vae - self.noise_scheduler = noise_scheduler - self.train_dataset = train_dataset - self.weight_dtype = weight_dtype - self.text_lora_parameters = text_lora_parameters - self.unet_lora_parameters = unet_lora_parameters - self.tokenizers = tokenizers - self.text_encoders = text_encoders - - if self.config.xl: - self._setup_xl() - - self.text_encoder1 = text_encoders[0] - self.text_encoder2 = None - if len(text_encoders) == 2: - self.text_encoder2 = text_encoders[1] - - overrode_max_train_steps = False - self.num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation) - if self.config.num_steps is None: - self.config.num_steps = self.config.epochs * self.num_update_steps_per_epoch - overrode_max_train_steps = True - - self.scheduler = get_scheduler( - self.config.scheduler, - optimizer=self.optimizer, - num_warmup_steps=self.config.warmup_steps * self.accelerator.num_processes, - num_training_steps=self.config.num_steps * self.accelerator.num_processes, - num_cycles=self.config.num_cycles, - power=self.config.lr_power, - ) - - if self.config.train_text_encoder: - if len(text_encoders) == 1: - ( - self.unet, - self.text_encoder1, - self.optimizer, - self.train_dataloader, - self.scheduler, - ) = self.accelerator.prepare( - self.unet, - self.text_encoder1, - self.optimizer, - self.train_dataloader, - self.scheduler, - ) - elif len(text_encoders) == 2: - ( - self.unet, - self.text_encoder1, - self.text_encoder2, - self.optimizer, - self.train_dataloader, - self.scheduler, - ) = self.accelerator.prepare( - self.unet, - self.text_encoder1, - self.text_encoder2, - self.optimizer, - self.train_dataloader, - self.scheduler, - ) - - else: - ( - self.unet, - self.optimizer, - self.train_dataloader, - self.scheduler, - ) = accelerator.prepare(self.unet, self.optimizer, self.train_dataloader, self.scheduler) - - self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / self.config.gradient_accumulation) - if overrode_max_train_steps: - self.config.num_steps = self.config.epochs * self.num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - self.config.epochs = math.ceil(self.config.num_steps / self.num_update_steps_per_epoch) - - if self.accelerator.is_main_process: - self.accelerator.init_trackers("dreambooth") - - self.total_batch_size = ( - self.config.batch_size * self.accelerator.num_processes * self.config.gradient_accumulation - ) - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(self.train_dataset)}") - logger.info(f" Num batches each epoch = {len(self.train_dataloader)}") - logger.info(f" Num Epochs = {self.config.epochs}") - logger.info(f" Instantaneous batch size per device = {config.batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {self.total_batch_size}") - logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation}") - logger.info(f" Total optimization steps = {self.config.num_steps}") - logger.info(f" Training config = {self.config}") - self.global_step = 0 - self.first_epoch = 0 - - if config.resume_from_checkpoint: - self._resume_from_checkpoint() - - def compute_text_embeddings(self, prompt): - logger.info(f"Computing text embeddings for prompt: {prompt}") - with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = utils.encode_prompt_xl(self.text_encoders, self.tokenizers, prompt) - prompt_embeds = prompt_embeds.to(self.accelerator.device) - pooled_prompt_embeds = pooled_prompt_embeds.to(self.accelerator.device) - return prompt_embeds, pooled_prompt_embeds - - def compute_time_ids(self): - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - original_size = (self.config.resolution, self.config.resolution) - target_size = (self.config.resolution, self.config.resolution) - # crops_coords_top_left = (self.config.crops_coords_top_left_h, self.config.crops_coords_top_left_w) - crops_coords_top_left = (0, 0) - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(self.accelerator.device, dtype=self.weight_dtype) - return add_time_ids - - def _setup_xl(self): - # Handle instance prompt. - instance_time_ids = self.compute_time_ids() - if not self.config.train_text_encoder: - ( - instance_prompt_hidden_states, - instance_pooled_prompt_embeds, - ) = self.compute_text_embeddings(self.config.prompt) - - # Handle class prompt for prior-preservation. - if self.config.prior_preservation: - class_time_ids = self.compute_time_ids() - if not self.config.train_text_encoder: - ( - class_prompt_hidden_states, - class_pooled_prompt_embeds, - ) = self.compute_text_embeddings(self.config.class_prompt) - - self.add_time_ids = instance_time_ids - if self.config.prior_preservation: - self.add_time_ids = torch.cat([self.add_time_ids, class_time_ids], dim=0) - - if not self.config.train_text_encoder: - self.prompt_embeds = instance_prompt_hidden_states - self.unet_add_text_embeds = instance_pooled_prompt_embeds - if self.config.prior_preservation: - self.prompt_embeds = torch.cat([self.prompt_embeds, class_prompt_hidden_states], dim=0) - self.unet_add_text_embeds = torch.cat([self.unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) - else: - self.tokens_one = utils.tokenize_prompt(self.tokenizers[0], self.config.prompt).input_ids - self.tokens_two = utils.tokenize_prompt(self.tokenizers[1], self.config.prompt).input_ids - if self.config.prior_preservation: - class_tokens_one = utils.tokenize_prompt(self.tokenizers[0], self.config.class_prompt).input_ids - class_tokens_two = utils.tokenize_prompt(self.tokenizers[1], self.config.class_prompt).input_ids - self.tokens_one = torch.cat([self.tokens_one, class_tokens_one], dim=0) - self.tokens_two = torch.cat([self.tokens_two, class_tokens_two], dim=0) - - def _resume_from_checkpoint(self): - if self.config.resume_from_checkpoint != "latest": - path = os.path.basename(self.config.resume_from_checkpoint) - else: - # Get the mos recent checkpoint - dirs = os.listdir(self.config.project_name) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - self.accelerator.print( - f"Checkpoint '{self.config.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - self.config.resume_from_checkpoint = None - else: - self.accelerator.print(f"Resuming from checkpoint {path}") - self.accelerator.load_state(os.path.join(self.config.project_name, path)) - self.global_step = int(path.split("-")[1]) - - resume_global_step = self.global_step * self.config.gradient_accumulation - self.first_epoch = self.global_step // self.num_update_steps_per_epoch - self.resume_step = resume_global_step % ( - self.num_update_steps_per_epoch * self.config.gradient_accumulation - ) - - def _calculate_loss(self, model_pred, noise, model_input, timesteps): - if model_pred.shape[1] == 6 and not self.config.xl: - model_pred, _ = torch.chunk(model_pred, 2, dim=1) - - # Get the target for loss depending on the prediction type - if self.noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif self.noise_scheduler.config.prediction_type == "v_prediction": - target = self.noise_scheduler.get_velocity(model_input, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") - - if self.config.prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + self.config.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - return loss - - def _clip_gradients(self): - if self.accelerator.sync_gradients: - if len(self.text_lora_parameters) == 0: - params_to_clip = self.unet_lora_parameters - elif len(self.text_lora_parameters) == 1: - params_to_clip = itertools.chain(self.unet_lora_parameters, self.text_lora_parameters[0]) - elif len(self.text_lora_parameters) == 2: - params_to_clip = itertools.chain( - self.unet_lora_parameters, - self.text_lora_parameters[0], - self.text_lora_parameters[1], - ) - else: - raise ValueError("More than 2 text encoders are not supported.") - self.accelerator.clip_grad_norm_(params_to_clip, self.config.max_grad_norm) - - def _save_checkpoint(self): - if self.accelerator.is_main_process: - if self.global_step % self.config.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if self.config.checkpoints_total_limit is not None: - checkpoints = os.listdir(self.config.project_name) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= self.config.checkpoints_total_limit: - num_to_remove = len(checkpoints) - self.config.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(self.config.project_name, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(self.config.project_name, f"checkpoint-{self.global_step}") - self.accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - def _get_model_pred(self, batch, channels, noisy_model_input, timesteps, bsz): - if self.config.xl: - elems_to_repeat = bsz // 2 if self.config.prior_preservation else bsz - if not self.config.train_text_encoder: - unet_added_conditions = { - "time_ids": self.add_time_ids.repeat(elems_to_repeat, 1), - "text_embeds": self.unet_add_text_embeds.repeat(elems_to_repeat, 1), - } - model_pred = self.unet( - noisy_model_input, - timesteps, - self.prompt_embeds.repeat(elems_to_repeat, 1, 1), - added_cond_kwargs=unet_added_conditions, - ).sample - else: - unet_added_conditions = {"time_ids": self.add_time_ids.repeat(elems_to_repeat, 1)} - prompt_embeds, pooled_prompt_embeds = utils.encode_prompt_xl( - text_encoders=self.text_encoders, - tokenizers=None, - prompt=None, - text_input_ids_list=[self.tokens_one, self.tokens_two], - ) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) - prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) - model_pred = self.unet( - noisy_model_input, - timesteps, - prompt_embeds, - added_cond_kwargs=unet_added_conditions, - ).sample - - else: - if self.config.pre_compute_text_embeddings: - encoder_hidden_states = batch["input_ids"] - else: - encoder_hidden_states = utils.encode_prompt( - self.text_encoder1, - batch["input_ids"], - batch["attention_mask"], - text_encoder_use_attention_mask=self.config.text_encoder_use_attention_mask, - ) - - if self.accelerator.unwrap_model(self.unet).config.in_channels == channels * 2: - noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) - - if self.config.class_labels_conditioning == "timesteps": - class_labels = timesteps - else: - class_labels = None - - model_pred = self.unet( - noisy_model_input, - timesteps, - encoder_hidden_states, - class_labels=class_labels, - ).sample - - return model_pred - - def train(self): - progress_bar = tqdm( - range(self.global_step, self.config.num_steps), - disable=not self.accelerator.is_local_main_process, - ) - progress_bar.set_description("Steps") - - for epoch in range(self.first_epoch, self.config.epochs): - self.unet.train() - - if self.config.train_text_encoder: - self.text_encoder1.train() - if self.config.xl: - self.text_encoder2.train() - - for step, batch in enumerate(self.train_dataloader): - # Skip steps until we reach the resumed step - if self.config.resume_from_checkpoint and epoch == self.first_epoch and step < self.resume_step: - if step % self.config.gradient_accumulation == 0: - progress_bar.update(1) - continue - - with self.accelerator.accumulate(self.unet): - if self.config.xl: - pixel_values = batch["pixel_values"] - else: - pixel_values = batch["pixel_values"].to(dtype=self.weight_dtype) - - if self.vae is not None: - # Convert images to latent space - model_input = self.vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * self.vae.config.scaling_factor - model_input = model_input.to(dtype=self.weight_dtype) - else: - model_input = pixel_values - - # Sample noise that we'll add to the latents - noise = torch.randn_like(model_input) - bsz, channels, height, width = model_input.shape - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - self.noise_scheduler.config.num_train_timesteps, - (bsz,), - device=model_input.device, - ) - timesteps = timesteps.long() - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_model_input = self.noise_scheduler.add_noise(model_input, noise, timesteps) - model_pred = self._get_model_pred(batch, channels, noisy_model_input, timesteps, bsz) - loss = self._calculate_loss(model_pred, noise, model_input, timesteps) - self.accelerator.backward(loss) - - self._clip_gradients() - self.optimizer.step() - self.scheduler.step() - self.optimizer.zero_grad() - - if self.accelerator.sync_gradients: - progress_bar.update(1) - self.global_step += 1 - self._save_checkpoint() - - logs = { - "loss": loss.detach().item(), - "lr": self.scheduler.get_last_lr()[0], - } - progress_bar.set_postfix(**logs) - self.accelerator.log(logs, step=self.global_step) - - if self.global_step >= self.config.num_steps: - break - - self.accelerator.wait_for_everyone() - if self.accelerator.is_main_process: - self.unet = self.accelerator.unwrap_model(self.unet) - self.unet = self.unet.to(torch.float32) - unet_lora_layers = utils.unet_attn_processors_state_dict(self.unet) - text_encoder_lora_layers_1 = None - text_encoder_lora_layers_2 = None - - if self.text_encoder1 is not None and self.config.train_text_encoder: - text_encoder1 = self.accelerator.unwrap_model(self.text_encoder1) - text_encoder1 = text_encoder1.to(torch.float32) - text_encoder_lora_layers_1 = text_encoder_lora_state_dict(text_encoder1) - - if self.text_encoder2 is not None and self.config.train_text_encoder: - text_encoder2 = self.accelerator.unwrap_model(self.text_encoder2) - text_encoder2 = text_encoder2.to(torch.float32) - text_encoder_lora_layers_2 = text_encoder_lora_state_dict(text_encoder2) - - if self.config.xl: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=self.config.project_name, - unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers_1, - text_encoder_2_lora_layers=text_encoder_lora_layers_2, - safe_serialization=True, - ) - else: - LoraLoaderMixin.save_lora_weights( - save_directory=self.config.project_name, - unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers_1, - safe_serialization=True, - ) - self.accelerator.end_training() - - def push_to_hub(self): - repo_id = create_repo( - repo_id=f"{self.config.username}/{self.config.project_name}", - exist_ok=True, - private=True, - token=self.config.token, - ).repo_id - - utils.create_model_card( - repo_id, - base_model=self.config.model, - train_text_encoder=self.config.train_text_encoder, - prompt=self.config.prompt, - repo_folder=self.config.project_name, - ) - upload_folder( - repo_id=repo_id, - folder_path=self.config.project_name, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - token=self.config.token, - ) diff --git a/src/autotrain/trainers/dreambooth/utils.py b/src/autotrain/trainers/dreambooth/utils.py deleted file mode 100644 index b80a2b764a..0000000000 --- a/src/autotrain/trainers/dreambooth/utils.py +++ /dev/null @@ -1,120 +0,0 @@ -import os - -from huggingface_hub import list_models - -from autotrain import logger - - -VALID_IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"] -try: - XL_MODELS = [ - m.id - for m in list( - list_models( - task="text-to-image", - sort="downloads", - limit=200, - direction=-1, - filter=["diffusers:StableDiffusionXLPipeline"], - ) - ) - ] -except Exception: - logger.info("Unable to reach Hugging Face Hub, using default models as XL models.") - XL_MODELS = [ - "stabilityai/stable-diffusion-xl-base-1.0", - "stabilityai/stable-diffusion-xl-base-0.9", - "diffusers/stable-diffusion-xl-base-1.0", - "stabilityai/sdxl-turbo", - ] - - -def save_model_card_xl( - repo_id: str, - base_model=str, - train_text_encoder=False, - instance_prompt=str, - repo_folder=None, - vae_path=None, -): - img_str = "" - yaml = f""" ---- -tags: -- autotrain -- stable-diffusion-xl -- stable-diffusion-xl-diffusers -- text-to-image -- diffusers -- lora -- template:sd-lora -{img_str} -base_model: {base_model} -instance_prompt: {instance_prompt} -license: openrail++ ---- - """ - - model_card = f""" -# AutoTrain SDXL LoRA DreamBooth - {repo_id} - - - -## Model description - -These are {repo_id} LoRA adaption weights for {base_model}. - -The weights were trained using [DreamBooth](https://dreambooth.github.io/). - -LoRA for the text encoder was enabled: {train_text_encoder}. - -Special VAE used for training: {vae_path}. - -## Trigger words - -You should use {instance_prompt} to trigger the image generation. - -## Download model - -Weights for this model are available in Safetensors format. - -[Download]({repo_id}/tree/main) them in the Files & versions tab. - -""" - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) - - -def save_model_card( - repo_id: str, - base_model=str, - train_text_encoder=False, - instance_prompt=str, - repo_folder=None, -): - img_str = "" - model_description = f""" -# AutoTrain LoRA DreamBooth - {repo_id} - -These are LoRA adaption weights for {base_model}. The weights were trained on {instance_prompt} using [DreamBooth](https://dreambooth.github.io/). -LoRA for the text encoder was enabled: {train_text_encoder}. -""" - - yaml = f""" ---- -tags: -- autotrain -- stable-diffusion -- stable-diffusion-diffusers -- text-to-image -- diffusers -- lora -- template:sd-lora -{img_str} -base_model: {base_model} -instance_prompt: {instance_prompt} -license: openrail++ ---- - """ - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_description) diff --git a/src/autotrain/utils.py b/src/autotrain/utils.py index fe1a3306d3..e26cfb2bef 100644 --- a/src/autotrain/utils.py +++ b/src/autotrain/utils.py @@ -4,7 +4,6 @@ from autotrain.commands import launch_command from autotrain.trainers.clm.params import LLMTrainingParams -from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.generic.params import GenericParams from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -51,8 +50,6 @@ def run_training(params, task_id, local=False, wait=False): params = TabularParams(**params) elif task_id == 27: params = GenericParams(**params) - elif task_id == 25: - params = DreamBoothTrainingParams(**params) elif task_id == 18: params = ImageClassificationParams(**params) elif task_id == 4: