Skip to content

Commit

Permalink
perf(stable-diffusion): use DPM++ 2M Karras scheduler and enable atte…
Browse files Browse the repository at this point in the history
…ntion slicing (#22198)
  • Loading branch information
hongbo-miao authored Jan 2, 2025
1 parent a279120 commit 0322b97
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 25 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ machine-learning/hm-supervision/*/data
machine-learning/mineru/data
machine-learning/mineru/output
machine-learning/neural-forecasting/*/lightning_logs
machine-learning/stable-diffusion/output
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data
mobile/mobile-android/.gradle
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ machine-learning/hm-supervision/*/data/**/*
machine-learning/mineru/data/**/*
machine-learning/mineru/output/**/*
machine-learning/neural-forecasting/*/lightning_logs/**/*
machine-learning/stable-diffusion/output/**/*
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data/**/*
mobile/mobile-android/.gradle/**/*
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions .markdownlint-cli2.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
"machine-learning/mineru/data",
"machine-learning/mineru/output",
"machine-learning/neural-forecasting/*/lightning_logs",
"machine-learning/stable-diffusion/output",
"machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data",
"mobile/mobile-android/.gradle",
"mobile/mobile-android/local.properties",
Expand Down
1 change: 1 addition & 0 deletions .prettierignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ machine-learning/hm-supervision/*/data
machine-learning/mineru/data
machine-learning/mineru/output
machine-learning/neural-forecasting/*/lightning_logs
machine-learning/stable-diffusion/output
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data
mobile/mobile-android/.gradle
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions .rubocop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ AllCops:
- 'machine-learning/mineru/data/**/*'
- 'machine-learning/mineru/output/**/*'
- 'machine-learning/neural-forecasting/*/lightning_logs/**/*'
- 'machine-learning/stable-diffusion/output/**/*'
- 'machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data/**/*'
- 'mobile/mobile-android/.gradle/**/*'
- 'mobile/mobile-android/local.properties'
Expand Down
1 change: 1 addition & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ exclude = [
"machine-learning/mineru/data",
"machine-learning/mineru/output",
"machine-learning/neural-forecasting/*/lightning_logs",
"machine-learning/stable-diffusion/output",
"machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data",
"mobile/mobile-android/.gradle",
"mobile/mobile-android/local.properties",
Expand Down
1 change: 1 addition & 0 deletions .solhintignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ machine-learning/hm-supervision/*/data
machine-learning/mineru/data
machine-learning/mineru/output
machine-learning/neural-forecasting/*/lightning_logs
machine-learning/stable-diffusion/output
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data
mobile/mobile-android/.gradle
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions .sqlfluffignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ machine-learning/hm-supervision/*/data
machine-learning/mineru/data
machine-learning/mineru/output
machine-learning/neural-forecasting/*/lightning_logs
machine-learning/stable-diffusion/output
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data
mobile/mobile-android/.gradle
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions .stylelintignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ machine-learning/hm-supervision/*/data
machine-learning/mineru/data
machine-learning/mineru/output
machine-learning/neural-forecasting/*/lightning_logs
machine-learning/stable-diffusion/output
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data
mobile/mobile-android/.gradle
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions .textlintignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ machine-learning/hm-supervision/*/data/**/*
machine-learning/mineru/data/**/*
machine-learning/mineru/output/**/*
machine-learning/neural-forecasting/*/lightning_logs/**/*
machine-learning/stable-diffusion/output/**/*
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data/**/*
mobile/mobile-android/.gradle/**/*
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions .yamllint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ ignore: |
machine-learning/mineru/data
machine-learning/mineru/output
machine-learning/neural-forecasting/*/lightning_logs
machine-learning/stable-diffusion/output
machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data
mobile/mobile-android/.gradle
mobile/mobile-android/local.properties
Expand Down
1 change: 1 addition & 0 deletions eslint.config.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ export default [
'machine-learning/mineru/data',
'machine-learning/mineru/output',
'machine-learning/neural-forecasting/*/lightning_logs',
'machine-learning/stable-diffusion/output',
'machine-learning/triton/amazon-sagemaker-triton-resnet-50/infer/data',
'mobile/mobile-android/.gradle',
'mobile/mobile-android/local.properties',
Expand Down
117 changes: 92 additions & 25 deletions machine-learning/stable-diffusion/src/main.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,110 @@
import logging
import random
from datetime import datetime
from pathlib import Path

import torch
from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline

logger = logging.getLogger(__name__)


def main(prompt: str, output_path: Path, num_inference_steps) -> None:
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 for better memory efficiency
safety_checker=None,
)
class StableDiffusionGenerator:
@staticmethod
def get_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"

@staticmethod
def create_pipeline(model_id: str) -> StableDiffusionPipeline:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
safety_checker=None,
)

# Use DPM++ 2M Karras scheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config, algorithm_type="dpmsolver++", use_karras_sigmas=True
)

device = StableDiffusionGenerator.get_device()
pipe = pipe.to(device)

if torch.cuda.is_available():
pipe = pipe.to("cuda")
# Enable attention slicing to reduce memory usage by processing attention in chunks
pipe.enable_attention_slicing()
return pipe

try:
image = pipe(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=7.5, # Controls how much the image generation follows the prompt
).images[0]
@staticmethod
def generate(
pipe: StableDiffusionPipeline,
prompt: str,
output_dir: Path,
negative_prompt: str | None = None,
image_number: int = 1,
width: int = 768,
height: int = 768,
inference_step_number: int = 50,
guidance_scale: float = 7.5,
seed: int | None = None,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)

image.save(output_path)
logger.info(f"Image successfully generated and saved to {output_path}")
# Set up seed for reproducibility
if seed is None:
seed = random.randint(0, 2**32 - 1)

except Exception as e:
logger.exception(f"Error generating image: {str(e)}")
device = StableDiffusionGenerator.get_device()
generator = torch.Generator(device).manual_seed(seed)
try:
# Generate images
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=image_number,
width=width,
height=height,
inference_step_number=inference_step_number,
guidance_scale=guidance_scale,
generator=generator,
)

# Save images
for idx, image in enumerate(result.images):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
path = output_dir / f"{timestamp}_{idx}.png"
image.save(path)

except Exception as e:
logger.exception(f"Error generating images: {e}")


def main():
pipe = StableDiffusionGenerator.create_pipeline(
model_id="stabilityai/stable-diffusion-2-1"
)

StableDiffusionGenerator.generate(
pipe=pipe,
prompt="A beautiful sunset over mountains, highly detailed, majestic",
negative_prompt="blur, low quality, bad anatomy, worst quality, low resolution, watermark, text, signature, copyright, logo, brand name",
image_number=1,
# Stable Diffusion 2 default is 768x768
width=768,
height=768,
inference_step_number=50,
# Controls how much the image generation follows the prompt. Higher values = more prompt adherence
guidance_scale=7.5,
seed=None,
output_dir=Path("output"),
)


if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

prompt = "A beautiful sunset over mountains"
output_path = Path("image.png")

main(prompt=prompt, output_path=output_path, num_inference_steps=50)
main()

0 comments on commit 0322b97

Please sign in to comment.