Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allows multiple lora versions (#127) #132

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class HordeLib:
"model": {"datatype": float, "min": -10.0, "max": 10.0, "default": 1.0},
"clip": {"datatype": float, "min": -10.0, "max": 10.0, "default": 1.0},
"inject_trigger": {"datatype": str},
"is_version": {"datatype": bool},
}

TIS_SCHEMA = {
Expand Down Expand Up @@ -459,10 +460,17 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
valid_loras = []
for lora in payload.get("loras"):
# Determine the actual lora filename
if not SharedModelManager.manager.lora.is_model_available(str(lora["name"])):
is_version: bool = lora.get("is_version", False)
verstext = ""
if is_version:
verstext = " version"
if not SharedModelManager.manager.lora.is_lora_available(str(lora["name"]), is_version=is_version):
logger.debug(f"Adhoc lora requested '{lora['name']}' not yet downloaded. Downloading...")
try:
adhoc_lora = SharedModelManager.manager.lora.fetch_adhoc_lora(str(lora["name"]))
adhoc_lora = SharedModelManager.manager.lora.fetch_adhoc_lora(
str(lora["name"]),
is_version=is_version,
)
except Exception as e:
logger.info(f"Error fetching adhoc lora {lora['name']}: ({type(e).__name__}) {e}")
faults.append(
Expand All @@ -474,7 +482,10 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
)
adhoc_lora = None
if not adhoc_lora:
logger.info(f"Adhoc lora requested '{lora['name']}' could not be found in CivitAI. Ignoring!")
logger.info(
f"Adhoc lora requested{verstext} '{lora['name']} "
"could not be found in CivitAI. Ignoring!",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.lora,
Expand All @@ -484,15 +495,25 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
)
continue
# We store the actual lora name to search for the trigger
lora_name = SharedModelManager.manager.lora.get_lora_name(str(lora["name"]))
# If a version is requested, the lora name we need is the exact version
if is_version:
lora_name = str(lora["name"])
else:
lora_name = SharedModelManager.manager.lora.get_lora_name(str(lora["name"]))
if lora_name:
logger.debug(f"Found valid lora {lora_name}")
logger.debug(f"Found valid lora{verstext} {lora_name}")
if SharedModelManager.manager.compvis is None:
raise RuntimeError("Cannot use LORAs without a compvis loaded!")
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model"])
# If the lora and model do not match baseline, we ignore the lora
if not SharedModelManager.manager.lora.do_baselines_match(lora_name, model_details):
logger.info(f"Skipped lora {lora_name} because its baseline does not match the model's")
if not SharedModelManager.manager.lora.do_baselines_match(
lora_name,
model_details,
is_version=is_version,
):
logger.info(
f"Skipped lora{verstext} {lora_name} because its baseline does not match the model's",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.lora,
Expand All @@ -504,21 +525,25 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
trigger_inject = lora.get("inject_trigger")
trigger = None
if trigger_inject == "any":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name)
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version=is_version)
if triggers:
trigger = random.choice(triggers)
elif trigger_inject == "all":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name)
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version=is_version)
if triggers:
trigger = ", ".join(triggers)
elif trigger_inject is not None:
trigger = SharedModelManager.manager.lora.find_lora_trigger(lora_name, trigger_inject)
trigger = SharedModelManager.manager.lora.find_lora_trigger(
lora_name,
trigger_inject,
is_version,
)
if trigger:
# We inject at the start, to avoid throwing it in a negative prompt
payload["prompt"] = f'{trigger}, {payload["prompt"]}'
# the fixed up and validated filename (Comfy expect the "name" key to be the filename)
lora["name"] = SharedModelManager.manager.lora.get_lora_filename(lora_name)
SharedModelManager.manager.lora._touch_lora(lora_name)
lora["name"] = SharedModelManager.manager.lora.get_lora_filename(lora_name, is_version=is_version)
SharedModelManager.manager.lora._touch_lora(lora_name, is_version=is_version)
valid_loras.append(lora)
payload["loras"] = valid_loras
for lora_index, lora in enumerate(payload.get("loras")):
Expand Down
4 changes: 4 additions & 0 deletions hordelib/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
download_reference: bool = False,
model_category_name: MODEL_CATEGORY_NAMES = MODEL_CATEGORY_NAMES.default_models,
models_db_path: Path | None = None,
**kwargs,
):
"""Create a new instance of this model manager.

Expand All @@ -62,6 +63,9 @@ def __init__(
download_reference (bool, optional): Get the model DB from github. Defaults to False.
"""

if len(kwargs) > 0:
logger.debug(f"Unused kwargs in {type(self)}: {kwargs}")

self.model_folder_path = UserSettings.get_model_directory() / f"{MODEL_FOLDER_NAMES[model_category_name]}"

os.makedirs(self.model_folder_path, exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion hordelib/model_manager/codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@


class CodeFormerModelManager(BaseModelManager):
def __init__(self, download_reference=False):
def __init__(self, download_reference=False, **kwargs):
super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.codeformer,
download_reference=download_reference,
**kwargs,
)
2 changes: 2 additions & 0 deletions hordelib/model_manager/compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ class CompVisModelManager(BaseModelManager):
def __init__(
self,
download_reference=False,
**kwargs,
# custom_path="models/custom", # XXX Remove this and any others like it?
):
super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.compvis,
download_reference=download_reference,
**kwargs,
)
7 changes: 6 additions & 1 deletion hordelib/model_manager/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ class CONTROLNET_BASELINE_NAMES(StrEnum): # XXX # TODO Move this to consts.py


class ControlNetModelManager(BaseModelManager):
def __init__(self, download_reference=False):
def __init__(
self,
download_reference=False,
**kwargs,
):
super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.controlnet,
download_reference=download_reference,
**kwargs,
)

def download_control_type(
Expand Down
7 changes: 6 additions & 1 deletion hordelib/model_manager/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@


class DiffusersModelManager(BaseModelManager):
def __init__(self, download_reference=False):
def __init__(
self,
download_reference=False,
**kwargs,
):
raise NotImplementedError("Diffusers are not yet supported")

super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.diffusers,
download_reference=download_reference,
**kwargs,
)
7 changes: 6 additions & 1 deletion hordelib/model_manager/esrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@


class EsrganModelManager(BaseModelManager):
def __init__(self, download_reference=False):
def __init__(
self,
download_reference=False,
**kwargs,
):
super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.esrgan,
download_reference=download_reference,
**kwargs,
)
7 changes: 6 additions & 1 deletion hordelib/model_manager/gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@


class GfpganModelManager(BaseModelManager):
def __init__(self, download_reference=False):
def __init__(
self,
download_reference=False,
**kwargs,
):
super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.gfpgan,
download_reference=download_reference,
**kwargs,
)
8 changes: 7 additions & 1 deletion hordelib/model_manager/hyper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Home for the controller class ModelManager, and related meta information."""
import threading
from collections.abc import Callable, Iterable
from multiprocessing.synchronize import Lock as multiprocessing_lock

import torch
from loguru import logger
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(
def init_model_managers(
self,
managers_to_load: Iterable[str | MODEL_CATEGORY_NAMES | type[BaseModelManager]],
multiprocessing_lock: multiprocessing_lock | None = None,
) -> None:
for manager_to_load in managers_to_load:
resolve_manager_to_load_type: type[BaseModelManager] | None = None
Expand All @@ -174,7 +176,11 @@ def init_model_managers(
f"Attempted to load a model manager which is already loaded: '{resolve_manager_to_load_type}'.",
)
continue
self.active_model_managers.append(resolve_manager_to_load_type())
self.active_model_managers.append(
resolve_manager_to_load_type(
multiprocessing_lock=multiprocessing_lock,
),
)

def unload_model_managers(
self,
Expand Down
Loading