From e16da6360d182e4da65e31d2521e9d6cfdf0b71e Mon Sep 17 00:00:00 2001 From: abhishekkrthakur Date: Tue, 26 Nov 2024 16:40:02 +0100 Subject: [PATCH] try diffusers update --- requirements.txt | 2 +- src/autotrain/trainers/dreambooth/train.py | 16 +++- src/autotrain/trainers/dreambooth/train_xl.py | 92 ------------------- 3 files changed, 12 insertions(+), 98 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9e04c3d3d9..56380abe79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,7 +29,7 @@ trl==0.12.0 tiktoken==0.6.0 transformers==4.46.2 accelerate==1.1.1 -diffusers==0.27.2 +diffusers==0.31.0 bitsandbytes==0.44.1 # extras rouge_score==0.1.2 diff --git a/src/autotrain/trainers/dreambooth/train.py b/src/autotrain/trainers/dreambooth/train.py index bedabddbb3..0131f9629a 100644 --- a/src/autotrain/trainers/dreambooth/train.py +++ b/src/autotrain/trainers/dreambooth/train.py @@ -37,7 +37,7 @@ DPMSolverMultistepScheduler, UNet2DConditionModel, ) -from diffusers.loaders import LoraLoaderMixin +from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params from diffusers.utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft, is_wandb_available @@ -63,6 +63,7 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, # Add torch_dtype parameter is_final_validation=False, ): logger.info( @@ -82,7 +83,7 @@ def log_validation( pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) # Use torch_dtype pipeline.set_progress_bar_config(disable=True) # run inference @@ -340,6 +341,10 @@ def main(args): project_config=accelerator_project_config, ) + # Add MPS support check + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") @@ -545,7 +550,7 @@ def save_model_hook(models, weights, output_dir): # make sure to pop weight so that corresponding model is not saved again weights.pop() - LoraLoaderMixin.save_lora_weights( + StableDiffusionLoraLoaderMixin.save_lora_weights( output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save, @@ -565,7 +570,7 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) @@ -948,6 +953,7 @@ def compute_text_embeddings(prompt): accelerator, pipeline_args, epoch, + torch_dtype=weight_dtype, ) # Save the lora layers @@ -964,7 +970,7 @@ def compute_text_embeddings(prompt): else: text_encoder_state_dict = None - LoraLoaderMixin.save_lora_weights( + StableDiffusionLoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=text_encoder_state_dict, diff --git a/src/autotrain/trainers/dreambooth/train_xl.py b/src/autotrain/trainers/dreambooth/train_xl.py index 467ee9ddc5..d0b844537c 100644 --- a/src/autotrain/trainers/dreambooth/train_xl.py +++ b/src/autotrain/trainers/dreambooth/train_xl.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -import contextlib import gc import itertools import json @@ -26,7 +25,6 @@ from pathlib import Path import diffusers -import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -36,7 +34,6 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, - DPMSolverMultistepScheduler, EDMEulerScheduler, EulerDiscreteScheduler, StableDiffusionXLPipeline, @@ -78,59 +75,6 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision): return scheduler_type -def log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - is_final_validation=False, -): - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if not args.do_edm_style_training: - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better - # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 - inference_ctx = ( - contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() - ) - - with inference_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] - - for tracker in accelerator.trackers: - phase_name = "test" if is_final_validation else "validation" - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") - - del pipeline - torch.cuda.empty_cache() - - return images - - def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): @@ -1239,42 +1183,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if global_step >= args.max_train_steps: break - if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - # create pipeline - if not args.train_text_encoder: - text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - variant=args.variant, - ) - text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - revision=args.revision, - variant=args.variant, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - unet=accelerator.unwrap_model(unet), - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline_args = {"prompt": args.validation_prompt} - - images = log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - ) - # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: