-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
perf(stable-diffusion): use DPM++ 2M Karras scheduler and enable atte…
…ntion slicing (#22198)
- Loading branch information
1 parent
a279120
commit 0322b97
Showing
13 changed files
with
104 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,110 @@ | ||
import logging | ||
import random | ||
from datetime import datetime | ||
from pathlib import Path | ||
|
||
import torch | ||
from diffusers import StableDiffusionPipeline | ||
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def main(prompt: str, output_path: Path, num_inference_steps) -> None: | ||
model_id = "runwayml/stable-diffusion-v1-5" | ||
pipe = StableDiffusionPipeline.from_pretrained( | ||
model_id, | ||
torch_dtype=torch.float16, # Use float16 for better memory efficiency | ||
safety_checker=None, | ||
) | ||
class StableDiffusionGenerator: | ||
@staticmethod | ||
def get_device() -> str: | ||
if torch.cuda.is_available(): | ||
return "cuda" | ||
if torch.backends.mps.is_available(): | ||
return "mps" | ||
return "cpu" | ||
|
||
@staticmethod | ||
def create_pipeline(model_id: str) -> StableDiffusionPipeline: | ||
pipe = StableDiffusionPipeline.from_pretrained( | ||
model_id, | ||
torch_dtype=torch.float16, | ||
safety_checker=None, | ||
) | ||
|
||
# Use DPM++ 2M Karras scheduler | ||
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | ||
pipe.scheduler.config, algorithm_type="dpmsolver++", use_karras_sigmas=True | ||
) | ||
|
||
device = StableDiffusionGenerator.get_device() | ||
pipe = pipe.to(device) | ||
|
||
if torch.cuda.is_available(): | ||
pipe = pipe.to("cuda") | ||
# Enable attention slicing to reduce memory usage by processing attention in chunks | ||
pipe.enable_attention_slicing() | ||
return pipe | ||
|
||
try: | ||
image = pipe( | ||
prompt, | ||
num_inference_steps=num_inference_steps, | ||
guidance_scale=7.5, # Controls how much the image generation follows the prompt | ||
).images[0] | ||
@staticmethod | ||
def generate( | ||
pipe: StableDiffusionPipeline, | ||
prompt: str, | ||
output_dir: Path, | ||
negative_prompt: str | None = None, | ||
image_number: int = 1, | ||
width: int = 768, | ||
height: int = 768, | ||
inference_step_number: int = 50, | ||
guidance_scale: float = 7.5, | ||
seed: int | None = None, | ||
) -> None: | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
image.save(output_path) | ||
logger.info(f"Image successfully generated and saved to {output_path}") | ||
# Set up seed for reproducibility | ||
if seed is None: | ||
seed = random.randint(0, 2**32 - 1) | ||
|
||
except Exception as e: | ||
logger.exception(f"Error generating image: {str(e)}") | ||
device = StableDiffusionGenerator.get_device() | ||
generator = torch.Generator(device).manual_seed(seed) | ||
try: | ||
# Generate images | ||
result = pipe( | ||
prompt=prompt, | ||
negative_prompt=negative_prompt, | ||
num_images_per_prompt=image_number, | ||
width=width, | ||
height=height, | ||
inference_step_number=inference_step_number, | ||
guidance_scale=guidance_scale, | ||
generator=generator, | ||
) | ||
|
||
# Save images | ||
for idx, image in enumerate(result.images): | ||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||
path = output_dir / f"{timestamp}_{idx}.png" | ||
image.save(path) | ||
|
||
except Exception as e: | ||
logger.exception(f"Error generating images: {e}") | ||
|
||
|
||
def main(): | ||
pipe = StableDiffusionGenerator.create_pipeline( | ||
model_id="stabilityai/stable-diffusion-2-1" | ||
) | ||
|
||
StableDiffusionGenerator.generate( | ||
pipe=pipe, | ||
prompt="A beautiful sunset over mountains, highly detailed, majestic", | ||
negative_prompt="blur, low quality, bad anatomy, worst quality, low resolution, watermark, text, signature, copyright, logo, brand name", | ||
image_number=1, | ||
# Stable Diffusion 2 default is 768x768 | ||
width=768, | ||
height=768, | ||
inference_step_number=50, | ||
# Controls how much the image generation follows the prompt. Higher values = more prompt adherence | ||
guidance_scale=7.5, | ||
seed=None, | ||
output_dir=Path("output"), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig( | ||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | ||
) | ||
|
||
prompt = "A beautiful sunset over mountains" | ||
output_path = Path("image.png") | ||
|
||
main(prompt=prompt, output_path=output_path, num_inference_steps=50) | ||
main() |