Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try diffusers update #814

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trl==0.12.0
tiktoken==0.6.0
transformers==4.46.2
accelerate==1.1.1
diffusers==0.27.2
diffusers==0.31.0
bitsandbytes==0.44.1
# extras
rouge_score==0.1.2
Expand Down
16 changes: 11 additions & 5 deletions src/autotrain/trainers/dreambooth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
DPMSolverMultistepScheduler,
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft, is_wandb_available
Expand All @@ -63,6 +63,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype, # Add torch_dtype parameter
is_final_validation=False,
):
logger.info(
Expand All @@ -82,7 +83,7 @@ def log_validation(

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) # Use torch_dtype
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -340,6 +341,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Add MPS support check
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down Expand Up @@ -545,7 +550,7 @@ def save_model_hook(models, weights, output_dir):
# make sure to pop weight so that corresponding model is not saved again
weights.pop()

LoraLoaderMixin.save_lora_weights(
StableDiffusionLoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
Expand All @@ -565,7 +570,7 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)

unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
Expand Down Expand Up @@ -948,6 +953,7 @@ def compute_text_embeddings(prompt):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)

# Save the lora layers
Expand All @@ -964,7 +970,7 @@ def compute_text_embeddings(prompt):
else:
text_encoder_state_dict = None

LoraLoaderMixin.save_lora_weights(
StableDiffusionLoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
Expand Down
92 changes: 0 additions & 92 deletions src/autotrain/trainers/dreambooth/train_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import contextlib
import gc
import itertools
import json
Expand All @@ -26,7 +25,6 @@
from pathlib import Path

import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand All @@ -36,7 +34,6 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMEulerScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
Expand Down Expand Up @@ -78,59 +75,6 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
return scheduler_type


def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)

# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}

if not args.do_edm_style_training:
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type

if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"

scheduler_args["variance_type"] = variance_type

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)

with inference_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")

del pipeline
torch.cuda.empty_cache()

return images


def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
Expand Down Expand Up @@ -1239,42 +1183,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if global_step >= args.max_train_steps:
break

if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}

images = log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
)

# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand Down
Loading