Skip to content

Commit

Permalink
Merge pull request #255 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: stable cascade 2pass
  • Loading branch information
tazlin authored May 13, 2024
2 parents 572e1b2 + f094131 commit 9524669
Show file tree
Hide file tree
Showing 26 changed files with 3,547 additions and 1,294 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ jobs:
- uses: actions/checkout@v3
- name: Run pre-commit
uses: pre-commit/[email protected]
with:
extra_args: --all-files

build:
runs-on: self-hosted
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:
- uses: actions/checkout@v3
- name: Run pre-commit
uses: pre-commit/[email protected]
with:
extra_args: --all-files

build:
runs-on: self-hosted
Expand Down
38 changes: 13 additions & 25 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _fix_pipeline_types(self, data: dict) -> dict:

return data

def _fix_node_names(self, data: dict, design: dict) -> dict:
def _fix_node_names(self, data: dict) -> dict:
"""Rename nodes to the "title" set in the design file.
Args:
Expand All @@ -472,15 +472,12 @@ def _fix_node_names(self, data: dict, design: dict) -> dict:
# in the design file. These must be unique names.
newnodes = {}
renames = {}
nodes = design["nodes"]
for nodename, oldnode in data.items():
for nodename, nodedata in data.items():
newname = nodename
for node in nodes:
if str(node["id"]) == str(nodename) and "title" in node:
newname = node["title"]
break
if nodedata.get("_meta", {}).get("title"):
newname = nodedata["_meta"]["title"]
renames[nodename] = newname
newnodes[newname] = oldnode
newnodes[newname] = nodedata

# Now we've renamed the node names, change any references to them also
for node in newnodes.values():
Expand All @@ -504,13 +501,13 @@ def _fix_node_names(self, data: dict, design: dict) -> dict:
#
# Note also that the format of the design files from web app is expected to change at a fast
# pace. This is why the only thing that partially relies on that format, is in fact, optional.
def _patch_pipeline(self, data: dict, design: dict) -> dict:
def _patch_pipeline(self, data: dict) -> dict:
"""Patch the pipeline data with the design data."""
# FIXME: This can now be done through the _meta.title key included with each API export.
# First replace comfyui standard types with hordelib node types
data = self._fix_pipeline_types(data)
# Now try to find better parameter names
return self._fix_node_names(data, design)
return self._fix_node_names(data)

def _load_pipeline(self, filename: str) -> bool | None:
"""
Expand Down Expand Up @@ -540,17 +537,8 @@ def _load_pipeline(self, filename: str) -> bool | None:
# Load the pipeline data from the file
pipeline_data = json.loads(jsonfile.read())
# Check if there is a design file for this pipeline
pipeline_design = os.path.join(
os.path.dirname(os.path.dirname(filename)),
"pipeline_designs",
os.path.basename(filename),
)
# If there is a design file, patch the pipeline data with it
if os.path.exists(pipeline_design):
logger.debug(f"Patching pipeline {pipeline_name}")
with open(pipeline_design) as design_file:
design_data = json.loads(design_file.read())
pipeline_data = self._patch_pipeline(pipeline_data, design_data)
logger.debug(f"Patching pipeline {pipeline_name}")
pipeline_data = self._patch_pipeline(pipeline_data)
# Add the pipeline data to the pipelines dictionary
self.pipelines[pipeline_name] = pipeline_data
logger.debug(f"Loaded inference pipeline: {pipeline_name}")
Expand Down Expand Up @@ -690,10 +678,10 @@ def _run_pipeline(
# This is useful for dumping the entire pipeline to the terminal when
# developing and debugging new pipelines. A badly structured pipeline
# file just results in a cryptic error from comfy
# if False: # This isn't here, Tazlin :)
# with open("pipeline_debug.json", "w") as outfile:
# default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>"
# outfile.write(json.dumps(pipeline, indent=4, default=default))
# if True: # This isn't here, Tazlin :)
# with open("pipeline_debug.json", "w") as outfile:
# default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>"
# outfile.write(json.dumps(pipeline, indent=4, default=default))
# pretty_pipeline = pformat(pipeline)
# logger.warning(pretty_pipeline)

Expand Down
53 changes: 45 additions & 8 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ class HordeLib:
"model_loader_stage_b.ckpt_name": "stable_cascade_stage_b",
"model_loader_stage_b.model_name": "stable_cascade_stage_b",
"model_loader_stage_b.horde_model_name": "model_name",
# Stable Cascade 2pass
"2pass_sampler_stage_c.sampler_name": "sampler_name",
"2pass_sampler_stage_c.denoise": "hires_fix_denoising_strength",
"2pass_sampler_stage_b.sampler_name": "sampler_name",
}

_comfyui_callback: Callable[[str, dict, str], None] | None = None
Expand Down Expand Up @@ -771,14 +775,44 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis

# For hires fix, change the image sizes as we create an intermediate image first
if payload.get("hires_fix", False):
width = pipeline_params.get("empty_latent_image.width", 0)
height = pipeline_params.get("empty_latent_image.height", 0)
if width > 512 and height > 512:
newwidth, newheight = ImageUtils.calculate_source_image_size(width, height)
pipeline_params["latent_upscale.width"] = width
pipeline_params["latent_upscale.height"] = height
pipeline_params["empty_latent_image.width"] = newwidth
pipeline_params["empty_latent_image.height"] = newheight
model_details = (
SharedModelManager.manager.compvis.get_model_reference_info(payload["model_name"])
if SharedModelManager.manager.compvis
else None
)

original_width = pipeline_params.get("empty_latent_image.width")
original_height = pipeline_params.get("empty_latent_image.height")

if original_width is None or original_height is None:
logger.error("empty_latent_image.width or empty_latent_image.height not found. Using 512x512.")
original_width, original_height = (512, 512)

new_width, new_height = (None, None)

if model_details and model_details.get("baseline") == "stable_cascade":
new_width, new_height = ImageUtils.get_first_pass_image_resolution_max(
original_width,
original_height,
)
else:
new_width, new_height = ImageUtils.get_first_pass_image_resolution_min(
original_width,
original_height,
)

# This is the *target* resolution
pipeline_params["latent_upscale.width"] = original_width
pipeline_params["latent_upscale.height"] = original_height

if new_width and new_height:
# This is the *first pass* resolution
pipeline_params["empty_latent_image.width"] = new_width
pipeline_params["empty_latent_image.height"] = new_height
else:
logger.error("Could not determine new image size for hires fix. Using 1024x1024.")
pipeline_params["empty_latent_image.width"] = 1024
pipeline_params["empty_latent_image.height"] = 1024

if payload.get("control_type"):
# Inject control net model manager
Expand Down Expand Up @@ -886,13 +920,16 @@ def _get_appropriate_pipeline(self, params):
# image_upscale
# stable_cascade
# stable_cascade_remix
# stable_cascade_2pass

# controlnet, controlnet_hires_fix controlnet_annotator
if params.get("model_name"):
model_details = SharedModelManager.manager.compvis.get_model_reference_info(params["model_name"])
if model_details.get("baseline") == "stable_cascade":
if params.get("source_processing") == "remix":
return "stable_cascade_remix"
if params.get("hires_fix", False):
return "stable_cascade_2pass"
return "stable_cascade"
if params.get("control_type"):
if params.get("return_control_map", False):
Expand Down
2 changes: 1 addition & 1 deletion hordelib/nodes/node_lora_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import comfy.utils
import folder_paths
import folder_paths # type: ignore
from loguru import logger


Expand Down
Loading

0 comments on commit 9524669

Please sign in to comment.