-
Notifications
You must be signed in to change notification settings - Fork 13
/
sd_pipeline_call.py
146 lines (123 loc) · 5.93 KB
/
sd_pipeline_call.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline
@torch.no_grad()
def sd_pipeline_call(
pipeline: StableDiffusionPipeline,
prompt_embeds: torch.FloatTensor,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None):
""" Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument."""
# 0. Default height and width to unet
height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
# 2. Define call parameters
batch_size = 1
device = pipeline._execution_device
neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
negative_prompt_embeds, _ = pipeline.text_encoder(
input_ids=neg_prompt.input_ids.to(device),
attention_mask=None,
)
negative_prompt_embeds = negative_prompt_embeds[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 4. Prepare timesteps
pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipeline.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = pipeline.unet.in_channels
latents = pipeline.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
pipeline.text_encoder.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if do_classifier_free_guidance:
latent_model_input = latents
latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred_uncond = pipeline.unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
cross_attention_kwargs=cross_attention_kwargs,
).sample
###############################################################
# NeTI logic: use the prompt embedding for the current timestep
###############################################################
embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds
noise_pred_text = pipeline.unet(
latent_model_input,
t,
encoder_hidden_states=embed,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = pipeline.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
# 10. Convert to PIL
image = pipeline.numpy_to_pil(image)
else:
# 8. Post-processing
image = pipeline.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
# Offload last model to CPU
if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
pipeline.final_offload_hook.offload()
if not return_dict:
return image, has_nsfw_concept
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline,
negative_prompt: Optional[Union[str, List[str]]] = None):
if negative_prompt is None:
negative_prompt = ""
uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
uncond_input = pipeline.tokenizer(
uncond_tokens,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
return uncond_input