Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Dec 19, 2023
1 parent 19ef35c commit 1f08f32
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 28 deletions.
58 changes: 34 additions & 24 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ def load_model_database(self) -> None:
if self.models_db_path.exists():
try:
self.model_reference = json.loads((self.models_db_path).read_text())

new_model_reference = {}
for old_lora_key in self.model_reference.keys():
lora = self.model_reference[old_lora_key]
if "versions" not in lora:
logger.info(lora)
lora_key = lora["name"].lower().strip()
lora_key = Sanitizer.sanitise_model_name(lora["orig_name"]).lower().strip()
lora_filename = f'{Sanitizer.sanitise_filename(lora["orig_name"])}_{lora["version_id"]}'
version = {
"inject": f'{lora_key}_{lora["version_id"]}',
"filename": f'{lora_key}_{lora["version_id"]}.safetensors',
"inject": lora_filename,
"filename": f"{lora_filename}.safetensors",
"sha256": lora["sha256"],
"adhoc": lora.get("adhoc"),
"size_mb": lora["size_mb"],
Expand All @@ -147,27 +148,29 @@ def load_model_database(self) -> None:
new_hashfile = f"{os.path.splitext(new_filename)[0]}.sha256"
os.rename(old_hashfile, new_hashfile)
new_lora_entry = {
"name": lora["name"].lower().strip(),
"name": lora_key,
"orig_name": lora["orig_name"],
"id": lora["id"],
"nsfw": lora["nsfw"],
"versions": {lora["version_id"]: version},
}
del self.model_reference[old_lora_key]
self.model_reference[lora_key] = new_lora_entry
for lora in self.model_reference.values():
new_model_reference[lora_key] = new_lora_entry
else:
new_model_reference[old_lora_key] = lora
for lora in new_model_reference.values():
self._index_ids[lora["id"]] = lora["name"].lower()
orig_name = lora.get("orig_name", lora["name"]).lower()
self._index_orig_names[orig_name] = lora["name"].lower()

self.model_reference = new_model_reference
logger.info("Loaded model reference from disk.")
except json.JSONDecodeError:
logger.error(f"Could not load {self.models_db_name} model reference from disk! Bad JSON?")
self.model_reference = {}
else:
logger.error(f"Could not load {self.models_db_name} model reference from disk! File not found.")
self.model_reference = {}

with self._mutex:
self.save_cached_reference_to_disk()
logger.info(
(f"Got {len(self.model_reference)} models for {self.models_db_name}.",),
)
Expand Down Expand Up @@ -316,15 +319,16 @@ def _parse_civitai_lora_data(self, item, adhoc=False):
for file in version.get("files", {}):
if file.get("primary", False) and file.get("name", "").endswith(".safetensors"):
sanitized_name = Sanitizer.sanitise_model_name(lora_name)
lora_name = sanitized_name.lower().strip()
lora["name"] = sanitized_name.lower()
lora_filename = f'{Sanitizer.sanitise_filename(lora_name)}_{version.get("id", 0)}'
lora_key = sanitized_name.lower().strip()
lora["name"] = lora_key
lora["orig_name"] = lora_name
lora["id"] = lora_id
lora["nsfw"] = lora_nsfw
lora["versions"] = {}
lora["versions"][lora_version] = {}
lora["versions"][lora_version]["inject"] = f'{lora_name}_{version.get("id", 0)}'
lora["versions"][lora_version]["filename"] = f'{lora_name}_{version.get("id", 0)}.safetensors'
lora["versions"][lora_version]["inject"] = lora_filename
lora["versions"][lora_version]["filename"] = f"{lora_filename}.safetensors"
lora["versions"][lora_version]["sha256"] = file.get("hashes", {}).get("SHA256")
lora["versions"][lora_version]["adhoc"] = adhoc
try:
Expand All @@ -338,7 +342,7 @@ def _parse_civitai_lora_data(self, item, adhoc=False):
lora["versions"][lora_version]["baseModel"] = version.get("baseModel", "SD 1.5")
lora["versions"][lora_version]["version_id"] = version.get("id", 0)
# To be able to refer back to the parent if needed
lora["versions"][lora_version]["lora_key"] = lora_name
lora["versions"][lora_version]["lora_key"] = lora_key
break
# If we don't have everything required, fail
if lora["versions"][lora_version]["adhoc"] and not lora["versions"][lora_version].get("sha256"):
Expand Down Expand Up @@ -535,15 +539,17 @@ def _add_lora_to_reference(self, lora):
self._adhoc_loras.add(lora_key)
# Once added to our set, we don't need to specify it was adhoc anymore
del lora["adhoc"]
# A lora coming to add to our reference should only have 1 version
# we do the version merge in this function
new_lora_version = list(lora["versions"].keys())[0]
if lora_key not in self.model_reference:
self.model_reference[lora_key] = lora
else:
# If we already know of one version of this lora, we simply add the extra version to our versions
new_lora_version = list(lora["versions"].keys())[0]
if new_lora_version not in self.model_reference[lora_key]["versions"]:
self.model_reference[lora_key]["versions"][new_lora_version] = lora["versions"][new_lora_version]
self._index_ids[lora["id"]] = lora_key
self._index_version_ids[lora["versions"][new_lora_version]] = lora_key
self._index_version_ids[new_lora_version] = lora_key
orig_name = lora.get("orig_name", lora["name"]).lower()
self._index_orig_names[orig_name] = lora_key

Expand Down Expand Up @@ -593,12 +599,15 @@ def are_download_threads_idle(self):
return False
return True

def fuzzy_find_lora_key(self, lora_name):
def fuzzy_find_lora_key(self, lora_name: int | str, is_version=False):
# sname = Sanitizer.remove_version(lora_name).lower()
if is_version:
if isinstance(lora_name, int) or lora_name.isdigit():
if int(lora_name) in self._index_version_ids:
return self._index_version_ids[int(lora_name)]
return None
logger.debug(f"Looking for lora {lora_name}")
if isinstance(lora_name, int) or lora_name.isdigit():
if int(lora_name) in self._index_version_ids:
return self._index_version_ids[int(lora_name)]
if int(lora_name) in self._index_ids:
return self._index_ids[int(lora_name)]
return None
Expand Down Expand Up @@ -642,7 +651,8 @@ def get_lora_filename(self, model_name: str):

def get_lora_name(self, model_name: str):
"""Returns the actual lora name for the specified model_name search string
Returns None if lora name not found"""
Returns None if lora name not found
The lora name is effectively the key in the dictionary"""
lora = self.get_model_reference_info(model_name)
if not lora:
return None
Expand Down Expand Up @@ -689,7 +699,8 @@ def calculate_downloaded_loras(self, mode=DOWNLOAD_SIZE_CHECK.everything):
continue
if mode == DOWNLOAD_SIZE_CHECK.adhoc and lora not in self._adhoc_loras:
continue
total_size += self.model_reference[lora]["size_mb"]
for version in self.model_reference[lora]["versions"]:
total_size += self.model_reference[lora]["versions"][version]["size_mb"]
return total_size

def calculate_default_loras_cache(self):
Expand Down Expand Up @@ -809,7 +820,7 @@ def reset_adhoc_loras(self):
sorted_items = []
for plora_key, plora in self._previous_model_reference.items():
for version in plora.get("versions", []):
unsorted_items.append(plora_key, version)
unsorted_items.append((plora_key, version))
try:
sorted_items = sorted(
unsorted_items,
Expand Down Expand Up @@ -978,7 +989,6 @@ def do_baselines_match(self, lora_name, model_details):
def is_model_available(self, model_name):
if model_name in self.model_reference:
return True

found_model_name = self.fuzzy_find_lora_key(model_name)
if found_model_name is None:
return False
Expand Down
8 changes: 4 additions & 4 deletions tests/model_managers/test_mm_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_fuzzy_search(self):
assert lora_model_manager.fuzzy_find_lora_key("Glowing Robots") is None
assert lora_model_manager.fuzzy_find_lora_key("GlowingRobots") is None
assert lora_model_manager.fuzzy_find_lora_key("GlowingRobotsAI") is None
assert lora_model_manager.fuzzy_find_lora_key("blindbox/大概是盲盒") == "blindbox/da gai shi mang he"
assert lora_model_manager.fuzzy_find_lora_key(25995) == "blindbox/da gai shi mang he"
assert lora_model_manager.fuzzy_find_lora_key("25995") == "blindbox/da gai shi mang he"
assert lora_model_manager.fuzzy_find_lora_key("大概是盲盒") == "blindbox/da gai shi mang he"
assert lora_model_manager.fuzzy_find_lora_key("blindbox/大概是盲盒") == "blindbox/da gai shi mang he"
lora_model_manager.stop_all()

def test_lora_search(self):
Expand All @@ -57,9 +57,9 @@ def test_lora_search(self):
)
lora_model_manager.download_default_loras()
lora_model_manager.wait_for_downloads(600)
assert lora_model_manager.get_lora_name("GlowingRunesAI") == "GlowingRunesAI"
assert lora_model_manager.get_lora_name("GlowingRunes") == "GlowingRunesAI"
assert lora_model_manager.get_lora_name("Glowing Runes") == "GlowingRunesAI"
assert lora_model_manager.get_lora_name("GlowingRunesAI") == "glowingrunesai"
assert lora_model_manager.get_lora_name("GlowingRunes") == "glowingrunesai"
assert lora_model_manager.get_lora_name("Glowing Runes") == "glowingrunesai"
assert len(lora_model_manager.get_lora_triggers("GlowingRunesAI")) > 1
# We can't rely on triggers not changing
assert lora_model_manager.find_lora_trigger("GlowingRunesAI", "blue") is not None
Expand Down

0 comments on commit 1f08f32

Please sign in to comment.