Skip to content

Commit

Permalink
webp support added (output_format + output_quality params added)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Apr 22, 2024
1 parent aa9fa62 commit dd7a750
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ safety-cache/

# Exclude Python cache files
__pycache__
.pytest_cache/
.mypy_cache
.pytest_cache
.ruff_cache
Expand Down
32 changes: 30 additions & 2 deletions cog/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cv2
import time
import torch
import mimetypes
import subprocess
import numpy as np
from typing import List
Expand Down Expand Up @@ -40,6 +41,8 @@
draw_kps,
)

mimetypes.add_type("image/webp", ".webp")

# GPU global variables
DEVICE = get_torch_device()
DTYPE = torch.float16 if str(DEVICE).__contains__("cuda") else torch.float32
Expand Down Expand Up @@ -626,6 +629,17 @@ def predict(
enhance_nonface_region: bool = Input(
description="Enhance non-face region", default=True
),
output_format: str = Input(
description="Format of the output images",
choices=["webp", "jpg", "png"],
default="webp",
),
output_quality: int = Input(
description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.",
default=80,
ge=0,
le=100,
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed",
default=None,
Expand Down Expand Up @@ -691,7 +705,21 @@ def predict(
raise Exception(
"NSFW content detected. Try running it again, or try a different prompt."
)
output_path = f"/tmp/out_{i}.png"
output_image.save(output_path)

extension = output_format.lower()
extension = "jpeg" if extension == "jpg" else extension
output_path = f"/tmp/out_{i}.{extension}"

print(f"[~] Saving to {output_path}...")
print(f"[~] Output format: {extension.upper()}")
if output_format != "png":
print(f"[~] Output quality: {output_quality}")

save_params = {"format": extension.upper()}
if output_format != "png":
save_params["quality"] = output_quality
save_params["optimize"] = True

output_image.save(output_path, **save_params)
output_paths.append(Path(output_path))
return output_paths

0 comments on commit dd7a750

Please sign in to comment.