Skip to content

Commit

Permalink
working versioned lora test
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Dec 25, 2023
1 parent 453a6e6 commit 4bfd5db
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 35 deletions.
41 changes: 29 additions & 12 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,14 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
valid_loras = []
for lora in payload.get("loras"):
# Determine the actual lora filename
if not SharedModelManager.manager.lora.is_model_available(str(lora["name"])):
is_version: bool = lora.get("is_version", False)
verstext = ""
if is_version:
verstext = " version"
if not SharedModelManager.manager.lora.is_lora_available(str(lora["name"]), is_version):
logger.debug(f"Adhoc lora requested '{lora['name']}' not yet downloaded. Downloading...")
try:
adhoc_lora = SharedModelManager.manager.lora.fetch_adhoc_lora(str(lora["name"]))
adhoc_lora = SharedModelManager.manager.lora.fetch_adhoc_lora(str(lora["name"]), is_version)
except Exception as e:
logger.info(f"Error fetching adhoc lora {lora['name']}: ({type(e).__name__}) {e}")
faults.append(
Expand All @@ -474,7 +478,10 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
)
adhoc_lora = None
if not adhoc_lora:
logger.info(f"Adhoc lora requested '{lora['name']}' could not be found in CivitAI. Ignoring!")
logger.info(
f"Adhoc lora requested{verstext} '{lora['name']} "
"could not be found in CivitAI. Ignoring!",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.lora,
Expand All @@ -484,15 +491,21 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
)
continue
# We store the actual lora name to search for the trigger
lora_name = SharedModelManager.manager.lora.get_lora_name(str(lora["name"]))
# If a version is requested, the lora name we need is the exact version
if is_version:
lora_name = str(lora["name"])
else:
lora_name = SharedModelManager.manager.lora.get_lora_name(str(lora["name"]))
if lora_name:
logger.debug(f"Found valid lora {lora_name}")
logger.debug(f"Found valid lora{verstext} {lora_name}")
if SharedModelManager.manager.compvis is None:
raise RuntimeError("Cannot use LORAs without a compvis loaded!")
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model"])
# If the lora and model do not match baseline, we ignore the lora
if not SharedModelManager.manager.lora.do_baselines_match(lora_name, model_details):
logger.info(f"Skipped lora {lora_name} because its baseline does not match the model's")
if not SharedModelManager.manager.lora.do_baselines_match(lora_name, model_details, is_version):
logger.info(
f"Skipped lora{verstext} {lora_name} because its baseline does not match the model's",
)
faults.append(
GenMetadataEntry(
type=METADATA_TYPE.lora,
Expand All @@ -504,21 +517,25 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
trigger_inject = lora.get("inject_trigger")
trigger = None
if trigger_inject == "any":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name)
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version)
if triggers:
trigger = random.choice(triggers)
elif trigger_inject == "all":
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name)
triggers = SharedModelManager.manager.lora.get_lora_triggers(lora_name, is_version)
if triggers:
trigger = ", ".join(triggers)
elif trigger_inject is not None:
trigger = SharedModelManager.manager.lora.find_lora_trigger(lora_name, trigger_inject)
trigger = SharedModelManager.manager.lora.find_lora_trigger(
lora_name,
trigger_inject,
is_version,
)
if trigger:
# We inject at the start, to avoid throwing it in a negative prompt
payload["prompt"] = f'{trigger}, {payload["prompt"]}'
# the fixed up and validated filename (Comfy expect the "name" key to be the filename)
lora["name"] = SharedModelManager.manager.lora.get_lora_filename(lora_name)
SharedModelManager.manager.lora._touch_lora(lora_name)
lora["name"] = SharedModelManager.manager.lora.get_lora_filename(lora_name, is_version)
SharedModelManager.manager.lora._touch_lora(lora_name, is_version)
valid_loras.append(lora)
payload["loras"] = valid_loras
for lora_index, lora in enumerate(payload.get("loras")):
Expand Down
36 changes: 13 additions & 23 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def load_model_database(self) -> None:
lora_key = Sanitizer.sanitise_model_name(lora["orig_name"]).lower().strip()
lora_filename = f'{Sanitizer.sanitise_filename(lora["orig_name"])}_{lora["version_id"]}'
version = {
"inject": lora_filename,
"filename": f"{lora_filename}.safetensors",
"sha256": lora["sha256"],
"adhoc": lora.get("adhoc"),
Expand Down Expand Up @@ -328,7 +327,6 @@ def _parse_civitai_lora_data(self, item, adhoc=False):
lora["nsfw"] = lora_nsfw
lora["versions"] = {}
lora["versions"][lora_version] = {}
lora["versions"][lora_version]["inject"] = lora_filename
lora["versions"][lora_version]["filename"] = f"{lora_filename}.safetensors"
lora["versions"][lora_version]["sha256"] = file.get("hashes", {}).get("SHA256")
lora["versions"][lora_version]["adhoc"] = adhoc
Expand Down Expand Up @@ -708,25 +706,6 @@ def get_lora_triggers(self, model_name: str, is_version: bool = False):
# and then we keep returning previous items
return []

def get_lora_prompt_inject(self, model_name: str | int, is_version: bool = False) -> str | None:
"""Returns the value to inject into the prompt to load this lora
Returns an empty list if no triggers are found
Returns None if lora name not found"""
lora = self.get_model_reference_info(model_name, is_version)
if not lora:
return None
if is_version:
version = self.ensure_is_version(model_name)
if version is None:
return None
prompt_inject = lora["versions"][version].get("inject")
else:
lora_version = self.get_latest_version(lora)
if lora_version is None:
return None
prompt_inject = lora_version.get("inject")
return prompt_inject

def find_lora_trigger(self, model_name: str, trigger_search: str, is_version: bool = False):
"""Searches for a specific trigger for a specified lora name
Returns None if string not found even with fuzzy search"""
Expand Down Expand Up @@ -985,12 +964,18 @@ def get_lora_last_use(self, lora_name, is_version: bool = False):
if not lora:
logger.warning(f"Could not find lora {lora_name} to get last use")
return None
return datetime.strptime(lora["last_used"], "%Y-%m-%d %H:%M:%S")
if is_version:
version = self.ensure_is_version(lora_name)
else:
version = self.find_latest_version(lora)
if version is None:
return None
return datetime.strptime(lora["versions"][version]["last_used"], "%Y-%m-%d %H:%M:%S")

def fetch_adhoc_lora(self, lora_name, timeout=45, is_version: bool = False):
"""Checks if a LoRa is available
If not, immediately downloads it
Finally, returns a tuple with the lora key and text to inject into the prompt for comfyUI to load the lora"""
Finally, returns the lora name"""
if isinstance(lora_name, str):
if lora_name in self.model_reference:
self._touch_lora(lora_name)
Expand Down Expand Up @@ -1066,3 +1051,8 @@ def is_model_available(self, model_name):
return False
self._touch_lora(found_model_name)
return True

def is_lora_available(self, lora_name: str, is_version: bool = False):
if is_version:
return self.ensure_is_version(lora_name) in self._index_version_ids
return self.is_model_available(lora_name)
48 changes: 48 additions & 0 deletions tests/test_horde_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,54 @@ def test_download_and_use_adhoc_lora(
pil_image,
)

def test_download_and_use_specific_version_lora(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
):
assert shared_model_manager.manager.lora

lora_name = "26975"
shared_model_manager.manager.lora.ensure_lora_deleted("22591")
data = {
"sampler_name": "k_euler",
"cfg_scale": 8.0,
"denoising_strength": 1.0,
"seed": 1312,
"height": 512,
"width": 512,
"karras": False,
"tiling": False,
"hires_fix": False,
"clip_skip": 1,
"control_type": None,
"image_is_control": False,
"return_control_map": False,
"prompt": (
"a skull in a bottle with colorful liquid, painting of one health potion, "
"colorful concept art, potion of healing, potion, fantasy game art style, "
"colorfull illustration, magic potions, concept art design illustration, "
"alchemy concept, 3 d epic illustrations, hyper realistic poison bottle, "
"magical potions, health potion, detailed game art illustration, by Justin Gerard"
),
"loras": [{"name": lora_name, "model": 1.0, "clip": 1.0, "inject_trigger": "any", "is_version": True}],
"ddim_steps": 20,
"n_iter": 1,
"model": stable_diffusion_model_name_for_testing,
}

pil_image = hordelib_instance.basic_inference_single_image(data).image
assert pil_image is not None
assert isinstance(pil_image, Image.Image)

img_filename = "lora_version_adhoc.png"
pil_image.save(f"images/{img_filename}", quality=100)
assert check_single_lora_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)

def test_for_probability_tensor_runtime_error(
self,
hordelib_instance: HordeLib,
Expand Down

0 comments on commit 4bfd5db

Please sign in to comment.