From dd7a750d21b01a440ca7e7f3b4956313d008b7b4 Mon Sep 17 00:00:00 2001 From: zsxkib Date: Mon, 22 Apr 2024 09:40:24 +0000 Subject: [PATCH] webp support added (output_format + output_quality params added) --- .dockerignore | 1 + cog/predict.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/.dockerignore b/.dockerignore index 6b27982f..9a406309 100644 --- a/.dockerignore +++ b/.dockerignore @@ -32,6 +32,7 @@ safety-cache/ # Exclude Python cache files __pycache__ +.pytest_cache/ .mypy_cache .pytest_cache .ruff_cache diff --git a/cog/predict.py b/cog/predict.py index bd7dcdd9..e839f88b 100644 --- a/cog/predict.py +++ b/cog/predict.py @@ -10,6 +10,7 @@ import cv2 import time import torch +import mimetypes import subprocess import numpy as np from typing import List @@ -40,6 +41,8 @@ draw_kps, ) +mimetypes.add_type("image/webp", ".webp") + # GPU global variables DEVICE = get_torch_device() DTYPE = torch.float16 if str(DEVICE).__contains__("cuda") else torch.float32 @@ -626,6 +629,17 @@ def predict( enhance_nonface_region: bool = Input( description="Enhance non-face region", default=True ), + output_format: str = Input( + description="Format of the output images", + choices=["webp", "jpg", "png"], + default="webp", + ), + output_quality: int = Input( + description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.", + default=80, + ge=0, + le=100, + ), seed: int = Input( description="Random seed. Leave blank to randomize the seed", default=None, @@ -691,7 +705,21 @@ def predict( raise Exception( "NSFW content detected. Try running it again, or try a different prompt." ) - output_path = f"/tmp/out_{i}.png" - output_image.save(output_path) + + extension = output_format.lower() + extension = "jpeg" if extension == "jpg" else extension + output_path = f"/tmp/out_{i}.{extension}" + + print(f"[~] Saving to {output_path}...") + print(f"[~] Output format: {extension.upper()}") + if output_format != "png": + print(f"[~] Output quality: {output_quality}") + + save_params = {"format": extension.upper()} + if output_format != "png": + save_params["quality"] = output_quality + save_params["optimize"] = True + + output_image.save(output_path, **save_params) output_paths.append(Path(output_path)) return output_paths