Skip to content

Commit

Permalink
Merge pull request #171 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: improve lora model management
  • Loading branch information
tazlin authored Jan 24, 2024
2 parents 5a6c746 + 75341ff commit 91ecd30
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 44 deletions.
96 changes: 52 additions & 44 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,51 +502,59 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
lora_name = str(lora["name"])
else:
lora_name = SharedModelManager.manager.lora.get_lora_name(str(lora["name"]))
if lora_name:
logger.debug(f"Found valid lora{verstext} {lora_name}")
if SharedModelManager.manager.compvis is None:
raise RuntimeError("Cannot use LORAs without a compvis loaded!")
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model"])
# If the lora and model do not match baseline, we ignore the lora
if not SharedModelManager.manager.lora.do_baselines_match(
if lora_name is None:
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.lora,
value=METADATA_VALUE.download_failed,
ref=lora_name,
),
)
continue
logger.debug(f"Found valid lora{verstext} {lora_name}")
if SharedModelManager.manager.compvis is None:
raise RuntimeError("Cannot use LORAs without a compvis loaded!")
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model"])
# If the lora and model do not match baseline, we ignore the lora
if not SharedModelManager.manager.lora.do_baselines_match(
lora_name,
model_details,
is_version=is_version,
):
logger.info(
f"Skipped lora{verstext} {lora_name} because its baseline does not match the model's",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.lora,
value=METADATA_VALUE.baseline_mismatch,
ref=lora_name,
),
)
continue
trigger_inject = lora.get("inject_trigger")
trigger = None
if trigger_inject == "any":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version=is_version)
if triggers:
trigger = random.choice(triggers)
elif trigger_inject == "all":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version=is_version)
if triggers:
trigger = ", ".join(triggers)
elif trigger_inject is not None:
trigger = SharedModelManager.manager.lora.find_lora_trigger(
lora_name,
model_details,
is_version=is_version,
):
logger.info(
f"Skipped lora{verstext} {lora_name} because its baseline does not match the model's",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.lora,
value=METADATA_VALUE.baseline_mismatch,
ref=lora_name,
),
)
continue
trigger_inject = lora.get("inject_trigger")
trigger = None
if trigger_inject == "any":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version=is_version)
if triggers:
trigger = random.choice(triggers)
elif trigger_inject == "all":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version=is_version)
if triggers:
trigger = ", ".join(triggers)
elif trigger_inject is not None:
trigger = SharedModelManager.manager.lora.find_lora_trigger(
lora_name,
trigger_inject,
is_version,
)
if trigger:
# We inject at the start, to avoid throwing it in a negative prompt
payload["prompt"] = f'{trigger}, {payload["prompt"]}'
# the fixed up and validated filename (Comfy expect the "name" key to be the filename)
lora["name"] = SharedModelManager.manager.lora.get_lora_filename(lora_name, is_version=is_version)
SharedModelManager.manager.lora._touch_lora(lora_name, is_version=is_version)
valid_loras.append(lora)
trigger_inject,
is_version,
)
if trigger:
# We inject at the start, to avoid throwing it in a negative prompt
payload["prompt"] = f'{trigger}, {payload["prompt"]}'
# the fixed up and validated filename (Comfy expect the "name" key to be the filename)
lora["name"] = SharedModelManager.manager.lora.get_lora_filename(lora_name, is_version=is_version)
SharedModelManager.manager.lora._touch_lora(lora_name, is_version=is_version)
valid_loras.append(lora)
payload["loras"] = valid_loras
for lora_index, lora in enumerate(payload.get("loras")):
# Inject a lora node (first lora)
Expand Down
18 changes: 18 additions & 0 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,24 @@ def get_model_reference_info(self, model_name: str | int, is_version: bool = Fal
lora_key = self.find_lora_key_by_version(model_name)
if lora_key is None:
return None
if lora_key not in self.model_reference:
logger.warning(
"LoRa key still in versions but missing from our reference. Attempting to refresh indexes...",
)
self.reload_reference_from_disk()
self.save_cached_reference_to_disk()
lora_key = self.find_lora_key_by_version(model_name)
if lora_key not in self.model_reference:
logger.error(
"Missing lora still in the versions index. Something has gone wrong. "
"Please contact the developers.",
)
else:
logger.info(
"Reloaded LoRa reference and now lora properly cleared. "
"Will not attempt to download it again at this time.",
)
return None
return self.model_reference[self.find_lora_key_by_version(model_name)]
lora_name = self.fuzzy_find_lora_key(model_name)
if not lora_name:
Expand Down

0 comments on commit 91ecd30

Please sign in to comment.