Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Automatic downloading of TIs #55

Merged
merged 17 commits into from
Aug 25, 2023
Merged
11 changes: 9 additions & 2 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,18 @@ def __init__(self) -> None:
self._callers = 0
self._gc_timer = time.time()
self._counter_mutex = threading.Lock()
# FIXME Temporary hack to set the model dir for LORAs
_comfy_folder_paths["loras"] = ( # type: ignore

# These set the default paths for comfyui to look for models and embeddings. From within hordelib,
# we aren't ever going to use the default ones, and this may help lubricate any further changes.
_comfy_folder_paths["lora"] = ( # type: ignore
[os.path.join(UserSettings.get_model_directory(), "loras")],
[".safetensors"],
)
_comfy_folder_paths["embeddings"] = ( # type: ignore
[os.path.join(UserSettings.get_model_directory(), "ti")],
[".safetensors"],
)

# Set custom node path
_comfy_folder_paths["custom_nodes"] = ([os.path.join(get_hordelib_path(), "nodes")], []) # type: ignore
# Load our pipelines
Expand Down
4 changes: 4 additions & 0 deletions hordelib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
gfpgan = auto()
safety_checker = auto()
lora = auto()
ti = auto()
blip = auto()
clip = auto()

Expand All @@ -55,6 +56,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.gfpgan: True,
MODEL_CATEGORY_NAMES.safety_checker: True,
MODEL_CATEGORY_NAMES.lora: True,
MODEL_CATEGORY_NAMES.ti: True,
MODEL_CATEGORY_NAMES.blip: True,
MODEL_CATEGORY_NAMES.clip: True,
}
Expand All @@ -69,6 +71,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.gfpgan: MODEL_CATEGORY_NAMES.gfpgan,
MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker,
MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora,
MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti,
MODEL_CATEGORY_NAMES.blip: MODEL_CATEGORY_NAMES.blip,
MODEL_CATEGORY_NAMES.clip: MODEL_CATEGORY_NAMES.clip,
}
Expand All @@ -83,6 +86,7 @@ class MODEL_CATEGORY_NAMES(StrEnum):
MODEL_CATEGORY_NAMES.gfpgan: MODEL_CATEGORY_NAMES.gfpgan,
MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker,
MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora,
MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti,
MODEL_CATEGORY_NAMES.blip: MODEL_CATEGORY_NAMES.blip,
MODEL_CATEGORY_NAMES.clip: MODEL_CATEGORY_NAMES.clip,
}
Expand Down
33 changes: 32 additions & 1 deletion hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class HordeLib:
"prompt": {"datatype": str, "default": ""},
"negative_prompt": {"datatype": str, "default": ""},
"loras": {"datatype": list, "default": []},
"tis": {"datatype": list, "default": []},
"ddim_steps": {"datatype": int, "min": 1, "max": 500, "default": 30},
"n_iter": {"datatype": int, "min": 1, "max": 100, "default": 1},
"model": {"datatype": str, "default": "stable_diffusion"},
Expand Down Expand Up @@ -286,7 +287,37 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
# Remap controlnet models, horde to comfy
if payload.get("control_type"):
payload["control_type"] = HordeLib.CONTROLNET_IMAGE_PREPROCESSOR_MAP.get(payload.get("control_type"))

if payload.get("tis") and SharedModelManager.manager.ti:
# Remove any requested TIs that we don't have
for ti in payload.get("tis"):
# Determine the actual ti filename
if not SharedModelManager.manager.ti.is_local_model(str(ti["name"])):
adhoc_ti = SharedModelManager.manager.ti.fetch_adhoc_ti(str(ti["name"]))
if not adhoc_ti:
logger.info(f"Adhoc TI requested '{ti['name']}' could not be found in CivitAI. Ignoring!")
continue
ti_name = SharedModelManager.manager.ti.get_ti_name(str(ti["name"]))
if ti_name:
logger.debug(f"Found valid TI {ti_name}")
if SharedModelManager.manager.compvis is None:
raise RuntimeError("Cannot use TIs without compvis loaded!")
model_details = SharedModelManager.manager.compvis.get_model(payload["model"])
# If the ti and model do not match baseline, we ignore the TI
if not SharedModelManager.manager.ti.do_baselines_match(ti_name, model_details):
logger.info(f"Skipped TI {ti_name} because its baseline does not match the model's")
continue
ti_inject = ti.get("inject_ti")
ti_strength = ti.get("strength", 1.0)
if type(ti_strength) not in [float, int]:
ti_strength = 1.0
ti_id = SharedModelManager.manager.ti.get_ti_id(str(ti["name"]))
if ti_inject == "prompt":
payload["prompt"] = f'(embedding:{ti_id}:{ti_strength}),{payload["prompt"]}'
elif ti_inject == "negprompt":
if "###" not in payload["prompt"]:
payload["prompt"] += "###"
payload["prompt"] = f'{payload["prompt"]},(embedding:{ti_id}:{ti_strength})'
SharedModelManager.manager.ti.touch_ti(ti_name)
# Setup controlnet if required
# For LORAs we completely build the LORA section of the pipeline dynamically, as we have
# to handle n LORA models which form chained nodes in the pipeline.
Expand Down
14 changes: 9 additions & 5 deletions hordelib/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
*,
download_reference: bool = False,
model_category_name: MODEL_CATEGORY_NAMES = MODEL_CATEGORY_NAMES.default_models,
models_db_path: Path | None = None,
):
"""Create a new instance of this model manager.

Expand All @@ -94,10 +95,14 @@ def __init__(
self.tainted_models = []
self.pkg = importlib_resources.files("hordelib") # XXX Remove
self.models_db_name = MODEL_DB_NAMES[model_category_name]
self.models_db_path = Path(get_hordelib_path()).joinpath(
"model_database/",
f"{self.models_db_name}.json",
)

if models_db_path:
self.models_db_path = models_db_path
else:
self.models_db_path = Path(get_hordelib_path()).joinpath(
"model_database/",
f"{self.models_db_name}.json",
)

self.cuda_available = torch.cuda.is_available() # XXX Remove?
self.cuda_devices, self.recommended_gpu = self.get_cuda_devices() # XXX Remove?
Expand Down Expand Up @@ -700,7 +705,6 @@ def download_file(self, url: str, filename: str) -> bool:
remote_file_size = 0

while retries:

if os.path.exists(partial_pathname):
# If file exists, find the size and append to it
partial_size = os.path.getsize(partial_pathname)
Expand Down
6 changes: 4 additions & 2 deletions hordelib/model_manager/compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ def modelToRam(
model_name: str,
**kwargs,
) -> dict[str, typing.Any]:
embeddings_path = os.getenv("HORDE_MODEL_DIR_EMBEDDINGS", "./")
embeddings_path = os.path.join(UserSettings.get_model_directory(), "ti")
if not embeddings_path:
logger.debug("No embeddings path found, disabling embeddings")

if not kwargs.get("local", False):
ckpt_path = self.getFullModelPath(model_name)
else:
ckpt_path = os.path.join(self.modelFolderPath, model_name)
return horde_load_checkpoint(
ckpt_path=ckpt_path,
embeddings_path=embeddings_path,
embeddings_path=embeddings_path if embeddings_path else None,
)

def can_cache_on_disk(self):
Expand Down
8 changes: 8 additions & 0 deletions hordelib/model_manager/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from hordelib.model_manager.gfpgan import GfpganModelManager
from hordelib.model_manager.lora import LoraModelManager
from hordelib.model_manager.safety_checker import SafetyCheckerModelManager
from hordelib.model_manager.ti import TextualInversionModelManager
from hordelib.settings import UserSettings

MODEL_MANAGERS_TYPE_LOOKUP: dict[MODEL_CATEGORY_NAMES | str, type[BaseModelManager]] = {
Expand All @@ -31,6 +32,7 @@
MODEL_CATEGORY_NAMES.gfpgan: GfpganModelManager,
MODEL_CATEGORY_NAMES.safety_checker: SafetyCheckerModelManager,
MODEL_CATEGORY_NAMES.lora: LoraModelManager,
MODEL_CATEGORY_NAMES.ti: TextualInversionModelManager,
MODEL_CATEGORY_NAMES.blip: BlipModelManager,
MODEL_CATEGORY_NAMES.clip: ClipModelManager,
}
Expand Down Expand Up @@ -85,6 +87,12 @@ def lora(self) -> LoraModelManager | None:
found_mm = self.get_mm_pointer(LoraModelManager)
return found_mm if isinstance(found_mm, LoraModelManager) else None

@property
def ti(self) -> TextualInversionModelManager | None:
"""The Textual Inversion model manager instance. Returns `None` if not loaded."""
found_mm = self.get_mm_pointer(TextualInversionModelManager)
return found_mm if isinstance(found_mm, TextualInversionModelManager) else None

@property
def blip(self) -> BlipModelManager | None:
"""The Blip model manager instance. Returns `None` if not loaded."""
Expand Down
14 changes: 10 additions & 4 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ class LoraModelManager(BaseModelManager):
LORA_API = "https://civitai.com/api/v1/models?types=LORA&sort=Highest%20Rated&primaryFileOnly=true"
MAX_RETRIES = 10 if not TESTS_ONGOING else 1
MAX_DOWNLOAD_THREADS = 3 # max concurrent downloads
RETRY_DELAY = 5 # seconds
REQUEST_METADATA_TIMEOUT = 30 # seconds
REQUEST_DOWNLOAD_TIMEOUT = 300 # seconds
RETRY_DELAY = 5
"""The time to wait between retries in seconds"""
REQUEST_METADATA_TIMEOUT = 30
"""The time to wait for a response from the server in seconds"""
REQUEST_DOWNLOAD_TIMEOUT = 300
"""The time to wait for a response from the server in seconds"""
THREAD_WAIT_TIME = 2
"""The time to wait between checking the download queue in seconds"""

def __init__(
self,
Expand Down Expand Up @@ -72,12 +76,14 @@ def __init__(
self._index_ids = {}
self._index_orig_names = {}

models_db_path = LEGACY_REFERENCE_FOLDER.joinpath("lora.json").resolve()

super().__init__(
model_category_name=MODEL_CATEGORY_NAMES.lora,
download_reference=download_reference,
models_db_path=models_db_path,
)
# FIXME (shift lora.json handling into horde_model_reference?)
self.models_db_path = LEGACY_REFERENCE_FOLDER.joinpath("lora.json").resolve()

def loadModelDatabase(self, list_models=False):
if self.model_reference:
Expand Down
Loading