Skip to content

Commit

Permalink
Base of modular backend (#6606)
Browse files Browse the repository at this point in the history
## Summary

Base code of new modular backend from #6577.
Contains normal generation and regional prompts support.
Also preview extension included to test if extensions logic works.

## Related Issues / Discussions


https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.
Currently only normal and regional conditionings supported, so just
generate some images and compare with main output.

## Merge Plan

Discuss a bit more about injection point names?
As if for example in future unet will be overridable, current
`pre_unet`/`post_unet` assumes to name override as `unet` what feels a
bit odd.
Also `apply_cfg` - future implementation could ignore/not use cfg, so in
this case `combine_noise_predictions`/`combine_noise` seems more
suitable.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
  • Loading branch information
RyanJDick authored Jul 19, 2024
2 parents e16faa6 + 78d2b1b commit 473f4cc
Show file tree
Hide file tree
Showing 13 changed files with 908 additions and 25 deletions.
122 changes: 115 additions & 7 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -39,6 +40,7 @@
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
Expand All @@ -53,6 +55,11 @@
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -314,9 +321,10 @@ def get_conditioning_data(
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
Expand All @@ -330,25 +338,25 @@ def get_conditioning_data(
uncond_list = [uncond_list]

cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
cond_list, context, device, dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
uncond_list, context, device, dtype
)

cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)

if isinstance(cfg_scale, list):
Expand Down Expand Up @@ -707,9 +715,108 @@ def prepare_noise_and_latents(

return seed, noise, latents

def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)

device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()

seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)

timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)

denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.detach().to("cpu")
TorchDevice.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
Expand Down Expand Up @@ -788,7 +895,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
Expand Down
23 changes: 21 additions & 2 deletions invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union

import numpy as np
import torch
Expand All @@ -32,8 +32,27 @@
"""


# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.
Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors

# create separate instance for each attention, to be able modify each attention separately
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
try:
unet.set_attn_processor(unet_new_processors)
yield None

finally:
unet.set_attn_processor(unet_orig_processors)

@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
Expand Down
131 changes: 131 additions & 0 deletions invokeai/backend/stable_diffusion/denoise_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData


@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor

class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True


@dataclass
class DenoiseInputs:
"""Initial variables passed to denoise. Supposed to be unchanged."""

# The latent-space image to denoise.
# Shape: [batch, channels, latent_height, latent_width]
# - If we are inpainting, this is the initial latent image before noise has been added.
# - If we are generating a new image, this should be initialized to zeros.
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
orig_latents: torch.Tensor

# kwargs forwarded to the scheduler.step() method.
scheduler_step_kwargs: dict[str, Any]

# Text conditionging data.
conditioning_data: TextConditioningData

# Noise used for two purposes:
# 1. Used by the scheduler to noise the initial `latents` before denoising.
# 2. Used to noise the `masked_latents` when inpainting.
# `noise` should be None if the `latents` tensor has already been noised.
# Shape: [1 or batch, channels, latent_height, latent_width]
noise: Optional[torch.Tensor]

# The seed used to generate the noise for the denoising process.
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
seed: int

# The timestep schedule for the denoising process.
timesteps: torch.Tensor

# The first timestep in the schedule. This is used to determine the initial noise level, so
# should be populated if you want noise applied *even* if timesteps is empty.
init_timestep: torch.Tensor

# Class of attention processor that is used.
attention_processor_cls: Type[Any]


@dataclass
class DenoiseContext:
"""Context with all variables in denoise"""

# Initial variables passed to denoise. Supposed to be unchanged.
inputs: DenoiseInputs

# Scheduler which used to apply noise predictions.
scheduler: SchedulerMixin

# UNet model.
unet: Optional[UNet2DConditionModel] = None

# Current state of latent-space image in denoising process.
# None until `pre_denoise_loop` callback.
# Shape: [batch, channels, latent_height, latent_width]
latents: Optional[torch.Tensor] = None

# Current denoising step index.
# None until `pre_step` callback.
step_index: Optional[int] = None

# Current denoising step timestep.
# None until `pre_step` callback.
timestep: Optional[torch.Tensor] = None

# Arguments which will be passed to UNet model.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
unet_kwargs: Optional[UNetKwargs] = None

# SchedulerOutput class returned from step function(normally, generated by scheduler).
# Supposed to be used only in `post_step` callback, otherwise can be None.
step_output: Optional[SchedulerOutput] = None

# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
# Available in events inside step(between `pre_step` and `post_stop`).
# Shape: [batch, channels, latent_height, latent_width]
latent_model_input: Optional[torch.Tensor] = None

# [TMP] Defines on which conditionings current unet call will be runned.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
conditioning_mode: Optional[ConditioningMode] = None

# [TMP] Noise predictions from negative conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
negative_noise_pred: Optional[torch.Tensor] = None

# [TMP] Noise predictions from positive conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
positive_noise_pred: Optional[torch.Tensor] = None

# Combined noise prediction from passed conditionings.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
noise_pred: Optional[torch.Tensor] = None

# Dictionary for extensions to pass extra info about denoise process to other extensions.
extra: dict = field(default_factory=dict)
11 changes: 1 addition & 10 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,12 @@
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel


@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None


@dataclass
class AddsMaskGuidance:
mask: torch.Tensor
Expand Down
Loading

0 comments on commit 473f4cc

Please sign in to comment.