Skip to content

Commit

Permalink
feat: Automatic downloading of TIs (#55)
Browse files Browse the repository at this point in the history
* Working TI download and injection

* tests: quiet potentially unset error

* fix: TI/Lora model reference json now is now with other reference files

* docs: moved class inline comments to docstrings

* style: lint violation (unsorted import)

* fix: variable name typo (missing an intended `s`)

* working ti ibjects

* fix downloaded tis not checking correct hash from hordeling

* fix: embeddings now directly notifies comfy

---------

Co-authored-by: tazlin <[email protected]>
  • Loading branch information
db0 and tazlin authored Aug 25, 2023
1 parent 69fc756 commit 304c793
Show file tree
Hide file tree
Showing 12 changed files with 1,071 additions and 14 deletions.
11 changes: 9 additions & 2 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,18 @@ def __init__(self) -> None:
self._callers = 0
self._gc_timer = time.time()
self._counter_mutex = threading.Lock()
# FIXME Temporary hack to set the model dir for LORAs
_comfy_folder_paths["loras"] = ( # type: ignore

# These set the default paths for comfyui to look for models and embeddings. From within hordelib,
# we aren't ever going to use the default ones, and this may help lubricate any further changes.
_comfy_folder_paths["lora"] = ( # type: ignore
[os.path.join(UserSettings.get_model_directory(), "loras")],
[".safetensors"],
)
_comfy_folder_paths["embeddings"] = ( # type: ignore
[os.path.join(UserSettings.get_model_directory(), "ti")],
[".safetensors"],
)

# Set custom node path
_comfy_folder_paths["custom_nodes"] = ([os.path.join(get_hordelib_path(), "nodes")], []) # type: ignore
# Load our pipelines
Expand Down
4 changes: 4 additions & 0 deletions hordelib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
gfpgan = auto()
safety_checker = auto()
lora = auto()
ti = auto()
blip = auto()
clip = auto()

Expand All @@ -55,6 +56,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.gfpgan: True,
MODEL_CATEGORY_NAMES.safety_checker: True,
MODEL_CATEGORY_NAMES.lora: True,
MODEL_CATEGORY_NAMES.ti: True,
MODEL_CATEGORY_NAMES.blip: True,
MODEL_CATEGORY_NAMES.clip: True,
}
Expand All @@ -69,6 +71,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.gfpgan: MODEL_CATEGORY_NAMES.gfpgan,
MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker,
MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora,
MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti,
MODEL_CATEGORY_NAMES.blip: MODEL_CATEGORY_NAMES.blip,
MODEL_CATEGORY_NAMES.clip: MODEL_CATEGORY_NAMES.clip,
}
Expand All @@ -83,6 +86,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.gfpgan: MODEL_CATEGORY_NAMES.gfpgan,
MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker,
MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora,
MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti,
MODEL_CATEGORY_NAMES.blip: MODEL_CATEGORY_NAMES.blip,
MODEL_CATEGORY_NAMES.clip: MODEL_CATEGORY_NAMES.clip,
}
Expand Down
33 changes: 32 additions & 1 deletion hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class HordeLib:
"prompt": {"datatype": str, "default": ""},
"negative_prompt": {"datatype": str, "default": ""},
"loras": {"datatype": list, "default": []},
"tis": {"datatype": list, "default": []},
"ddim_steps": {"datatype": int, "min": 1, "max": 500, "default": 30},
"n_iter": {"datatype": int, "min": 1, "max": 100, "default": 1},
"model": {"datatype": str, "default": "stable_diffusion"},
Expand Down Expand Up @@ -286,7 +287,37 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
# Remap controlnet models, horde to comfy
if payload.get("control_type"):
payload["control_type"] = HordeLib.CONTROLNET_IMAGE_PREPROCESSOR_MAP.get(payload.get("control_type"))

if payload.get("tis") and SharedModelManager.manager.ti:
# Remove any requested TIs that we don't have
for ti in payload.get("tis"):
# Determine the actual ti filename
if not SharedModelManager.manager.ti.is_local_model(str(ti["name"])):
adhoc_ti = SharedModelManager.manager.ti.fetch_adhoc_ti(str(ti["name"]))
if not adhoc_ti:
logger.info(f"Adhoc TI requested '{ti['name']}' could not be found in CivitAI. Ignoring!")
continue
ti_name = SharedModelManager.manager.ti.get_ti_name(str(ti["name"]))
if ti_name:
logger.debug(f"Found valid TI {ti_name}")
if SharedModelManager.manager.compvis is None:
raise RuntimeError("Cannot use TIs without compvis loaded!")
model_details = SharedModelManager.manager.compvis.get_model(payload["model"])
# If the ti and model do not match baseline, we ignore the TI
if not SharedModelManager.manager.ti.do_baselines_match(ti_name, model_details):
logger.info(f"Skipped TI {ti_name} because its baseline does not match the model's")
continue
ti_inject = ti.get("inject_ti")
ti_strength = ti.get("strength", 1.0)
if type(ti_strength) not in [float, int]:
ti_strength = 1.0
ti_id = SharedModelManager.manager.ti.get_ti_id(str(ti["name"]))
if ti_inject == "prompt":
payload["prompt"] = f'(embedding:{ti_id}:{ti_strength}),{payload["prompt"]}'
elif ti_inject == "negprompt":
if "###" not in payload["prompt"]:
payload["prompt"] += "###"
payload["prompt"] = f'{payload["prompt"]},(embedding:{ti_id}:{ti_strength})'
SharedModelManager.manager.ti.touch_ti(ti_name)
# Setup controlnet if required
# For LORAs we completely build the LORA section of the pipeline dynamically, as we have
# to handle n LORA models which form chained nodes in the pipeline.
Expand Down
14 changes: 9 additions & 5 deletions hordelib/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
*,
download_reference: bool = False,
model_category_name: MODEL_CATEGORY_NAMES = MODEL_CATEGORY_NAMES.default_models,
models_db_path: Path | None = None,
):
"""Create a new instance of this model manager.
Expand All @@ -94,10 +95,14 @@ def __init__(
self.tainted_models = []
self.pkg = importlib_resources.files("hordelib") # XXX Remove
self.models_db_name = MODEL_DB_NAMES[model_category_name]
self.models_db_path = Path(get_hordelib_path()).joinpath(
"model_database/",
f"{self.models_db_name}.json",
)

if models_db_path:
self.models_db_path = models_db_path
else:
self.models_db_path = Path(get_hordelib_path()).joinpath(
"model_database/",
f"{self.models_db_name}.json",
)

self.cuda_available = torch.cuda.is_available() # XXX Remove?
self.cuda_devices, self.recommended_gpu = self.get_cuda_devices() # XXX Remove?
Expand Down Expand Up @@ -700,7 +705,6 @@ def download_file(self, url: str, filename: str) -> bool:
remote_file_size = 0

while retries:

if os.path.exists(partial_pathname):
# If file exists, find the size and append to it
partial_size = os.path.getsize(partial_pathname)
Expand Down
6 changes: 4 additions & 2 deletions hordelib/model_manager/compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ def modelToRam(
model_name: str,
**kwargs,
) -> dict[str, typing.Any]:
embeddings_path = os.getenv("HORDE_MODEL_DIR_EMBEDDINGS", "./")
embeddings_path = os.path.join(UserSettings.get_model_directory(), "ti")
if not embeddings_path:
logger.debug("No embeddings path found, disabling embeddings")

if not kwargs.get("local", False):
ckpt_path = self.getFullModelPath(model_name)
else:
ckpt_path = os.path.join(self.modelFolderPath, model_name)
return horde_load_checkpoint(
ckpt_path=ckpt_path,
embeddings_path=embeddings_path,
embeddings_path=embeddings_path if embeddings_path else None,
)

def can_cache_on_disk(self):
Expand Down
8 changes: 8 additions & 0 deletions hordelib/model_manager/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from hordelib.model_manager.gfpgan import GfpganModelManager
from hordelib.model_manager.lora import LoraModelManager
from hordelib.model_manager.safety_checker import SafetyCheckerModelManager
from hordelib.model_manager.ti import TextualInversionModelManager
from hordelib.settings import UserSettings

MODEL_MANAGERS_TYPE_LOOKUP: dict[MODEL_CATEGORY_NAMES | str, type[BaseModelManager]] = {
Expand All @@ -31,6 +32,7 @@
MODEL_CATEGORY_NAMES.gfpgan: GfpganModelManager,
MODEL_CATEGORY_NAMES.safety_checker: SafetyCheckerModelManager,
MODEL_CATEGORY_NAMES.lora: LoraModelManager,
MODEL_CATEGORY_NAMES.ti: TextualInversionModelManager,
MODEL_CATEGORY_NAMES.blip: BlipModelManager,
MODEL_CATEGORY_NAMES.clip: ClipModelManager,
}
Expand Down Expand Up @@ -85,6 +87,12 @@ def lora(self) -> LoraModelManager | None:
found_mm = self.get_mm_pointer(LoraModelManager)
return found_mm if isinstance(found_mm, LoraModelManager) else None

@property
def ti(self) -> TextualInversionModelManager | None:
"""The Textual Inversion model manager instance. Returns `None` if not loaded."""
found_mm = self.get_mm_pointer(TextualInversionModelManager)
return found_mm if isinstance(found_mm, TextualInversionModelManager) else None

@property
def blip(self) -> BlipModelManager | None:
"""The Blip model manager instance. Returns `None` if not loaded."""
Expand Down
14 changes: 10 additions & 4 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ class LoraModelManager(BaseModelManager):
LORA_API = "https://civitai.com/api/v1/models?types=LORA&sort=Highest%20Rated&primaryFileOnly=true"
MAX_RETRIES = 10 if not TESTS_ONGOING else 1
MAX_DOWNLOAD_THREADS = 3 # max concurrent downloads
RETRY_DELAY = 5 # seconds
REQUEST_METADATA_TIMEOUT = 30 # seconds
REQUEST_DOWNLOAD_TIMEOUT = 300 # seconds
RETRY_DELAY = 5
"""The time to wait between retries in seconds"""
REQUEST_METADATA_TIMEOUT = 30
"""The time to wait for a response from the server in seconds"""
REQUEST_DOWNLOAD_TIMEOUT = 300
"""The time to wait for a response from the server in seconds"""
THREAD_WAIT_TIME = 2
"""The time to wait between checking the download queue in seconds"""

def __init__(
self,
Expand Down Expand Up @@ -72,12 +76,14 @@ def __init__(
self._index_ids = {}
self._index_orig_names = {}

models_db_path = LEGACY_REFERENCE_FOLDER.joinpath("lora.json").resolve()

super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.lora,
download_reference=download_reference,
models_db_path=models_db_path,
)
# FIXME (shift lora.json handling into horde_model_reference?)
self.models_db_path = LEGACY_REFERENCE_FOLDER.joinpath("lora.json").resolve()

def loadModelDatabase(self, list_models=False):
if self.model_reference:
Expand Down
Loading

0 comments on commit 304c793

Please sign in to comment.