Skip to content

Commit

Permalink
Merge pull request #289 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: improve hires fix efficiency; fix: better SDXL hires fix default implementation; fix: cascade bug (#287)
  • Loading branch information
tazlin authored Jul 21, 2024
2 parents 2a55de9 + b35d8bf commit f893800
Show file tree
Hide file tree
Showing 35 changed files with 663 additions and 110 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ repos:
types-tabulate,
types-tqdm,
types-urllib3,
horde_sdk==0.12.0,
horde_sdk==0.14.0,
horde_model_reference==0.8.1,
]
7 changes: 2 additions & 5 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,8 @@ def __init__(
# Load our pipelines
self._load_pipelines()

stdio = OutputCollector()
with contextlib.redirect_stdout(stdio):
# Load our custom nodes
self._load_custom_nodes()
stdio.replay()
# Load our custom nodes
self._load_custom_nodes()

self._comfyui_callback = comfyui_callback

Expand Down
89 changes: 69 additions & 20 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from enum import Enum, auto
from types import FunctionType

from horde_model_reference.meta_consts import STABLE_DIFFUSION_BASELINE_CATEGORY, get_baseline_native_resolution
from horde_sdk.ai_horde_api.apimodels import ImageGenerateJobPopResponse
from horde_sdk.ai_horde_api.apimodels.base import (
GenMetadataEntry,
Expand Down Expand Up @@ -78,15 +79,55 @@ def __init__(
self.faults = faults


def _calc_upscale_sampler_steps(payload):
"""Calculates the amount of hires_fix upscaler steps based on the denoising used and the steps used for the
primary image"""
upscale_steps = round(payload["ddim_steps"] * (0.9 - payload["hires_fix_denoising_strength"]))
if upscale_steps < 3:
upscale_steps = 3
def _calc_upscale_sampler_steps(
payload: dict,
) -> int:
"""Use `ImageUtils.calc_upscale_sampler_steps(...)` to calculate the number of steps for the upscale sampler.
logger.debug(f"Upscale steps calculated as {upscale_steps}")
return upscale_steps
Args:
payload (dict): The payload to use for the calculation.
Returns:
int: The number of steps to use.
"""
model_name = payload.get("model_name")
baseline = None
native_resolution = None
if model_name is not None:
baseline = SharedModelManager.model_reference_manager.stable_diffusion.get_model_baseline(model_name)
if baseline is not None:
try:
baseline = STABLE_DIFFUSION_BASELINE_CATEGORY(baseline)
except ValueError:
baseline = None
logger.warning(
f"Model {model_name} has an invalid baseline {baseline} so we cannot calculate "
"hires fix upscale steps.",
)
if baseline is not None:
native_resolution = get_baseline_native_resolution(baseline)

width: int | None = payload.get("width")
height: int | None = payload.get("height")
hires_fix_denoising_strength: float | None = payload.get("hires_fix_denoising_strength")
ddim_steps: int | None = payload.get("ddim_steps")

if width is None or height is None:
raise ValueError("Width and height must be set to calculate upscale sampler steps")

if hires_fix_denoising_strength is None:
raise ValueError("Hires fix denoising strength must be set to calculate upscale sampler steps")

if ddim_steps is None:
raise ValueError("DDIM steps must be set to calculate upscale sampler steps")

return ImageUtils.calc_upscale_sampler_steps(
model_native_resolution=native_resolution,
width=width,
height=height,
hires_fix_denoising_strength=hires_fix_denoising_strength,
ddim_steps=ddim_steps,
)


class HordeLib:
Expand Down Expand Up @@ -825,13 +866,15 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
raise RuntimeError(f"Invalid key {key}")
elif "*" in key:
key, multiplier = key.split("*", 1)
elif key in payload:

if key in payload:
if multiplier:
pipeline_params[newkey] = round(payload.get(key) * float(multiplier))
else:
pipeline_params[newkey] = payload.get(key)
else:
elif not isinstance(key, FunctionType):
logger.error(f"Parameter {key} not found")

# We inject these parameters to ensure the HordeCheckpointLoader knows what file to load, if necessary
# We don't want to hardcode this into the pipeline.json as we export this directly from ComfyUI
# and don't want to have to rememebr to re-add those keys
Expand Down Expand Up @@ -874,16 +917,22 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
baseline = None
if model_details:
baseline = model_details.get("baseline")
if baseline and (baseline == "stable_cascade" or baseline == "stable_diffusion_xl"):
new_width, new_height = ImageUtils.get_first_pass_image_resolution_max(
original_width,
original_height,
)
else:
new_width, new_height = ImageUtils.get_first_pass_image_resolution_min(
original_width,
original_height,
)
if baseline:
if baseline == "stable_cascade":
new_width, new_height = ImageUtils.get_first_pass_image_resolution_max(
original_width,
original_height,
)
elif baseline == "stable_diffusion_xl":
new_width, new_height = ImageUtils.get_first_pass_image_resolution_sdxl(
original_width,
original_height,
)
else: # fall through case; only `stable diffusion 1`` at time of writing
new_width, new_height = ImageUtils.get_first_pass_image_resolution_min(
original_width,
original_height,
)

# This is the *target* resolution
pipeline_params["latent_upscale.width"] = original_width
Expand Down
3 changes: 2 additions & 1 deletion hordelib/initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def initialise(
force_normal_vram_mode: bool = True,
extra_comfyui_args: list[str] | None = None,
disable_smart_memory: bool = False,
do_not_load_model_mangers: bool = False,
):
"""Initialise hordelib. This is required before using any other hordelib functions.
Expand Down Expand Up @@ -96,7 +97,7 @@ def initialise(
# Initialise model manager
from hordelib.shared_model_manager import SharedModelManager

SharedModelManager()
SharedModelManager(do_not_load_model_mangers=do_not_load_model_mangers)

sys.argv = sys_arg_bkp

Expand Down
17 changes: 1 addition & 16 deletions hordelib/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,7 @@ def load_model_database(self) -> None:
)

def download_model_reference(self) -> dict:
try:
logger.debug(f"Downloading Model Reference for {self.models_db_name}")
response = requests.get(self.remote_db)
logger.debug("Downloaded Model Reference successfully")
models = response.json()
logger.info("Updated Model Reference from remote.")
return models
except Exception as e: # XXX Double check and/or rework this
logger.error(
f"Download failed: {e}",
)
logger.warning("Model Reference not downloaded, using local copy")
if self.models_db_path.exists():
return json.loads(self.models_db_path.read_text())
logger.error("No local copy of Model Reference found!")
return {}
raise NotImplementedError("Downloading model databases is no longer supported within hordelib.")

def get_free_ram_mb(self) -> int:
"""Returns the amount of free RAM in MB rounded down to the nearest integer.
Expand Down
2 changes: 1 addition & 1 deletion hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class LoraModelManager(BaseModelManager):
)
LORA_API = "https://civitai.com/api/v1/models?types=LORA&sort=Highest%20Rated&primaryFileOnly=true"
MAX_RETRIES = 10 if not TESTS_ONGOING else 3
MAX_DOWNLOAD_THREADS = 5 if not TESTS_ONGOING else 15
MAX_DOWNLOAD_THREADS = 5 if not TESTS_ONGOING else 75
RETRY_DELAY = 3 if not TESTS_ONGOING else 0.2
"""The time to wait between retries in seconds"""
REQUEST_METADATA_TIMEOUT = 20 # Longer because civitai performs poorly on metadata requests for more than 5 models
Expand Down
1 change: 1 addition & 0 deletions hordelib/nodes/facerestore_cf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import model_management
import numpy as np
import torch

# from comfy_extras.chainner_models import model_loading
from hordelib.nodes.facerestore_cf.r_chainner import model_loading
from torchvision.transforms.functional import normalize
Expand Down
28 changes: 7 additions & 21 deletions hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,12 @@ def forward(
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
noise = [getattr(self.noises, f"noise{i}") for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
Expand All @@ -96,9 +92,7 @@ def forward(
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)

# main generation
Expand Down Expand Up @@ -160,14 +154,10 @@ def __init__(self, in_channels, out_channels, mode="down"):
def forward(self, x):
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
# upsample/downsample
out = F.interpolate(
out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
out = F.interpolate(out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False)
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
# skip
x = F.interpolate(
x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
x = F.interpolate(x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False)
skip = self.skip(x)
out = out + skip
return out
Expand Down Expand Up @@ -283,9 +273,7 @@ def __init__(
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
Expand Down Expand Up @@ -317,9 +305,7 @@ def __init__(
)
self.load_state_dict(state_dict)

def forward(
self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
Expand Down
6 changes: 1 addition & 5 deletions hordelib/nodes/facerestore_cf/r_chainner/model_loading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean
from hordelib.nodes.facerestore_cf.r_chainner.types import PyTorchModel

Expand All @@ -21,9 +20,6 @@ def load_state_dict(state_dict) -> PyTorchModel:
state_dict_keys = list(state_dict.keys())

# GFPGAN
if (
"toRGB.0.weight" in state_dict_keys
and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
):
if "toRGB.0.weight" in state_dict_keys and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys:
model = GFPGANv1Clean(state_dict)
return model
28 changes: 7 additions & 21 deletions hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def forward(self, x, style):
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)

weight = weight.view(
b * self.out_channels, c, self.kernel_size, self.kernel_size
)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)

# upsample or downsample if necessary
if self.sample_mode == "upsample":
Expand Down Expand Up @@ -224,9 +222,7 @@ def forward(self, x, style, skip=None):
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(
skip, scale_factor=2, mode="bilinear", align_corners=False
)
skip = F.interpolate(skip, scale_factor=2, mode="bilinear", align_corners=False)
out = out + skip
return out

Expand Down Expand Up @@ -257,9 +253,7 @@ class StyleGAN2GeneratorClean(nn.Module):
narrow (float): Narrow ratio for channels. Default: 1.0.
"""

def __init__(
self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
):
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
super(StyleGAN2GeneratorClean, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
Expand Down Expand Up @@ -362,9 +356,7 @@ def get_latent(self, x):
return self.style_mlp(x)

def mean_latent(self, num_latent):
latent_in = torch.randn(
num_latent, self.num_style_feat, device=self.constant_input.weight.device
)
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent

Expand Down Expand Up @@ -398,16 +390,12 @@ def forward(
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
noise = [getattr(self.noises, f"noise{i}") for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
Expand All @@ -422,9 +410,7 @@ def forward(
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)

# main generation
Expand Down
4 changes: 2 additions & 2 deletions hordelib/nodes/facerestore_cf/r_chainner/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import Union

from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean
Expand All @@ -11,7 +10,8 @@
def is_pytorch_face_model(model: object):
return isinstance(model, PyTorchFaceModels)

PyTorchModels = (*PyTorchFaceModels, )

PyTorchModels = (*PyTorchFaceModels,)
PyTorchModel = Union[PyTorchFaceModel]


Expand Down
Loading

0 comments on commit f893800

Please sign in to comment.