Skip to content

Commit

Permalink
try diffusers update
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Nov 26, 2024
1 parent c413203 commit e16da63
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 98 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trl==0.12.0
tiktoken==0.6.0
transformers==4.46.2
accelerate==1.1.1
diffusers==0.27.2
diffusers==0.31.0
bitsandbytes==0.44.1
# extras
rouge_score==0.1.2
Expand Down
16 changes: 11 additions & 5 deletions src/autotrain/trainers/dreambooth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
DPMSolverMultistepScheduler,
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft, is_wandb_available
Expand All @@ -63,6 +63,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype, # Add torch_dtype parameter
is_final_validation=False,
):
logger.info(
Expand All @@ -82,7 +83,7 @@ def log_validation(

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) # Use torch_dtype
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -340,6 +341,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Add MPS support check
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down Expand Up @@ -545,7 +550,7 @@ def save_model_hook(models, weights, output_dir):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()

LoraLoaderMixin.save_lora_weights(
StableDiffusionLoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
Expand All @@ -565,7 +570,7 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)

unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
Expand Down Expand Up @@ -948,6 +953,7 @@ def compute_text_embeddings(prompt):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)

# Save the lora layers
Expand All @@ -964,7 +970,7 @@ def compute_text_embeddings(prompt):
else:
text_encoder_state_dict = None

LoraLoaderMixin.save_lora_weights(
StableDiffusionLoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
Expand Down
92 changes: 0 additions & 92 deletions src/autotrain/trainers/dreambooth/train_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import contextlib
import gc
import itertools
import json
Expand All @@ -26,7 +25,6 @@
from pathlib import Path

import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand All @@ -36,7 +34,6 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMEulerScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
Expand Down Expand Up @@ -78,59 +75,6 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
return scheduler_type


def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)

# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}

if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type

if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"

scheduler_args["variance_type"] = variance_type

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)

with inference_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")

del pipeline
torch.cuda.empty_cache()

return images


def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
Expand Down Expand Up @@ -1239,42 +1183,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if global_step >= args.max_train_steps:
break

if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}

images = log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
)

# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand Down

0 comments on commit e16da63

Please sign in to comment.