From 7e4a5072e5d0801fde92c523d2d8b75cb37cdc42 Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 21 Jul 2023 16:32:56 +0200 Subject: [PATCH 01/15] test TIs --- tests/test_horde_ti.py | 57 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/test_horde_ti.py diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py new file mode 100644 index 00000000..8930a9f9 --- /dev/null +++ b/tests/test_horde_ti.py @@ -0,0 +1,57 @@ +# test_horde_ti.py +import os +from datetime import datetime, timedelta + +import pytest +from PIL import Image + +from hordelib.horde import HordeLib +from hordelib.shared_model_manager import SharedModelManager + +from .testing_shared_functions import check_single_lora_image_similarity + + +class TestHordeTI: + @pytest.fixture(autouse=True, scope="class") + def setup_and_teardown(self, shared_model_manager: type[SharedModelManager]): + # shared_model_manager.manager.lora.download_default_loras() + # shared_model_manager.manager.lora.wait_for_downloads() + yield + shared_model_manager.manager.lora.stop_all() + + def test_basic_ti( + self, + hordelib_instance: HordeLib, + stable_diffusion_modelname_for_testing: str, + ): + + data = { + "sampler_name": "k_euler", + "cfg_scale": 8.0, + "denoising_strength": 1.0, + "seed": 1312, + "height": 512, + "width": 512, + "karras": False, + "tiling": False, + "hires_fix": False, + "clip_skip": 1, + "control_type": None, + "image_is_control": False, + "return_control_map": False, + "prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, atmospheric lighting, embedding:style-sylvamagic###(embedding:easynegative:0.5), embedding:bhands-neg", + "ddim_steps": 20, + "n_iter": 1, + "model": stable_diffusion_modelname_for_testing, + } + + pil_image = hordelib_instance.basic_inference(data) + assert pil_image is not None + + img_filename = "ti_basic.png" + pil_image.save(f"images/{img_filename}", quality=100) + + # assert check_single_lora_image_similarity( + # f"images_expected/{img_filename}", + # pil_image, + # ) From ad1f4883db1e40b30923ef45d2388c3ee618140f Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 18 Aug 2023 11:55:49 +0200 Subject: [PATCH 02/15] before tests --- hordelib/consts.py | 4 + hordelib/model_manager/ti.py | 783 +++++++++++++++++++++++++++++ tests/model_managers/test_mm_ti.py | 47 ++ tests/test_horde_ti.py | 4 +- 4 files changed, 837 insertions(+), 1 deletion(-) create mode 100644 hordelib/model_manager/ti.py create mode 100644 tests/model_managers/test_mm_ti.py diff --git a/hordelib/consts.py b/hordelib/consts.py index 4b1cce48..c4b57903 100644 --- a/hordelib/consts.py +++ b/hordelib/consts.py @@ -41,6 +41,7 @@ class MODEL_CATEGORY_NAMES(StrEnum): gfpgan = auto() safety_checker = auto() lora = auto() + ti = auto() blip = auto() clip = auto() @@ -55,6 +56,7 @@ class MODEL_CATEGORY_NAMES(StrEnum): MODEL_CATEGORY_NAMES.gfpgan: True, MODEL_CATEGORY_NAMES.safety_checker: True, MODEL_CATEGORY_NAMES.lora: True, + MODEL_CATEGORY_NAMES.ti: True, MODEL_CATEGORY_NAMES.blip: True, MODEL_CATEGORY_NAMES.clip: True, } @@ -69,6 +71,7 @@ class MODEL_CATEGORY_NAMES(StrEnum): MODEL_CATEGORY_NAMES.gfpgan: MODEL_CATEGORY_NAMES.gfpgan, MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker, MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora, + MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti, MODEL_CATEGORY_NAMES.blip: MODEL_CATEGORY_NAMES.blip, MODEL_CATEGORY_NAMES.clip: MODEL_CATEGORY_NAMES.clip, } @@ -83,6 +86,7 @@ class MODEL_CATEGORY_NAMES(StrEnum): MODEL_CATEGORY_NAMES.gfpgan: MODEL_CATEGORY_NAMES.gfpgan, MODEL_CATEGORY_NAMES.safety_checker: MODEL_CATEGORY_NAMES.safety_checker, MODEL_CATEGORY_NAMES.lora: MODEL_CATEGORY_NAMES.lora, + MODEL_CATEGORY_NAMES.ti: MODEL_CATEGORY_NAMES.ti, MODEL_CATEGORY_NAMES.blip: MODEL_CATEGORY_NAMES.blip, MODEL_CATEGORY_NAMES.clip: MODEL_CATEGORY_NAMES.clip, } diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py new file mode 100644 index 00000000..b0d7e20f --- /dev/null +++ b/hordelib/model_manager/ti.py @@ -0,0 +1,783 @@ +import copy +import glob +import hashlib +import json +import os +import re +import threading +import time +import typing +from collections import deque +from datetime import datetime, timedelta +from enum import auto + +import requests +from fuzzywuzzy import fuzz +from loguru import logger +from strenum import StrEnum +from typing_extensions import override + +from hordelib.consts import MODEL_CATEGORY_NAMES +from hordelib.model_manager.base import BaseModelManager +from hordelib.utils.sanitizer import Sanitizer + + +class DOWNLOAD_SIZE_CHECK(StrEnum): + everything = auto() + top = auto() + adhoc = auto() + + +TESTS_ONGOING = os.getenv("TESTS_ONGOING", "0") == "1" + + +class TextualInversionModelManager(BaseModelManager): + TI_API = "https://civitai.com/api/v1/models?types=TextualInversion&sort=Highest%20Rated&primaryFileOnly=true" + HORDELING_API = "https://hordeling.aihorde.net/api/v1/embedding" + MAX_RETRIES = 10 if not TESTS_ONGOING else 1 + MAX_DOWNLOAD_THREADS = 3 # max concurrent downloads + RETRY_DELAY = 5 # seconds + REQUEST_METADATA_TIMEOUT = 30 # seconds + REQUEST_DOWNLOAD_TIMEOUT = 300 # seconds + THREAD_WAIT_TIME = 2 + + def __init__( + self, + download_reference=False, + ): + self._data = None + self._next_page_url = None + self._mutex = threading.Lock() + self._file_count = 0 + self._download_threads = {} + self._download_queue = deque() + self._thread = None + self.stop_downloading = True + # Not yet handled, as we need a global reference to search through. + self._previous_model_reference = {} + self._adhoc_tis = set() + # If false, this MM will only download SFW tis + self.nsfw = True + self._adhoc_reset_thread = None + self._stop_all_threads = False + self._index_ids = {} + self._index_orig_names = {} + + super().__init__( + model_category_name=MODEL_CATEGORY_NAMES.ti, + download_reference=download_reference, + ) + + def loadModelDatabase(self, list_models=False): + if self.model_reference: + logger.info( + ( + "Model reference was already loaded." + f" Got {len(self.model_reference)} models for {self.models_db_name}." + ), + ) + logger.info("Reloading model reference...") + + # TI are always stored to disk and the model reference created slowly through ad-hoc requests + os.makedirs(self.modelFolderPath, exist_ok=True) + if self.models_db_path.exists(): + try: + self.model_reference = json.loads((self.models_db_path).read_text()) + logger.info("Loaded model reference from disk.") + except json.JSONDecodeError: + logger.error(f"Could not load {self.models_db_name} model reference from disk! Bad JSON?") + self.model_reference = {} + self.save_cached_reference_to_disk() + else: + logger.info(f"Initiating new model reference {self.models_db_name} model reference to disk.") + self.model_reference = {} + self.save_cached_reference_to_disk() + logger.info( + (f"Got {len(self.model_reference)} models for {self.models_db_name}.",), + ) + + def download_model_reference(self): + # We have to wipe it, as we are going to be adding it it instead of replacing it + # We're not downloading now, as we need to be able to init without it + self.model_reference = {} + self.save_cached_reference_to_disk() + + def _add_ti_ids_to_download_queue(self, ti_ids, adhoc=False, version_compare=None): + idsq = "&ids=".join([str(id) for id in ti_ids]) + url = f"https://civitai.com/api/v1/models?limit=100&ids={idsq}" + data = self._get_json(url) + if not data: + logger.warning(f"metadata for Textual Inversion {ti_ids} could not be downloaded!") + return + for ti_data in data.get("items", []): + ti = self._parse_civitai_ti_data(ti_data, adhoc=adhoc) + # If we're comparing versions, then we don't download if the existing ti metadata matches + # Instead we just refresh metadata information + if not ti: + continue + if version_compare and ti["version_id"] == version_compare: + logger.debug( + f"Downloaded metadata for Textual Inversion {ti['name']} " + f"('{ti['name']}') and found version match! Refreshing metadata.", + ) + ti["last_checked"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self._add_ti_to_reference(ti) + continue + logger.debug( + f"Downloaded metadata for Textual Inversions {ti['id']} ('{ti['name']}') and added to download queue", + ) + self._download_ti(ti) + + def _get_json(self, url): + retries = 0 + while retries <= self.MAX_RETRIES: + try: + response = requests.get(url, timeout=self.REQUEST_METADATA_TIMEOUT) + response.raise_for_status() + # Attempt to decode the response to JSON + return response.json() + + except (requests.HTTPError, requests.ConnectionError, requests.Timeout, json.JSONDecodeError): + retries += 1 + if retries <= self.MAX_RETRIES: + time.sleep(self.RETRY_DELAY) + else: + # Max retries exceeded, give up + return None + + except Exception as e: + # Failed badly + logger.error(f"url '{url}' download failed {e}") + return None + + def _get_more_items(self): + if not self._data: + # We need to lowercase the boolean, or CivitAI doesn't understand it >.> + url = f"{self.TI_API}&nsfw={str(self.nsfw).lower()}" + else: + url = self._next_page_url + + # This may be the end of the road, unlikely but... + if not url: + logger.warning("End of Textual Inversion data reached") + self.stop_downloading = True + else: + # Get the actual item data + items = self._get_json(url) + if items: + self._data = items + self._next_page_url = self._data.get("metadata", {}).get("nextPage", "") + else: + # We failed to get more items + logger.error("Failed to download all Textual Inversion data even after retries.") + self._data = None + self._next_page_url = None # give up + + def _parse_civitai_ti_data(self, item, adhoc=False): + """Return a simplified dictionary with the information we actually need about a ti""" + ti = { + "name": "", + "orig_name": "", + "sha256": "", + "filename": "", + "id": "", + "url": "", + "triggers": [], + "size_kb": 0, + "adhoc": adhoc, + "nsfw": False, + } + # get top version + try: + version = item.get("modelVersions", {})[0] + except IndexError: + version = {} + # Get model triggers + triggers = version.get("trainedWords", []) + # get first file that is a primary file and a safetensor + for file in version.get("files", {}): + if file.get("primary", False) and file.get("name", "").endswith(".safetensors"): + ti["name"] = Sanitizer.sanitise_model_name(item.get("name", "")) + ti["orig_name"] = item.get("name", "") + ti["id"] = item.get("id", 0) + ti["filename"] = f'{Sanitizer.sanitise_filename(ti["name"])}.safetensors' + ti["sha256"] = file.get("hashes", {}).get("SHA256") + try: + ti["size_kb"] = round(file.get("sizeKB", 0)) + except TypeError: + ti["size_kb"] = 24 # guess common case of 24Kb, it's not critical here + ti["url"] = file.get("downloadUrl", "") + ti["triggers"] = triggers + ti["nsfw"] = item.get("nsfw", True) + ti["baseModel"] = version.get("baseModel", "SD 1.5") + ti["version_id"] = version.get("id", 0) + break + # If we don't have everything required, fail + if ti["adhoc"] and not ti.get("sha256"): + logger.debug(f"Rejecting Textual Inversion {ti.get('name')} because it doesn't have a sha256") + return + if not ti.get("filename") or not ti.get("url"): + logger.debug(f"Rejecting Textual Inversion {ti.get('name')} because it doesn't have a url") + return + # We don't want to start downloading GBs of a single Textual Inversion. + # We just ignore anything over 150Mb. Them's the breaks... + if ti["adhoc"] and ti["size_kb"] > 20000: + logger.debug(f"Rejecting Textual Inversion {ti.get('name')} because its size is over 20Mb.") + return + if ti["adhoc"] and ti["nsfw"] and not self.nsfw: + logger.debug(f"Rejecting Textual Inversion {ti.get('name')} because worker is SFW.") + return + # Fixup A1111 centric triggers + for i, trigger in enumerate(ti["triggers"]): + if re.match("", trigger): + ti["triggers"][i] = re.sub("", "\\1", trigger) + return ti + + def _download_thread(self, thread_number): + # We try to download the Textual Inversion. There are tens of thousands of these things, we aren't + # picky if downloads fail, as they often will if civitai is the host, we just move on to + # the next one + logger.debug(f"Started Download Thread {thread_number}") + while True: + # Endlessly look for files to download and download them + if self._stop_all_threads: + logger.debug(f"Stopped Download Thread {thread_number}") + return + try: + ti = self._download_queue.popleft() + self._download_threads[thread_number]["ti"] = ti + except IndexError: + # Nothing in the queue + self._download_threads[thread_number]["ti"] = None + time.sleep(self.THREAD_WAIT_TIME) + continue + # Download the ti + retries = 0 + while retries <= self.MAX_RETRIES: + try: + # Just before we download this file, check if we already have it + filename = os.path.join(self.modelFolderPath, ti["filename"]) + hashfile = f"{os.path.splitext(filename)[0]}.sha256" + if os.path.exists(filename) and os.path.exists(hashfile): + # Check the hash + with open(hashfile) as infile: + try: + hashdata = infile.read().split()[0] + except (IndexError, OSError, IOError, PermissionError): + hashdata = "" + if not ti.get("sha256") or hashdata.lower() == ti["sha256"].lower(): + # we already have this ti, consider it downloaded + # the SHA256 might not exist when the ti has been selected in the curation list + # Where we allow them to skip it + if not ti.get("sha256"): + logger.debug( + f"Already have Textual Inversion {ti['filename']}. " + "Bypassing SHA256 check as there's none stored", + ) + else: + logger.debug(f"Already have Textual Inversion {ti['filename']}") + with self._mutex: + # We store as lower to allow case-insensitive search + ti["last_checked"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self._add_ti_to_reference(ti) + if self.is_default_cache_full(): + self.stop_downloading = True + self.save_cached_reference_to_disk() + break + + logger.info(f"Starting download of Textual Inversion {ti['filename']}") + hordeling_response = requests.get(f"{self.HORDELING_API}/{ti['id']}", timeout=5) + if not hordeling_response.ok: + if hordeling_response.status_code == 404: + logger.debug(f"Textual Inversion {ti['filename']} could not be found on AI Hordeling.") + break + # We will retry + logger.debug( + f"AI Horeling reported error when downloading {ti['filename']}: " + "{hordeling_response.json()}" + f"Retry {retries}/{self.MAX_RETRIES}", + ) + else: + response = requests.get( + hordeling_response.json()["url"], + timeout=self.REQUEST_DOWNLOAD_TIMEOUT, + ) + response.raise_for_status() + # Check the data hash + hash_object = hashlib.sha256() + hash_object.update(response.content) + sha256 = hash_object.hexdigest() + if not ti.get("sha256") or sha256.lower() == ti["sha256"].lower(): + # wow, we actually got a valid file, save it + with open(filename, "wb") as outfile: + outfile.write(response.content) + # Save the hash file + with open(hashfile, "wt") as outfile: + outfile.write(f"{sha256} *{ti['filename']}") + + # Shout about it + logger.info(f"Downloaded Textual Inversion {ti['filename']} ({ti['size_kb']} KB)") + # Maybe we're done + with self._mutex: + # We store as lower to allow case-insensitive search + ti["last_used"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ti["last_checked"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self._add_ti_to_reference(ti) + if self.is_adhoc_cache_full(): + self.delete_oldest_ti() + self.save_cached_reference_to_disk() + break + else: + # We will retry + logger.debug( + f"Downloaded Textual Inversion file for {ti['filename']} didn't match hash. " + f"Retry {retries}/{self.MAX_RETRIES}", + ) + + except (requests.HTTPError, requests.ConnectionError, requests.Timeout, json.JSONDecodeError) as e: + # We will retry + logger.debug(f"Error downloading {ti['filename']} {e}. Retry {retries}/{self.MAX_RETRIES}") + + except Exception as e: + # Failed badly, ignore and retry + logger.debug(f"Fatal error downloading {ti['filename']} {e}. Retry {retries}/{self.MAX_RETRIES}") + + retries += 1 + if retries > self.MAX_RETRIES: + break # fail + + time.sleep(self.RETRY_DELAY) + + def _download_ti(self, ti): + with self._mutex: + # Start download threads if they aren't already started + while len(self._download_threads) < self.MAX_DOWNLOAD_THREADS: + thread_iter = len(self._download_threads) + thread = threading.Thread(target=self._download_thread, daemon=True, args=(thread_iter,)) + self._download_threads[thread_iter] = {"thread": thread, "ti": None} + thread.start() + + # Add this ti to the download queue + self._download_queue.append(ti) + + def _process_items(self): + # i.e. given a bunch of TI item metadata, download them + if not self._data: + logger.debug("No Textual Inversion data to process") + return + for item in self._data.get("items", []): + ti = self._parse_civitai_ti_data(item) + if ti: + self._file_count += 1 + # We have valid ti data, download it + self._download_ti(ti) + + def _start_processing(self): + self.stop_downloading = False + + while not self.stop_downloading: + if self._stop_all_threads: + logger.debug("Stopped processing thread") + return + # Get some items to download + self._get_more_items() + + # If we have some items to process, process them + if self._data: + self._process_items() + + def _add_ti_to_reference(self, ti): + ti_key = ti["name"].lower().strip() + if ti.get("adhoc", False): + self._adhoc_tis.add(ti_key) + # Once added to our set, we don't need to specify it was adhoc anymore + del ti["adhoc"] + self.model_reference[ti_key] = ti + self._index_ids[ti["id"]] = ti_key + orig_name = ti.get("orig_name", ti["name"]).lower() + self._index_orig_names[orig_name] = ti_key + + def wait_for_downloads(self, timeout=None): + rtr = 0 + while not self.are_downloads_complete(): + time.sleep(0.5) + rtr += 0.5 + if timeout and rtr > timeout: + raise Exception(f"Textual Inversion downloads exceeded specified timeout ({timeout})") + + def are_downloads_complete(self): + # If we don't have any models in our reference, then we haven't downloaded anything + # perhaps faulty civitai? + if self._thread and self._thread.is_alive(): + return False + if not self.are_download_threads_idle(): + return False + if len(self._download_queue) > 0: + return False + return self.stop_downloading + + def are_download_threads_idle(self): + # logger.debug([dthread["ti"] for dthread in self._download_threads.values()]) + for dthread in self._download_threads.values(): + if dthread["ti"] is not None: + return False + return True + + def fuzzy_find_ti_key(self, ti_name): + # sname = Sanitizer.remove_version(ti_name).lower() + if type(ti_name) is int or ti_name.isdigit(): + if int(ti_name) in self._index_ids: + return self._index_ids[int(ti_name)] + return None + sname = ti_name.lower().strip() + if sname in self.model_reference: + return sname + if sname in self._index_orig_names: + return self._index_orig_names[sname] + if Sanitizer.has_unicode(sname): + for ti in self._index_orig_names: + if sname in ti: + return self._index_orig_names[ti] + # If a unicode name is not found in the orig_names index + # it won't be found anywhere else, as unicode chars are converted to ascii in the keys + # This saves us time doing unnecessary fuzzy searches + return None + for ti in self.model_reference: + if sname in ti: + return ti + for ti in self.model_reference: + if fuzz.ratio(sname, ti) > 80: + return ti + return None + + # Using `get_model` instead of `get_ti` as it exists in the base class + def get_model(self, model_name: str) -> dict | None: + """Returns the actual ti details dict for the specified model_name search string + Returns None if ti name not found""" + ti_name = self.fuzzy_find_ti_key(model_name) + if not ti_name: + return None + return self.model_reference.get(ti_name) + + def get_ti_filename(self, model_name: str): + """Returns the actual ti filename for the specified model_name search string + Returns None if ti name not found""" + ti = self.get_model(model_name) + if not ti: + return None + return ti["filename"] + + def get_ti_name(self, model_name: str): + """Returns the actual ti name for the specified model_name search string + Returns None if ti name not found""" + ti = self.get_model(model_name) + if not ti: + return None + return ti["name"] + + def get_ti_triggers(self, model_name: str): + """Returns a list of triggers for a specified ti name + Returns an empty list if no triggers are found + Returns None if ti name not found""" + ti = self.get_model(model_name) + if not ti: + return None + triggers = ti.get("triggers") + if triggers: + return triggers + # We don't `return ti.get("triggers", [])`` to avoid the returned list object being modified + # and then we keep returning previous items + return [] + + def find_ti_trigger(self, model_name: str, trigger_search: str): + """Searches for a specific trigger for a specified ti name + Returns None if string not found even with fuzzy search""" + triggers = self.get_ti_triggers(model_name) + if triggers is None: + return None + if trigger_search.lower() in [trigger.lower() for trigger in triggers]: + return trigger_search + for trigger in triggers: + if trigger_search.lower() in trigger.lower(): + return trigger + for trigger in triggers: + if fuzz.ratio(trigger_search.lower(), trigger.lower()) > 65: + return trigger + return None + + def save_cached_reference_to_disk(self): + with open(self.models_db_path, "wt", encoding="utf-8", errors="ignore") as outfile: + outfile.write(json.dumps(self.model_reference, indent=4)) + + def calculate_downloaded_tis(self, mode=DOWNLOAD_SIZE_CHECK.everything): + total_size = 0 + for ti in self.model_reference: + if mode == DOWNLOAD_SIZE_CHECK.top and ti in self._adhoc_tis: + continue + if mode == DOWNLOAD_SIZE_CHECK.adhoc and ti not in self._adhoc_tis: + continue + total_size += self.model_reference[ti]["size_kb"] + return total_size + + def is_default_cache_full(self): + return False + + def is_adhoc_cache_full(self): + return False + + def calculate_download_queue(self): + total_queue = 0 + for ti in self._download_queue: + total_queue += ti["size_kb"] + return total_queue + + def _parse_civitai_ti_data(self) -> str | None: + oldest_ti: str | None = None + oldest_datetime: datetime | None = None + for ti in self._adhoc_tis: + ti_datetime = datetime.strptime(self.model_reference[ti]["last_used"], "%Y-%m-%d %H:%M:%S") + if not oldest_ti: + oldest_ti = ti + oldest_datetime = ti_datetime + continue + if oldest_datetime and oldest_datetime > ti_datetime: + oldest_ti = ti + oldest_datetime = ti_datetime + return oldest_ti + + def delete_oldest_ti(self): + oldest_ti = self._parse_civitai_ti_data() + if not oldest_ti: + return + self.delete_ti(oldest_ti) + + def find_ti_from_filename(self, filename: str): + for ti in self.model_reference: + if self.model_reference[ti]["filename"] == filename: + return ti + return None + + def find_unused_tis(self): + files = glob.glob(f"{self.modelFolderPath}/*.safetensors") + filesnames = set() + for stfile in files: + filename = os.path.basename(stfile) + if not self.find_ti_from_filename(filename): + filesnames.add(filename) + return filesnames + + def delete_unused_tis(self, timeout=0): + """Deletes downloaded Textual Inversions which do not appear in the model_reference + By default protects the user by not running if are_downloads_complete() is not done + """ + waited = 0 + while not self.are_downloads_complete(): + if waited >= timeout: + raise Exception( + f"Waiting for current Textual Inversion downloads exceeded specified timeout ({timeout})", + ) + waited += 0.2 + time.sleep(0.2) + tis_to_delete = self.find_unused_tis() + for ti_filename in tis_to_delete: + self.delete_ti_files(ti_filename) + return tis_to_delete + + def delete_ti_files(self, ti_filename: str): + filename = os.path.join(self.modelFolderPath, ti_filename) + if not os.path.exists(filename): + logger.warning(f"Could not find Textual Inversion file on disk to delete: {filename}") + return + os.remove(filename) + logger.info(f"Deleted Textual Inversion file: {filename}") + + def delete_ti(self, ti_name: str): + ti_info = self.get_model(ti_name) + if not ti_info: + logger.warning(f"Could not find ti {ti_name} to delete") + return + self.delete_ti_files(ti_info["filename"]) + self._adhoc_tis.remove(ti_name) + del self._index_ids[ti_info["id"]] + del self._index_orig_names[ti_info["orig_name"].lower()] + del self.model_reference[ti_name] + self.save_cached_reference_to_disk() + + def ensure_ti_deleted(self, ti_name: str): + ti_key = self.fuzzy_find_ti_key(ti_name) + if not ti_key: + return + self.delete_ti(ti_key) + + # def reset_adhoc_tis(self): + # """Compared the known tis from the previous run to the current one + # Adds any definitions as adhoc tis, until we have as many Mb as self.max_adhoc_disk""" + # while not self.are_downloads_complete(): + # if self._stop_all_threads: + # logger.debug("Stopped processing thread") + # return + # time.sleep(0.2) + # now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + # self._adhoc_tis = set() + # sorted_items = [] + # try: + # sorted_items = sorted( + # self._previous_model_reference.items(), + # key=lambda x: x[1].get("last_used", now), + # reverse=True, + # ) + # except Exception as err: + # logger.error(err) + # while not self.is_adhoc_cache_full() and len(sorted_items) > 0: + # prevti_key, prevti_value = sorted_items.pop() + # if prevti_key in self.model_reference: + # continue + # # If False, it will initiates a redownload and call _add_ti_to_reference() later + # if self._check_for_refresh(prevti_key): + # self._add_ti_to_reference(prevti_value) + # self._adhoc_tis.add(prevti_key) + # for ti_key in self.model_reference: + # if ti_key in self._previous_model_reference: + # self.model_reference[ti_key]["last_used"] = self._previous_model_reference[ti_key].get( + # "last_used", + # now, + # ) + # # Final assurance that all our tis have a last_used timestamp + # for ti in self.model_reference.values(): + # if "last_used" not in ti: + # ti["last_used"] = now + # self._previous_model_reference = {} + # self.save_cached_reference_to_disk() + + def _check_for_refresh(self, ti_name: str): + """Returns True if a refresh is needed + and also initiates a refresh + Else returns False + """ + ti_details = self.get_model(ti_name) + if not ti_details: + return True + refresh = False + if "last_checked" not in ti_details: + refresh = True + elif "baseModel" not in ti_details: + refresh = True + else: + ti_datetime = datetime.strptime(ti_details["last_checked"], "%Y-%m-%d %H:%M:%S") + if ti_datetime < datetime.now() - timedelta(days=1): + refresh = True + if refresh: + logger.debug(f"Textual Inversion {ti_name} found needing refresh. Initiating metadata download...") + self._add_ti_ids_to_download_queue([ti_details["id"]], ti_details.get("version_id", -1)) + return False + return True + + # def check_for_valid + + # def is_adhoc_reset_complete(self): + # if self._adhoc_reset_thread and self._adhoc_reset_thread.is_alive(): + # return False + # return True + + # def wait_for_adhoc_reset(self, timeout=15): + # rtr = 0 + # while not self.is_adhoc_reset_complete(): + # time.sleep(0.2) + # rtr += 0.2 + # if timeout and rtr > timeout: + # raise Exception(f"Textual Inversion adhoc reset exceeded specified timeout ({timeout})") + + def stop_all(self): + self._stop_all_threads = True + + def touch_ti(self, ti_name): + """Updates the "last_used" key in a ti entry to current UTC time""" + ti = self.get_model(ti_name) + if not ti: + logger.warning(f"Could not find ti {ti_name} to touch") + return + ti["last_used"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + def get_ti_last_use(self, ti_name): + """Returns a dateimte object based on the "last_used" key in a ti entry""" + ti = self.get_model(ti_name) + if not ti: + logger.warning(f"Could not find ti {ti_name} to get last use") + return None + return datetime.strptime(ti["last_used"], "%Y-%m-%d %H:%M:%S") + + def fetch_adhoc_ti(self, ti_name, timeout=45): + if type(ti_name) is int or ti_name.isdigit(): + url = f"https://civitai.com/api/v1/models/{ti_name}" + else: + url = f"{self.TI_API}&nsfw={str(self.nsfw).lower()}&query={ti_name}" + data = self._get_json(url) + # CivitAI down + if not data: + return None + if "items" in data: + if len(data["items"]) == 0: + return None + ti = self._parse_civitai_ti_data(data["items"][0], adhoc=True) + else: + ti = self._parse_civitai_ti_data(data, adhoc=True) + # For example epi_noiseoffset doesn't have sha256 so we ignore it + # This avoid us faulting + if not ti: + return None + # We double-check that somehow our search missed it but CivitAI searches differently and found it + fuzzy_find = self.fuzzy_find_ti_key(ti["id"]) + if fuzzy_find: + logger.error(fuzzy_find) + return fuzzy_find + self._download_queue.append(ti) + # We need to wait a bit to make sure the threads pick up the download + time.sleep(self.THREAD_WAIT_TIME) + self.wait_for_downloads(timeout) + return ti["name"].lower() + + def do_baselines_match(self, ti_name, model_details): + self._check_for_refresh(ti_name) + lota_details = self.get_model(ti_name) + if not lota_details: + logger.warning(f"Could not find ti {ti_name} to check baselines") + return False + if "SD 1.5" in lota_details["baseModel"] and model_details["baseline"] == "stable diffusion 1": + return True + if "SD 2.1" in lota_details["baseModel"] and model_details["baseline"] == "stable diffusion 2": + return True + return False + + @override + def is_local_model(self, model_name): + return self.fuzzy_find_ti_key(model_name) is not None + + @override + def load( + self, + model_name: str, + *, + half_precision: bool = True, + gpu_id: int | None = 0, + cpu_only: bool = False, + **kwargs, + ) -> bool | None: + error = "load is not supported for Textual Inversions" + logger.error(error) + raise NotImplementedError(error) + + @override + def modelToRam( + self, + model_name: str, + **kwargs, + ) -> dict[str, typing.Any]: + error = "modelToRam is not supported for Textual Inversions" + logger.error(error) + raise NotImplementedError(error) + + def get_available_models(self): + """ + Returns the available model names + """ + return list(self.model_reference.keys()) diff --git a/tests/model_managers/test_mm_ti.py b/tests/model_managers/test_mm_ti.py new file mode 100644 index 00000000..4ea74e76 --- /dev/null +++ b/tests/model_managers/test_mm_ti.py @@ -0,0 +1,47 @@ +import os + +import pytest + +from hordelib.model_manager.ti import TextualInversionModelManager + + +class TestModelManagerLora: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + # We don't want to download a ton of loras for tests by mistake + assert os.getenv("TESTS_ONGOING") == "1" + + def test_fuzzy_search(self): + mmti = TextualInversionModelManager() + mmti.fetch_adhoc_ti("56519") + assert mmti.fuzzy_find_lora_key("negative_hand Negative Embedding") == "glowingrunesai" + assert mmti.fuzzy_find_lora_key("negative_hand") is None + assert mmti.fuzzy_find_lora_key(56519) == "blindbox/da gai shi mang he" + assert mmti.fuzzy_find_lora_key("56519") == "blindbox/da gai shi mang he" + mmti.stop_all() + + def test_ti_trigger_search(self): + mmti = TextualInversionModelManager() + mmti.fetch_adhoc_ti("71961") + assert mmti.get_lora_name("Fast Negative Embedding") == "GlowingRunesAI" + assert len(mmti.get_ti_triggers("Fast Negative Embedding")) > 0 + # We can't rely on triggers not changing + assert mmti.find_lora_trigger("Fast Negative Embedding", "FastNegativeV2") is not None + mmti.stop_all() + + def test_fetch_adhoc_ti(self): + mmti = TextualInversionModelManager() + mmti.ensure_ti_deleted(7808) + lora_key = mmti.fetch_adhoc_ti("7808") + assert lora_key == "EasyNegative".lower() + assert mmti.is_local_model("7808") + assert mmti.get_lora_name("7808") == "EasyNegative" + mmti.stop_all() + + def test_adhoc_non_existing(self): + mmti = TextualInversionModelManager() + ti_name = "__THIS SHOULD NOT EXIST. I SWEAR IF ONE OF YOU UPLOADS A TI WITH THIS NAME I AM GOING TO BE UPSET!" + lora_key = mmti.fetch_adhoc_lora(ti_name) + assert lora_key is None + assert not mmti.is_local_model(ti_name) + mmti.stop_all() diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index 8930a9f9..3fdc3bec 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -39,7 +39,9 @@ def test_basic_ti( "control_type": None, "image_is_control": False, "return_control_map": False, - "prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, atmospheric lighting, embedding:style-sylvamagic###(embedding:easynegative:0.5), embedding:bhands-neg", + "prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, " + "walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, " + "atmospheric lighting, embedding:style-sylvamagic###(embedding:easynegative:0.5), embedding:bhands-neg", "ddim_steps": 20, "n_iter": 1, "model": stable_diffusion_modelname_for_testing, From efc0b55a8fe26e54a5d7ea5fd0608a71a8f1aaa5 Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 18 Aug 2023 16:38:48 +0200 Subject: [PATCH 03/15] working on tests remote --- hordelib/shared_model_manager.py | 1 + tests/test_horde_ti.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hordelib/shared_model_manager.py b/hordelib/shared_model_manager.py index ef98dc3a..13dd3cd9 100644 --- a/hordelib/shared_model_manager.py +++ b/hordelib/shared_model_manager.py @@ -76,6 +76,7 @@ def loadModelManagers( gfpgan: bool = False, safety_checker: bool = False, lora: bool = False, + ti: bool = False, blip: bool = False, clip: bool = False, ): diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index 3fdc3bec..00e6ac88 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -17,14 +17,14 @@ def setup_and_teardown(self, shared_model_manager: type[SharedModelManager]): # shared_model_manager.manager.lora.download_default_loras() # shared_model_manager.manager.lora.wait_for_downloads() yield - shared_model_manager.manager.lora.stop_all() def test_basic_ti( self, hordelib_instance: HordeLib, stable_diffusion_modelname_for_testing: str, ): - + assert True + return data = { "sampler_name": "k_euler", "cfg_scale": 8.0, From 455e0fc777436de2d11215d183feaab6951d760a Mon Sep 17 00:00:00 2001 From: db0 Date: Thu, 24 Aug 2023 10:19:28 +0200 Subject: [PATCH 04/15] adjust ti tests --- tests/test_horde_ti.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index 00e6ac88..f3de80e6 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -8,20 +8,15 @@ from hordelib.horde import HordeLib from hordelib.shared_model_manager import SharedModelManager -from .testing_shared_functions import check_single_lora_image_similarity - - class TestHordeTI: @pytest.fixture(autouse=True, scope="class") def setup_and_teardown(self, shared_model_manager: type[SharedModelManager]): - # shared_model_manager.manager.lora.download_default_loras() - # shared_model_manager.manager.lora.wait_for_downloads() + assert shared_model_manager.manager.ti yield def test_basic_ti( self, hordelib_instance: HordeLib, - stable_diffusion_modelname_for_testing: str, ): assert True return From 4d139371c34b788cd1e1f4acb34e67b2f0815016 Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 25 Aug 2023 13:41:23 +0200 Subject: [PATCH 05/15] working test_mm_ti --- hordelib/model_manager/ti.py | 21 +++++++++++---------- tests/model_managers/test_mm_ti.py | 25 ++++++++++++------------- tests/test_horde_ti.py | 5 ++++- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index b0d7e20f..6aa10125 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -38,7 +38,7 @@ class TextualInversionModelManager(BaseModelManager): MAX_DOWNLOAD_THREADS = 3 # max concurrent downloads RETRY_DELAY = 5 # seconds REQUEST_METADATA_TIMEOUT = 30 # seconds - REQUEST_DOWNLOAD_TIMEOUT = 300 # seconds + REQUEST_DOWNLOAD_TIMEOUT = 30 # seconds THREAD_WAIT_TIME = 2 def __init__( @@ -92,9 +92,6 @@ def loadModelDatabase(self, list_models=False): logger.info(f"Initiating new model reference {self.models_db_name} model reference to disk.") self.model_reference = {} self.save_cached_reference_to_disk() - logger.info( - (f"Got {len(self.model_reference)} models for {self.models_db_name}.",), - ) def download_model_reference(self): # We have to wipe it, as we are going to be adding it it instead of replacing it @@ -195,8 +192,9 @@ def _parse_civitai_ti_data(self, item, adhoc=False): # Get model triggers triggers = version.get("trainedWords", []) # get first file that is a primary file and a safetensor + # logger.debug(json.dumps(version,indent=4)) for file in version.get("files", {}): - if file.get("primary", False) and file.get("name", "").endswith(".safetensors"): + if file.get("primary", False): ti["name"] = Sanitizer.sanitise_model_name(item.get("name", "")) ti["orig_name"] = item.get("name", "") ti["id"] = item.get("id", 0) @@ -294,16 +292,19 @@ def _download_thread(self, thread_number): # We will retry logger.debug( f"AI Horeling reported error when downloading {ti['filename']}: " - "{hordeling_response.json()}" + f"{hordeling_response.json()}" f"Retry {retries}/{self.MAX_RETRIES}", ) else: + hordeling_json = hordeling_response.json() response = requests.get( - hordeling_response.json()["url"], + hordeling_json["url"], timeout=self.REQUEST_DOWNLOAD_TIMEOUT, ) response.raise_for_status() # Check the data hash + if hordeling_json.get("sha256"): + ti["sha256"] = hordeling_json["sha256"] hash_object = hashlib.sha256() hash_object.update(response.content) sha256 = hash_object.hexdigest() @@ -531,7 +532,7 @@ def calculate_download_queue(self): total_queue += ti["size_kb"] return total_queue - def _parse_civitai_ti_data(self) -> str | None: + def find_oldest_adhoc_ti(self) -> str | None: oldest_ti: str | None = None oldest_datetime: datetime | None = None for ti in self._adhoc_tis: @@ -706,7 +707,7 @@ def get_ti_last_use(self, ti_name): return None return datetime.strptime(ti["last_used"], "%Y-%m-%d %H:%M:%S") - def fetch_adhoc_ti(self, ti_name, timeout=45): + def fetch_adhoc_ti(self, ti_name, timeout=15): if type(ti_name) is int or ti_name.isdigit(): url = f"https://civitai.com/api/v1/models/{ti_name}" else: @@ -730,7 +731,7 @@ def fetch_adhoc_ti(self, ti_name, timeout=45): if fuzzy_find: logger.error(fuzzy_find) return fuzzy_find - self._download_queue.append(ti) + self._download_ti(ti) # We need to wait a bit to make sure the threads pick up the download time.sleep(self.THREAD_WAIT_TIME) self.wait_for_downloads(timeout) diff --git a/tests/model_managers/test_mm_ti.py b/tests/model_managers/test_mm_ti.py index 4ea74e76..7aa4dd47 100644 --- a/tests/model_managers/test_mm_ti.py +++ b/tests/model_managers/test_mm_ti.py @@ -5,43 +5,42 @@ from hordelib.model_manager.ti import TextualInversionModelManager -class TestModelManagerLora: +class TestModelManagerTI: @pytest.fixture(autouse=True) def setup_and_teardown(self): - # We don't want to download a ton of loras for tests by mistake + # We don't want to download a ton of tis for tests by mistake assert os.getenv("TESTS_ONGOING") == "1" def test_fuzzy_search(self): mmti = TextualInversionModelManager() mmti.fetch_adhoc_ti("56519") - assert mmti.fuzzy_find_lora_key("negative_hand Negative Embedding") == "glowingrunesai" - assert mmti.fuzzy_find_lora_key("negative_hand") is None - assert mmti.fuzzy_find_lora_key(56519) == "blindbox/da gai shi mang he" - assert mmti.fuzzy_find_lora_key("56519") == "blindbox/da gai shi mang he" + assert mmti.fuzzy_find_ti_key("negative_hand") == "negative_hand negative embedding" + assert mmti.fuzzy_find_ti_key(56519) == "negative_hand negative embedding" + assert mmti.fuzzy_find_ti_key("56519") == "negative_hand negative embedding" mmti.stop_all() def test_ti_trigger_search(self): mmti = TextualInversionModelManager() mmti.fetch_adhoc_ti("71961") - assert mmti.get_lora_name("Fast Negative Embedding") == "GlowingRunesAI" + assert mmti.get_ti_name("Fast Negative Embedding") == "Fast Negative Embedding (+ FastNegativeV2)" assert len(mmti.get_ti_triggers("Fast Negative Embedding")) > 0 # We can't rely on triggers not changing - assert mmti.find_lora_trigger("Fast Negative Embedding", "FastNegativeV2") is not None + assert mmti.find_ti_trigger("Fast Negative Embedding", "FastNegativeV2") is not None mmti.stop_all() def test_fetch_adhoc_ti(self): mmti = TextualInversionModelManager() mmti.ensure_ti_deleted(7808) - lora_key = mmti.fetch_adhoc_ti("7808") - assert lora_key == "EasyNegative".lower() + ti_key = mmti.fetch_adhoc_ti("7808") + assert ti_key == "EasyNegative".lower() assert mmti.is_local_model("7808") - assert mmti.get_lora_name("7808") == "EasyNegative" + assert mmti.get_ti_name("7808") == "EasyNegative" mmti.stop_all() def test_adhoc_non_existing(self): mmti = TextualInversionModelManager() ti_name = "__THIS SHOULD NOT EXIST. I SWEAR IF ONE OF YOU UPLOADS A TI WITH THIS NAME I AM GOING TO BE UPSET!" - lora_key = mmti.fetch_adhoc_lora(ti_name) - assert lora_key is None + ti_key = mmti.fetch_adhoc_ti(ti_name) + assert ti_key is None assert not mmti.is_local_model(ti_name) mmti.stop_all() diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index f3de80e6..51dc8369 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -8,6 +8,7 @@ from hordelib.horde import HordeLib from hordelib.shared_model_manager import SharedModelManager + class TestHordeTI: @pytest.fixture(autouse=True, scope="class") def setup_and_teardown(self, shared_model_manager: type[SharedModelManager]): @@ -16,7 +17,9 @@ def setup_and_teardown(self, shared_model_manager: type[SharedModelManager]): def test_basic_ti( self, + shared_model_manager: type[SharedModelManager], hordelib_instance: HordeLib, + stable_diffusion_model_name_for_testing: str, ): assert True return @@ -39,7 +42,7 @@ def test_basic_ti( "atmospheric lighting, embedding:style-sylvamagic###(embedding:easynegative:0.5), embedding:bhands-neg", "ddim_steps": 20, "n_iter": 1, - "model": stable_diffusion_modelname_for_testing, + "model": stable_diffusion_model_name_for_testing, } pil_image = hordelib_instance.basic_inference(data) From 1c70940407e91e0a27bf1d7d7871ee8aeba31c12 Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 25 Aug 2023 14:15:04 +0200 Subject: [PATCH 06/15] wip test --- hordelib/horde.py | 20 ++++++++++++++++++++ hordelib/model_manager/hyper.py | 7 +++++++ hordelib/model_manager/ti.py | 2 +- tests/test_horde_ti.py | 9 ++++++--- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/hordelib/horde.py b/hordelib/horde.py index 213da6c5..d817279d 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -287,6 +287,26 @@ def _final_pipeline_adjustments(self, payload, pipeline_data): if payload.get("control_type"): payload["control_type"] = HordeLib.CONTROLNET_IMAGE_PREPROCESSOR_MAP.get(payload.get("control_type")) + if payload.get("tis") and SharedModelManager.manager.ti: + # Remove any requested TIs that we don't have + for ti in payload.get("tis"): + # Determine the actual lora filename + if not SharedModelManager.manager.ti.is_local_model(str(ti["name"])): + adhoc_ti = SharedModelManager.manager.ti.fetch_adhoc_ti(str(ti["name"])) + if not adhoc_ti: + logger.info(f"Adhoc TI requested '{ti['name']}' could not be found in CivitAI. Ignoring!") + continue + ti_name = SharedModelManager.manager.lora.get_ti_name(str(ti["name"])) + if ti_name: + logger.debug(f"Found valid TI {ti_name}") + if SharedModelManager.manager.compvis is None: + raise RuntimeError("Cannot use TIs without compvis loaded!") + model_details = SharedModelManager.manager.compvis.get_model(payload["model"]) + # If the lora and model do not match baseline, we ignore the TI + if not SharedModelManager.manager.ti.do_baselines_match(ti_name, model_details): + logger.info(f"Skipped TI {ti_name} because its baseline does not match the model's") + continue + SharedModelManager.manager.ti.touch_ti(ti_name) # Setup controlnet if required # For LORAs we completely build the LORA section of the pipeline dynamically, as we have # to handle n LORA models which form chained nodes in the pipeline. diff --git a/hordelib/model_manager/hyper.py b/hordelib/model_manager/hyper.py index aebb3138..f382e459 100644 --- a/hordelib/model_manager/hyper.py +++ b/hordelib/model_manager/hyper.py @@ -20,6 +20,7 @@ from hordelib.model_manager.gfpgan import GfpganModelManager from hordelib.model_manager.lora import LoraModelManager from hordelib.model_manager.safety_checker import SafetyCheckerModelManager +from hordelib.model_manager.ti import TextualInversionModelManager from hordelib.settings import UserSettings MODEL_MANAGERS_TYPE_LOOKUP: dict[MODEL_CATEGORY_NAMES | str, type[BaseModelManager]] = { @@ -85,6 +86,12 @@ def lora(self) -> LoraModelManager | None: found_mm = self.get_mm_pointer(LoraModelManager) return found_mm if isinstance(found_mm, LoraModelManager) else None + @property + def ti(self) -> TextualInversionModelManager | None: + """The Textual Inversion model manager instance. Returns `None` if not loaded.""" + found_mm = self.get_mm_pointer(TextualInversionModelManager) + return found_mm if isinstance(found_mm, TextualInversionModelManager) else None + @property def blip(self) -> BlipModelManager | None: """The Blip model manager instance. Returns `None` if not loaded.""" diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index 6aa10125..df077610 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -198,7 +198,7 @@ def _parse_civitai_ti_data(self, item, adhoc=False): ti["name"] = Sanitizer.sanitise_model_name(item.get("name", "")) ti["orig_name"] = item.get("name", "") ti["id"] = item.get("id", 0) - ti["filename"] = f'{Sanitizer.sanitise_filename(ti["name"])}.safetensors' + ti["filename"] = f'{ti["id"]}.safetensors' ti["sha256"] = file.get("hashes", {}).get("SHA256") try: ti["size_kb"] = round(file.get("sizeKB", 0)) diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index 51dc8369..29f658f5 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -21,8 +21,6 @@ def test_basic_ti( hordelib_instance: HordeLib, stable_diffusion_model_name_for_testing: str, ): - assert True - return data = { "sampler_name": "k_euler", "cfg_scale": 8.0, @@ -39,7 +37,12 @@ def test_basic_ti( "return_control_map": False, "prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, " "walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, " - "atmospheric lighting, embedding:style-sylvamagic###(embedding:easynegative:0.5), embedding:bhands-neg", + "atmospheric lighting, embedding:7523###(embedding:7808:0.5), embedding:64870", + "tis": [ + {"name": 7523}, + {"name": 7808}, + {"name": 64870}, + ], "ddim_steps": 20, "n_iter": 1, "model": stable_diffusion_model_name_for_testing, From 3c5a3bf25740a2832e70af1636882b7aca87fe27 Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 25 Aug 2023 14:52:38 +0200 Subject: [PATCH 07/15] working TI test --- hordelib/horde.py | 10 ++++++---- hordelib/model_manager/hyper.py | 1 + tests/conftest.py | 2 ++ tests/model_managers/test_mm_ti.py | 2 ++ tests/test_horde_ti.py | 4 ++++ 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/hordelib/horde.py b/hordelib/horde.py index d817279d..29c7e496 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -87,6 +87,7 @@ class HordeLib: "prompt": {"datatype": str, "default": ""}, "negative_prompt": {"datatype": str, "default": ""}, "loras": {"datatype": list, "default": []}, + "tis": {"datatype": list, "default": []}, "ddim_steps": {"datatype": int, "min": 1, "max": 500, "default": 30}, "n_iter": {"datatype": int, "min": 1, "max": 100, "default": 1}, "model": {"datatype": str, "default": "stable_diffusion"}, @@ -286,23 +287,24 @@ def _final_pipeline_adjustments(self, payload, pipeline_data): # Remap controlnet models, horde to comfy if payload.get("control_type"): payload["control_type"] = HordeLib.CONTROLNET_IMAGE_PREPROCESSOR_MAP.get(payload.get("control_type")) - + logger.debug(payload.get("tis")) if payload.get("tis") and SharedModelManager.manager.ti: # Remove any requested TIs that we don't have for ti in payload.get("tis"): - # Determine the actual lora filename + # Determine the actual ti filename + logger.debug(ti) if not SharedModelManager.manager.ti.is_local_model(str(ti["name"])): adhoc_ti = SharedModelManager.manager.ti.fetch_adhoc_ti(str(ti["name"])) if not adhoc_ti: logger.info(f"Adhoc TI requested '{ti['name']}' could not be found in CivitAI. Ignoring!") continue - ti_name = SharedModelManager.manager.lora.get_ti_name(str(ti["name"])) + ti_name = SharedModelManager.manager.ti.get_ti_name(str(ti["name"])) if ti_name: logger.debug(f"Found valid TI {ti_name}") if SharedModelManager.manager.compvis is None: raise RuntimeError("Cannot use TIs without compvis loaded!") model_details = SharedModelManager.manager.compvis.get_model(payload["model"]) - # If the lora and model do not match baseline, we ignore the TI + # If the ti and model do not match baseline, we ignore the TI if not SharedModelManager.manager.ti.do_baselines_match(ti_name, model_details): logger.info(f"Skipped TI {ti_name} because its baseline does not match the model's") continue diff --git a/hordelib/model_manager/hyper.py b/hordelib/model_manager/hyper.py index f382e459..df854d5b 100644 --- a/hordelib/model_manager/hyper.py +++ b/hordelib/model_manager/hyper.py @@ -32,6 +32,7 @@ MODEL_CATEGORY_NAMES.gfpgan: GfpganModelManager, MODEL_CATEGORY_NAMES.safety_checker: SafetyCheckerModelManager, MODEL_CATEGORY_NAMES.lora: LoraModelManager, + MODEL_CATEGORY_NAMES.ti: TextualInversionModelManager, MODEL_CATEGORY_NAMES.blip: BlipModelManager, MODEL_CATEGORY_NAMES.clip: ClipModelManager, } diff --git a/tests/conftest.py b/tests/conftest.py index 15e1d61c..56c84ceb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,6 +49,7 @@ def shared_model_manager(hordelib_instance: HordeLib) -> Generator[type[SharedMo assert SharedModelManager.manager.gfpgan is not None assert SharedModelManager.manager.safety_checker is not None assert SharedModelManager.manager.lora is not None + assert SharedModelManager.manager.ti is not None assert SharedModelManager.manager.blip is not None assert SharedModelManager.manager.clip is not None @@ -102,6 +103,7 @@ def pytest_collection_modifyitems(items): "tests.test_payload_mapping", "test_shared_model_manager", "test_mm_lora", + "test_mm_ti", ] MODULES_TO_RUN_LAST = [ "test.test_blip", diff --git a/tests/model_managers/test_mm_ti.py b/tests/model_managers/test_mm_ti.py index 7aa4dd47..e9a45efb 100644 --- a/tests/model_managers/test_mm_ti.py +++ b/tests/model_managers/test_mm_ti.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pytest @@ -35,6 +36,7 @@ def test_fetch_adhoc_ti(self): assert ti_key == "EasyNegative".lower() assert mmti.is_local_model("7808") assert mmti.get_ti_name("7808") == "EasyNegative" + assert Path(os.path.join(mmti.modelFolderPath, "7808.safetensors")).exists() mmti.stop_all() def test_adhoc_non_existing(self): diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index 29f658f5..b2977ad7 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -1,6 +1,7 @@ # test_horde_ti.py import os from datetime import datetime, timedelta +from pathlib import Path import pytest from PIL import Image @@ -50,6 +51,9 @@ def test_basic_ti( pil_image = hordelib_instance.basic_inference(data) assert pil_image is not None + assert ( + Path(os.path.join(shared_model_manager.manager.ti.modelFolderPath, "64870.safetensors")).exists() is True + ) img_filename = "ti_basic.png" pil_image.save(f"images/{img_filename}", quality=100) From 4fba77563cbfda2fe8a7bba528e3f9bc318eec1d Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 25 Aug 2023 09:56:34 -0400 Subject: [PATCH 08/15] tests: quiet potentially unset error --- tests/test_horde_ti.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index b2977ad7..41795fb7 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -11,17 +11,14 @@ class TestHordeTI: - @pytest.fixture(autouse=True, scope="class") - def setup_and_teardown(self, shared_model_manager: type[SharedModelManager]): - assert shared_model_manager.manager.ti - yield - def test_basic_ti( self, shared_model_manager: type[SharedModelManager], hordelib_instance: HordeLib, stable_diffusion_model_name_for_testing: str, ): + assert shared_model_manager.manager.ti + data = { "sampler_name": "k_euler", "cfg_scale": 8.0, From 99700d1b63d7ecbf3c3a712ce314469ef6651dd8 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 25 Aug 2023 09:59:46 -0400 Subject: [PATCH 09/15] fix: TI/Lora model reference json now is now with other reference files --- hordelib/model_manager/base.py | 14 +++++++++----- hordelib/model_manager/lora.py | 4 +++- hordelib/model_manager/ti.py | 3 +++ 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/hordelib/model_manager/base.py b/hordelib/model_manager/base.py index d5db175e..77df5454 100644 --- a/hordelib/model_manager/base.py +++ b/hordelib/model_manager/base.py @@ -75,6 +75,7 @@ def __init__( *, download_reference: bool = False, model_category_name: MODEL_CATEGORY_NAMES = MODEL_CATEGORY_NAMES.default_models, + model_db_path: Path | None = None, ): """Create a new instance of this model manager. @@ -94,10 +95,14 @@ def __init__( self.tainted_models = [] self.pkg = importlib_resources.files("hordelib") # XXX Remove self.models_db_name = MODEL_DB_NAMES[model_category_name] - self.models_db_path = Path(get_hordelib_path()).joinpath( - "model_database/", - f"{self.models_db_name}.json", - ) + + if model_db_path: + self.models_db_path = model_db_path + else: + self.models_db_path = Path(get_hordelib_path()).joinpath( + "model_database/", + f"{self.models_db_name}.json", + ) self.cuda_available = torch.cuda.is_available() # XXX Remove? self.cuda_devices, self.recommended_gpu = self.get_cuda_devices() # XXX Remove? @@ -700,7 +705,6 @@ def download_file(self, url: str, filename: str) -> bool: remote_file_size = 0 while retries: - if os.path.exists(partial_pathname): # If file exists, find the size and append to it partial_size = os.path.getsize(partial_pathname) diff --git a/hordelib/model_manager/lora.py b/hordelib/model_manager/lora.py index 1dc434d4..490d0eed 100644 --- a/hordelib/model_manager/lora.py +++ b/hordelib/model_manager/lora.py @@ -72,12 +72,14 @@ def __init__( self._index_ids = {} self._index_orig_names = {} + models_db_path = LEGACY_REFERENCE_FOLDER.joinpath("lora.json").resolve() + super().__init__( model_category_name=MODEL_CATEGORY_NAMES.lora, download_reference=download_reference, + models_db_path=models_db_path, ) # FIXME (shift lora.json handling into horde_model_reference?) - self.models_db_path = LEGACY_REFERENCE_FOLDER.joinpath("lora.json").resolve() def loadModelDatabase(self, list_models=False): if self.model_reference: diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index df077610..3d7b176f 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -10,6 +10,7 @@ from collections import deque from datetime import datetime, timedelta from enum import auto +from horde_model_reference import LEGACY_REFERENCE_FOLDER import requests from fuzzywuzzy import fuzz @@ -63,9 +64,11 @@ def __init__( self._index_ids = {} self._index_orig_names = {} + models_db_path = LEGACY_REFERENCE_FOLDER.joinpath("ti.json").resolve() super().__init__( model_category_name=MODEL_CATEGORY_NAMES.ti, download_reference=download_reference, + model_db_path=models_db_path, ) def loadModelDatabase(self, list_models=False): From cabb7a8c10a9ecd81faa63408b60fca14297491b Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 25 Aug 2023 10:01:39 -0400 Subject: [PATCH 10/15] docs: moved class inline comments to docstrings --- hordelib/model_manager/lora.py | 10 +++++++--- hordelib/model_manager/ti.py | 21 +++++++++++++-------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/hordelib/model_manager/lora.py b/hordelib/model_manager/lora.py index 490d0eed..cd648cc3 100644 --- a/hordelib/model_manager/lora.py +++ b/hordelib/model_manager/lora.py @@ -38,10 +38,14 @@ class LoraModelManager(BaseModelManager): LORA_API = "https://civitai.com/api/v1/models?types=LORA&sort=Highest%20Rated&primaryFileOnly=true" MAX_RETRIES = 10 if not TESTS_ONGOING else 1 MAX_DOWNLOAD_THREADS = 3 # max concurrent downloads - RETRY_DELAY = 5 # seconds - REQUEST_METADATA_TIMEOUT = 30 # seconds - REQUEST_DOWNLOAD_TIMEOUT = 300 # seconds + RETRY_DELAY = 5 + """The time to wait between retries in seconds""" + REQUEST_METADATA_TIMEOUT = 30 + """The time to wait for a response from the server in seconds""" + REQUEST_DOWNLOAD_TIMEOUT = 300 + """The time to wait for a response from the server in seconds""" THREAD_WAIT_TIME = 2 + """The time to wait between checking the download queue in seconds""" def __init__( self, diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index 3d7b176f..03bc3beb 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -33,14 +33,19 @@ class DOWNLOAD_SIZE_CHECK(StrEnum): class TextualInversionModelManager(BaseModelManager): - TI_API = "https://civitai.com/api/v1/models?types=TextualInversion&sort=Highest%20Rated&primaryFileOnly=true" - HORDELING_API = "https://hordeling.aihorde.net/api/v1/embedding" - MAX_RETRIES = 10 if not TESTS_ONGOING else 1 - MAX_DOWNLOAD_THREADS = 3 # max concurrent downloads - RETRY_DELAY = 5 # seconds - REQUEST_METADATA_TIMEOUT = 30 # seconds - REQUEST_DOWNLOAD_TIMEOUT = 30 # seconds - THREAD_WAIT_TIME = 2 + TI_API: str = "https://civitai.com/api/v1/models?types=TextualInversion&sort=Highest%20Rated&primaryFileOnly=true" + HORDELING_API: str = "https://hordeling.aihorde.net/api/v1/embedding" + MAX_RETRIES: int = 10 if not TESTS_ONGOING else 1 + MAX_DOWNLOAD_THREADS: int = 3 + """The number of threads to use for downloading (the max number of concurrent downloads).""" + RETRY_DELAY: int = 5 + """The time to wait between retries in seconds""" + REQUEST_METADATA_TIMEOUT: int = 30 + """The time to wait for a metadata request to complete in seconds""" + REQUEST_DOWNLOAD_TIMEOUT: int = 30 + """The time to wait for a download request to complete in seconds""" + THREAD_WAIT_TIME: int = 2 + """The time to wait between checking the download queue in seconds""" def __init__( self, From 2da8c51462d78759c919709fa3fa9cf44752bbd7 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 25 Aug 2023 10:05:20 -0400 Subject: [PATCH 11/15] style: lint violation (unsorted import) --- hordelib/model_manager/ti.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index 03bc3beb..fe6692f8 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -10,10 +10,10 @@ from collections import deque from datetime import datetime, timedelta from enum import auto -from horde_model_reference import LEGACY_REFERENCE_FOLDER import requests from fuzzywuzzy import fuzz +from horde_model_reference import LEGACY_REFERENCE_FOLDER from loguru import logger from strenum import StrEnum from typing_extensions import override From 77c70b0c7598915dbe3e0db279ac4f6215bbfc9f Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 25 Aug 2023 10:12:53 -0400 Subject: [PATCH 12/15] fix: variable name typo (missing an intended `s`) --- hordelib/model_manager/base.py | 6 +++--- hordelib/model_manager/ti.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hordelib/model_manager/base.py b/hordelib/model_manager/base.py index 77df5454..4741c676 100644 --- a/hordelib/model_manager/base.py +++ b/hordelib/model_manager/base.py @@ -75,7 +75,7 @@ def __init__( *, download_reference: bool = False, model_category_name: MODEL_CATEGORY_NAMES = MODEL_CATEGORY_NAMES.default_models, - model_db_path: Path | None = None, + models_db_path: Path | None = None, ): """Create a new instance of this model manager. @@ -96,8 +96,8 @@ def __init__( self.pkg = importlib_resources.files("hordelib") # XXX Remove self.models_db_name = MODEL_DB_NAMES[model_category_name] - if model_db_path: - self.models_db_path = model_db_path + if models_db_path: + self.models_db_path = models_db_path else: self.models_db_path = Path(get_hordelib_path()).joinpath( "model_database/", diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index fe6692f8..bea3bacd 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -73,7 +73,7 @@ def __init__( super().__init__( model_category_name=MODEL_CATEGORY_NAMES.ti, download_reference=download_reference, - model_db_path=models_db_path, + models_db_path=models_db_path, ) def loadModelDatabase(self, list_models=False): From 5ecfacf2186060b6793549c8c0b461fa10424bab Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 25 Aug 2023 15:58:32 +0200 Subject: [PATCH 13/15] working ti ibjects --- hordelib/horde.py | 13 ++++++++++-- hordelib/model_manager/ti.py | 8 ++++++++ tests/test_horde_ti.py | 40 ++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/hordelib/horde.py b/hordelib/horde.py index 29c7e496..6e028023 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -287,12 +287,10 @@ def _final_pipeline_adjustments(self, payload, pipeline_data): # Remap controlnet models, horde to comfy if payload.get("control_type"): payload["control_type"] = HordeLib.CONTROLNET_IMAGE_PREPROCESSOR_MAP.get(payload.get("control_type")) - logger.debug(payload.get("tis")) if payload.get("tis") and SharedModelManager.manager.ti: # Remove any requested TIs that we don't have for ti in payload.get("tis"): # Determine the actual ti filename - logger.debug(ti) if not SharedModelManager.manager.ti.is_local_model(str(ti["name"])): adhoc_ti = SharedModelManager.manager.ti.fetch_adhoc_ti(str(ti["name"])) if not adhoc_ti: @@ -308,6 +306,17 @@ def _final_pipeline_adjustments(self, payload, pipeline_data): if not SharedModelManager.manager.ti.do_baselines_match(ti_name, model_details): logger.info(f"Skipped TI {ti_name} because its baseline does not match the model's") continue + ti_inject = ti.get("inject_ti") + ti_strength = ti.get("strength", 1.0) + if type(ti_strength) not in [float, int]: + ti_strength = 1.0 + ti_id = SharedModelManager.manager.ti.get_ti_id(str(ti["name"])) + if ti_inject == "prompt": + payload["prompt"] = f'(embedding:{ti_id}:{ti_strength}),{payload["prompt"]}' + elif ti_inject == "negprompt": + if "###" not in payload["prompt"]: + payload["prompt"] += "###" + payload["prompt"] = f'{payload["prompt"]},(embedding:{ti_id}:{ti_strength})' SharedModelManager.manager.ti.touch_ti(ti_name) # Setup controlnet if required # For LORAs we completely build the LORA section of the pipeline dynamically, as we have diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index bea3bacd..2544a0fa 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -484,6 +484,14 @@ def get_ti_name(self, model_name: str): return None return ti["name"] + def get_ti_id(self, model_name: str): + """Returns the civitai ti ID for the specified model_name search string + Returns None if ti name not found""" + ti = self.get_model(model_name) + if not ti: + return None + return ti["id"] + def get_ti_triggers(self, model_name: str): """Returns a list of triggers for a specified ti name Returns an empty list if no triggers are found diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index 41795fb7..d0d658a8 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -55,6 +55,46 @@ def test_basic_ti( img_filename = "ti_basic.png" pil_image.save(f"images/{img_filename}", quality=100) + def test_inject_ti( + self, + shared_model_manager: type[SharedModelManager], + hordelib_instance: HordeLib, + stable_diffusion_model_name_for_testing: str, + ): + data = { + "sampler_name": "k_euler", + "cfg_scale": 8.0, + "denoising_strength": 1.0, + "seed": 1312, + "height": 512, + "width": 512, + "karras": False, + "tiling": False, + "hires_fix": False, + "clip_skip": 1, + "control_type": None, + "image_is_control": False, + "return_control_map": False, + "prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, " + "walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, " + "atmospheric lighting", + "tis": [ + {"name": 7523, "inject_ti": "prompt", "strength": 0.5}, + {"name": 7808, "inject_ti": "negprompt", "strength": 0.5}, + {"name": 64870, "inject_ti": "negprompt", "strength": 0.5}, + ], + "ddim_steps": 20, + "n_iter": 1, + "model": stable_diffusion_model_name_for_testing, + } + + pil_image = hordelib_instance.basic_inference(data) + assert pil_image is not None + assert False + + img_filename = "ti_inject.png" + pil_image.save(f"images/{img_filename}", quality=100) + # assert check_single_lora_image_similarity( # f"images_expected/{img_filename}", # pil_image, From 2b70922798f99f43bfb21acf6cbf3de3acdda87f Mon Sep 17 00:00:00 2001 From: db0 Date: Fri, 25 Aug 2023 16:18:56 +0200 Subject: [PATCH 14/15] fix downloaded tis not checking correct hash from hordeling --- hordelib/model_manager/ti.py | 79 +++++++++++++++++++----------------- tests/test_horde_ti.py | 42 ++++++++++++++++++- 2 files changed, 82 insertions(+), 39 deletions(-) diff --git a/hordelib/model_manager/ti.py b/hordelib/model_manager/ti.py index 2544a0fa..fa8c7241 100644 --- a/hordelib/model_manager/ti.py +++ b/hordelib/model_manager/ti.py @@ -262,70 +262,73 @@ def _download_thread(self, thread_number): while retries <= self.MAX_RETRIES: try: # Just before we download this file, check if we already have it - filename = os.path.join(self.modelFolderPath, ti["filename"]) - hashfile = f"{os.path.splitext(filename)[0]}.sha256" - if os.path.exists(filename) and os.path.exists(hashfile): - # Check the hash - with open(hashfile) as infile: - try: - hashdata = infile.read().split()[0] - except (IndexError, OSError, IOError, PermissionError): - hashdata = "" - if not ti.get("sha256") or hashdata.lower() == ti["sha256"].lower(): - # we already have this ti, consider it downloaded - # the SHA256 might not exist when the ti has been selected in the curation list - # Where we allow them to skip it - if not ti.get("sha256"): - logger.debug( - f"Already have Textual Inversion {ti['filename']}. " - "Bypassing SHA256 check as there's none stored", - ) - else: - logger.debug(f"Already have Textual Inversion {ti['filename']}") - with self._mutex: - # We store as lower to allow case-insensitive search - ti["last_checked"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - self._add_ti_to_reference(ti) - if self.is_default_cache_full(): - self.stop_downloading = True - self.save_cached_reference_to_disk() - break - - logger.info(f"Starting download of Textual Inversion {ti['filename']}") + filepath = os.path.join(self.modelFolderPath, ti["filename"]) + hashpath = f"{os.path.splitext(filepath)[0]}.sha256" + logger.debug(f"Retrieving TI metadata from Hordeling for ID: {ti['filename']}") hordeling_response = requests.get(f"{self.HORDELING_API}/{ti['id']}", timeout=5) if not hordeling_response.ok: if hordeling_response.status_code == 404: - logger.debug(f"Textual Inversion {ti['filename']} could not be found on AI Hordeling.") + logger.debug(f"Textual Inversion: {ti['filename']} could not be found on AI Hordeling.") break # We will retry logger.debug( - f"AI Horeling reported error when downloading {ti['filename']}: " + "AI Horeling reported error when downloading metadata " + f"for Textual Inversion: {ti['filename']}: " f"{hordeling_response.json()}" f"Retry {retries}/{self.MAX_RETRIES}", ) else: hordeling_json = hordeling_response.json() + if os.path.exists(filepath) and os.path.exists(hashpath): + if hordeling_json.get("sha256"): + ti["sha256"] = hordeling_json["sha256"] + # Check the hash + with open(hashpath) as infile: + try: + hashdata = infile.read().split()[0] + except (IndexError, OSError, IOError, PermissionError): + hashdata = "" + + if not ti.get("sha256") or hashdata.lower() == ti["sha256"].lower(): + # we already have this ti, consider it downloaded + # the SHA256 might not exist when the ti has been selected in the curation list + # Where we allow them to skip it + if not ti.get("sha256"): + logger.debug( + f"Already have Textual Inversion: {ti['filename']}. " + "Bypassing SHA256 check as there's none stored", + ) + else: + logger.debug(f"Already have Textual Inversion: {ti['filename']}") + with self._mutex: + # We store as lower to allow case-insensitive search + ti["last_checked"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self._add_ti_to_reference(ti) + if self.is_default_cache_full(): + self.stop_downloading = True + self.save_cached_reference_to_disk() + break + + logger.info(f"Starting download of Textual Inversion: {ti['filename']}") response = requests.get( hordeling_json["url"], timeout=self.REQUEST_DOWNLOAD_TIMEOUT, ) response.raise_for_status() # Check the data hash - if hordeling_json.get("sha256"): - ti["sha256"] = hordeling_json["sha256"] hash_object = hashlib.sha256() hash_object.update(response.content) sha256 = hash_object.hexdigest() if not ti.get("sha256") or sha256.lower() == ti["sha256"].lower(): # wow, we actually got a valid file, save it - with open(filename, "wb") as outfile: + with open(filepath, "wb") as outfile: outfile.write(response.content) # Save the hash file - with open(hashfile, "wt") as outfile: + with open(hashpath, "wt") as outfile: outfile.write(f"{sha256} *{ti['filename']}") # Shout about it - logger.info(f"Downloaded Textual Inversion {ti['filename']} ({ti['size_kb']} KB)") + logger.info(f"Downloaded Textual Inversion: {ti['filename']} ({ti['size_kb']} KB)") # Maybe we're done with self._mutex: # We store as lower to allow case-insensitive search @@ -339,7 +342,7 @@ def _download_thread(self, thread_number): else: # We will retry logger.debug( - f"Downloaded Textual Inversion file for {ti['filename']} didn't match hash. " + f"Downloaded Textual Inversion file {ti['filename']} didn't match hash. " f"Retry {retries}/{self.MAX_RETRIES}", ) diff --git a/tests/test_horde_ti.py b/tests/test_horde_ti.py index d0d658a8..4aef6f29 100644 --- a/tests/test_horde_ti.py +++ b/tests/test_horde_ti.py @@ -90,7 +90,6 @@ def test_inject_ti( pil_image = hordelib_instance.basic_inference(data) assert pil_image is not None - assert False img_filename = "ti_inject.png" pil_image.save(f"images/{img_filename}", quality=100) @@ -99,3 +98,44 @@ def test_inject_ti( # f"images_expected/{img_filename}", # pil_image, # ) + + def test_bad_inject_ti( + self, + shared_model_manager: type[SharedModelManager], + hordelib_instance: HordeLib, + stable_diffusion_model_name_for_testing: str, + ): + data = { + "sampler_name": "k_euler", + "cfg_scale": 8.0, + "denoising_strength": 1.0, + "seed": 1312, + "height": 512, + "width": 512, + "karras": False, + "tiling": False, + "hires_fix": False, + "clip_skip": 1, + "control_type": None, + "image_is_control": False, + "return_control_map": False, + "prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, " + "walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, " + "atmospheric lighting", + "tis": [ + {"name": 7523, "inject_ti": "prompt", "strength": "0.5"}, + {"name": 7808, "inject_ti": "negprompt", "strength": None}, + {"name": 64870, "inject_ti": "YOLO", "strength": "YOLO"}, + ], + "ddim_steps": 20, + "n_iter": 1, + "model": stable_diffusion_model_name_for_testing, + } + + pil_image = hordelib_instance.basic_inference(data) + assert pil_image is not None + + # assert check_single_lora_image_similarity( + # f"images_expected/{img_filename}", + # pil_image, + # ) From df2f7016d0253db88671caa3f1ad4cb000dfdfd5 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 25 Aug 2023 11:45:33 -0400 Subject: [PATCH 15/15] fix: embeddings now directly notifies comfy --- hordelib/comfy_horde.py | 11 +++++++++-- hordelib/model_manager/compvis.py | 6 ++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/hordelib/comfy_horde.py b/hordelib/comfy_horde.py index d88e9850..edf6842e 100644 --- a/hordelib/comfy_horde.py +++ b/hordelib/comfy_horde.py @@ -325,11 +325,18 @@ def __init__(self) -> None: self._callers = 0 self._gc_timer = time.time() self._counter_mutex = threading.Lock() - # FIXME Temporary hack to set the model dir for LORAs - _comfy_folder_paths["loras"] = ( # type: ignore + + # These set the default paths for comfyui to look for models and embeddings. From within hordelib, + # we aren't ever going to use the default ones, and this may help lubricate any further changes. + _comfy_folder_paths["lora"] = ( # type: ignore [os.path.join(UserSettings.get_model_directory(), "loras")], [".safetensors"], ) + _comfy_folder_paths["embeddings"] = ( # type: ignore + [os.path.join(UserSettings.get_model_directory(), "ti")], + [".safetensors"], + ) + # Set custom node path _comfy_folder_paths["custom_nodes"] = ([os.path.join(get_hordelib_path(), "nodes")], []) # type: ignore # Load our pipelines diff --git a/hordelib/model_manager/compvis.py b/hordelib/model_manager/compvis.py index 76e7c436..d4b14cf8 100644 --- a/hordelib/model_manager/compvis.py +++ b/hordelib/model_manager/compvis.py @@ -36,7 +36,9 @@ def modelToRam( model_name: str, **kwargs, ) -> dict[str, typing.Any]: - embeddings_path = os.getenv("HORDE_MODEL_DIR_EMBEDDINGS", "./") + embeddings_path = os.path.join(UserSettings.get_model_directory(), "ti") + if not embeddings_path: + logger.debug("No embeddings path found, disabling embeddings") if not kwargs.get("local", False): ckpt_path = self.getFullModelPath(model_name) @@ -44,7 +46,7 @@ def modelToRam( ckpt_path = os.path.join(self.modelFolderPath, model_name) return horde_load_checkpoint( ckpt_path=ckpt_path, - embeddings_path=embeddings_path, + embeddings_path=embeddings_path if embeddings_path else None, ) def can_cache_on_disk(self):