-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Summary Base code of new modular backend from #6577. Contains normal generation and regional prompts support. Also preview extension included to test if extensions logic works. ## Related Issues / Discussions https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. Currently only normal and regional conditionings supported, so just generate some images and compare with main output. ## Merge Plan Discuss a bit more about injection point names? As if for example in future unet will be overridable, current `pre_unet`/`post_unet` assumes to name override as `unet` what feels a bit odd. Also `apply_cfg` - future implementation could ignore/not use cfg, so in this case `combine_noise_predictions`/`combine_noise` seems more suitable. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
- Loading branch information
Showing
13 changed files
with
908 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union | ||
|
||
import torch | ||
from diffusers import UNet2DConditionModel | ||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | ||
|
||
if TYPE_CHECKING: | ||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData | ||
|
||
|
||
@dataclass | ||
class UNetKwargs: | ||
sample: torch.Tensor | ||
timestep: Union[torch.Tensor, float, int] | ||
encoder_hidden_states: torch.Tensor | ||
|
||
class_labels: Optional[torch.Tensor] = None | ||
timestep_cond: Optional[torch.Tensor] = None | ||
attention_mask: Optional[torch.Tensor] = None | ||
cross_attention_kwargs: Optional[Dict[str, Any]] = None | ||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None | ||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None | ||
mid_block_additional_residual: Optional[torch.Tensor] = None | ||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None | ||
encoder_attention_mask: Optional[torch.Tensor] = None | ||
# return_dict: bool = True | ||
|
||
|
||
@dataclass | ||
class DenoiseInputs: | ||
"""Initial variables passed to denoise. Supposed to be unchanged.""" | ||
|
||
# The latent-space image to denoise. | ||
# Shape: [batch, channels, latent_height, latent_width] | ||
# - If we are inpainting, this is the initial latent image before noise has been added. | ||
# - If we are generating a new image, this should be initialized to zeros. | ||
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). | ||
orig_latents: torch.Tensor | ||
|
||
# kwargs forwarded to the scheduler.step() method. | ||
scheduler_step_kwargs: dict[str, Any] | ||
|
||
# Text conditionging data. | ||
conditioning_data: TextConditioningData | ||
|
||
# Noise used for two purposes: | ||
# 1. Used by the scheduler to noise the initial `latents` before denoising. | ||
# 2. Used to noise the `masked_latents` when inpainting. | ||
# `noise` should be None if the `latents` tensor has already been noised. | ||
# Shape: [1 or batch, channels, latent_height, latent_width] | ||
noise: Optional[torch.Tensor] | ||
|
||
# The seed used to generate the noise for the denoising process. | ||
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the | ||
# same noise used earlier in the pipeline. This should really be handled in a clearer way. | ||
seed: int | ||
|
||
# The timestep schedule for the denoising process. | ||
timesteps: torch.Tensor | ||
|
||
# The first timestep in the schedule. This is used to determine the initial noise level, so | ||
# should be populated if you want noise applied *even* if timesteps is empty. | ||
init_timestep: torch.Tensor | ||
|
||
# Class of attention processor that is used. | ||
attention_processor_cls: Type[Any] | ||
|
||
|
||
@dataclass | ||
class DenoiseContext: | ||
"""Context with all variables in denoise""" | ||
|
||
# Initial variables passed to denoise. Supposed to be unchanged. | ||
inputs: DenoiseInputs | ||
|
||
# Scheduler which used to apply noise predictions. | ||
scheduler: SchedulerMixin | ||
|
||
# UNet model. | ||
unet: Optional[UNet2DConditionModel] = None | ||
|
||
# Current state of latent-space image in denoising process. | ||
# None until `pre_denoise_loop` callback. | ||
# Shape: [batch, channels, latent_height, latent_width] | ||
latents: Optional[torch.Tensor] = None | ||
|
||
# Current denoising step index. | ||
# None until `pre_step` callback. | ||
step_index: Optional[int] = None | ||
|
||
# Current denoising step timestep. | ||
# None until `pre_step` callback. | ||
timestep: Optional[torch.Tensor] = None | ||
|
||
# Arguments which will be passed to UNet model. | ||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. | ||
unet_kwargs: Optional[UNetKwargs] = None | ||
|
||
# SchedulerOutput class returned from step function(normally, generated by scheduler). | ||
# Supposed to be used only in `post_step` callback, otherwise can be None. | ||
step_output: Optional[SchedulerOutput] = None | ||
|
||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization. | ||
# Available in events inside step(between `pre_step` and `post_stop`). | ||
# Shape: [batch, channels, latent_height, latent_width] | ||
latent_model_input: Optional[torch.Tensor] = None | ||
|
||
# [TMP] Defines on which conditionings current unet call will be runned. | ||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. | ||
conditioning_mode: Optional[ConditioningMode] = None | ||
|
||
# [TMP] Noise predictions from negative conditioning. | ||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. | ||
# Shape: [batch, channels, latent_height, latent_width] | ||
negative_noise_pred: Optional[torch.Tensor] = None | ||
|
||
# [TMP] Noise predictions from positive conditioning. | ||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. | ||
# Shape: [batch, channels, latent_height, latent_width] | ||
positive_noise_pred: Optional[torch.Tensor] = None | ||
|
||
# Combined noise prediction from passed conditionings. | ||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. | ||
# Shape: [batch, channels, latent_height, latent_width] | ||
noise_pred: Optional[torch.Tensor] = None | ||
|
||
# Dictionary for extensions to pass extra info about denoise process to other extensions. | ||
extra: dict = field(default_factory=dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.