diff --git a/hordelib/comfy_horde.py b/hordelib/comfy_horde.py index 664ee3e7..b70535fd 100644 --- a/hordelib/comfy_horde.py +++ b/hordelib/comfy_horde.py @@ -19,6 +19,7 @@ from pprint import pformat import requests import psutil +from collections.abc import Callable import torch from loguru import logger @@ -40,9 +41,10 @@ # There may be other ways to skin this cat, but this strategy minimizes certain kinds of hassle. # # If you tamper with the code in this module to bring the imports out of the function, you may find that you have -# broken, among other things, the ability of pytest to do its test discovery because you will have lost the ability for -# modules which, directly or otherwise, import this module without having called `hordelib.initialise()`. Pytest -# discovery will come across those imports, valiantly attempt to import them and fail with a cryptic error message. +# broken, among myriad other things, the ability of pytest to do its test discovery because you will have lost the +# ability for modules which, directly or otherwise, import this module without having called `hordelib.initialise()`. +# Pytest discovery will come across those imports, valiantly attempt to import them and fail with a cryptic error +# message. # # Correspondingly, you will find that to be an enormous hassle if you are are trying to leverage pytest in any # reasonability sophisticated way (outside of tox), and you will be forced to adopt solution below or something @@ -51,17 +53,18 @@ # # Keen readers may have noticed that the aforementioned issues could be side stepped by simply calling # `hordelib.initialise()` automatically, such as in `test/__init__.py` or in a `conftest.py`. You would be correct, -# but that would be a terrible idea if you ever intended to make alterations to the patch file, as each time you -# triggered pytest discovery which could be as frequently as *every time you save a file* (such as with VSCode), and -# you would enter a situation where the patch was automatically being applied at times you may not intend. +# but that would be a terrible idea as a general practice. It would mean that every time you saved a file in your +# editor, a number of heavyweight operations would be triggered, such as loading comfyui, while pytest discovery runs +# and that would cause slow and unpredictable behavior in your editor. # -# This would be a nightmare to debug, as this author is able to attest to. +# This would be a nightmare to debug, as this author is able to attest to and is the reason this wall of text exists. # # Further, if you are like myself, and enjoy type hints, you will find that any modules have this file in their import # chain will be un-importable in certain contexts and you would be unable to provide the relevant type hints. # -# Having read this, I suggest you glance at the code in `hordelib.initialise()` to get a sense of what is going on -# there, and if you're still confused, ask a hordelib dev who would be happy to share the burden of understanding. +# Having exercised a herculean amount of focus to read this far, I suggest you also glance at the code in +# `hordelib.initialise()` to get a sense of what is going on there, and if you're still confused, ask a hordelib dev +# who would be happy to share the burden of understanding. _comfy_load_models_gpu: types.FunctionType _comfy_current_loaded_models: list = None # type: ignore @@ -76,15 +79,24 @@ _comfy_load_checkpoint_guess_config: types.FunctionType _comfy_get_torch_device: types.FunctionType +"""Will return the current torch device, typically the GPU.""" _comfy_get_free_memory: types.FunctionType +"""Will return the amount of free memory on the current torch device. This value can be misleading.""" _comfy_get_total_memory: types.FunctionType +"""Will return the total amount of memory on the current torch device.""" _comfy_load_torch_file: types.FunctionType _comfy_model_loading: types.ModuleType -_comfy_free_memory: types.FunctionType -_comfy_cleanup_models: types.FunctionType -_comfy_soft_empty_cache: types.FunctionType +_comfy_free_memory: Callable[[float, torch.device, list], None] +"""Will aggressively unload models from memory""" +_comfy_cleanup_models: Callable[[bool], None] +"""Will unload unused models from memory""" +_comfy_soft_empty_cache: Callable[[bool], None] +"""Triggers comfyui and torch to empty their caches""" -_comfy_is_changed_cache_get: types.FunctionType +_comfy_is_changed_cache_get: Callable +_comfy_model_patcher_load: Callable + +_comfy_interrupt_current_processing: types.FunctionType _canny: types.ModuleType _hed: types.ModuleType @@ -111,6 +123,9 @@ class InterceptHandler(logging.Handler): @logger.catch(default=True, reraise=True) def emit(self, record): + message = record.getMessage() + if "lowvram: loaded module regularly" in message: + return # Get corresponding Loguru level if it exists. try: level = logger.level(record.levelname).name @@ -123,7 +138,7 @@ def emit(self, record): frame = frame.f_back depth += 1 - logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + logger.opt(depth=depth, exception=record.exc_info).log(level, message) # ComfyUI uses stdlib logging, so we need to intercept it. @@ -145,6 +160,8 @@ def do_comfy_import( global _comfy_free_memory, _comfy_cleanup_models, _comfy_soft_empty_cache global _canny, _hed, _leres, _midas, _mlsd, _openpose, _pidinet, _uniformer + global _comfy_interrupt_current_processing + if disable_smart_memory: logger.info("Disabling smart memory") sys.argv.append("--disable-smart-memory") @@ -173,23 +190,36 @@ def do_comfy_import( from execution import IsChangedCache global _comfy_is_changed_cache_get - _comfy_is_changed_cache_get = IsChangedCache.get # type: ignore + _comfy_is_changed_cache_get = IsChangedCache.get IsChangedCache.get = IsChangedCache_get_hijack # type: ignore - from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths # type: ignore - from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions # type: ignore + from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths + from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions from comfy.sd import load_checkpoint_guess_config as _comfy_load_checkpoint_guess_config from comfy.model_management import current_loaded_models as _comfy_current_loaded_models - from comfy.model_management import load_models_gpu as _comfy_load_models_gpu + from comfy.model_management import load_models_gpu + + _comfy_load_models_gpu = load_models_gpu # type: ignore + import comfy.model_management + + comfy.model_management.load_models_gpu = _load_models_gpu_hijack from comfy.model_management import get_torch_device as _comfy_get_torch_device from comfy.model_management import get_free_memory as _comfy_get_free_memory from comfy.model_management import get_total_memory as _comfy_get_total_memory from comfy.model_management import free_memory as _comfy_free_memory from comfy.model_management import cleanup_models as _comfy_cleanup_models from comfy.model_management import soft_empty_cache as _comfy_soft_empty_cache + from comfy.model_management import interrupt_current_processing as _comfy_interrupt_current_processing from comfy.utils import load_torch_file as _comfy_load_torch_file - from comfy_extras.chainner_models import model_loading as _comfy_model_loading # type: ignore + from comfy_extras.chainner_models import model_loading as _comfy_model_loading + + from comfy.model_patcher import ModelPatcher + + global _comfy_model_patcher_load + _comfy_model_patcher_load = ModelPatcher.load + ModelPatcher.load = _model_patcher_load_hijack # type: ignore + from hordelib.nodes.comfy_controlnet_preprocessors import ( canny as _canny, hed as _hed, @@ -208,6 +238,70 @@ def do_comfy_import( # isort: on +models_not_to_force_load: list = ["cascade", "sdxl"] # other possible values could be `basemodel` or `sd1` +"""Models which should not be forced to load in the comfy model loading hijack. + +Possible values include `cascade`, `sdxl`, `basemodel`, `sd1` or any other comfyui classname +which can be passed to comfyui's `load_models_gpu` function (as a `ModelPatcher.model`). +""" + +disable_force_loading: bool = False + + +def _do_not_force_load_model_in_patcher(model_patcher): + for model in models_not_to_force_load: + if model in str(type(model_patcher.model)).lower(): + return True + + return False + + +def _load_models_gpu_hijack(*args, **kwargs): + """Intercepts the comfy load_models_gpu function to force full load. + + ComfyUI is too conservative in its loading to GPU for the worker/horde use case where we can have + multiple ComfyUI instances running on the same GPU. This function forces a full load of the model + and the worker/horde-engine takes responsibility for managing the memory or the problems this may + cause. + """ + found_model_to_skip = False + for model_patcher in args[0]: + found_model_to_skip = _do_not_force_load_model_in_patcher(model_patcher) + if found_model_to_skip: + break + + global _comfy_current_loaded_models + if found_model_to_skip: + logger.debug("Not overriding model load") + _comfy_load_models_gpu(*args, **kwargs) + return + + if "force_full_load" in kwargs: + kwargs.pop("force_full_load") + + kwargs["force_full_load"] = True + _comfy_load_models_gpu(*args, **kwargs) + + +def _model_patcher_load_hijack(*args, **kwargs): + """Intercepts the comfy ModelPatcher.load function to force full load. + + See _load_models_gpu_hijack for more information + """ + global _comfy_model_patcher_load + + model_patcher = args[0] + if _do_not_force_load_model_in_patcher(model_patcher): + logger.debug("Not overriding model load") + _comfy_model_patcher_load(*args, **kwargs) + return + + if "full_load" in kwargs: + kwargs.pop("full_load") + + kwargs["full_load"] = True + _comfy_model_patcher_load(*args, **kwargs) + _last_pipeline_settings_hash = "" @@ -324,6 +418,11 @@ def log_free_ram(): ) +def interrupt_comfyui_processing(): + logger.warning("Interrupting comfyui processing") + _comfy_interrupt_current_processing() + + class Comfy_Horde: """Handles horde-specific behavior against ComfyUI.""" @@ -362,6 +461,7 @@ def __init__( self, *, comfyui_callback: typing.Callable[[str, dict, str], None] | None = None, + aggressive_unloading: bool = True, ) -> None: """Initialise the Comfy_Horde object. @@ -388,6 +488,7 @@ def __init__( self._load_custom_nodes() self._comfyui_callback = comfyui_callback + self.aggressive_unloading = aggressive_unloading def _set_comfyui_paths(self) -> None: # These set the default paths for comfyui to look for models and embeddings. From within hordelib, @@ -764,6 +865,12 @@ def _run_pipeline( inference.execute(pipeline, self.client_id, {"client_id": self.client_id}, valid[2]) except Exception as e: logger.exception(f"Exception during comfy execute: {e}") + finally: + if self.aggressive_unloading: + global _comfy_cleanup_models + logger.debug("Cleaning up models") + _comfy_cleanup_models(False) + _comfy_soft_empty_cache(True) stdio.replay() diff --git a/hordelib/horde.py b/hordelib/horde.py index f6e2b088..cbca0403 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -352,10 +352,12 @@ def __init__( self, *, comfyui_callback: Callable[[str, dict, str], None] | None = None, + aggressive_unloading: bool = True, ): if not self._initialised: self.generator = Comfy_Horde( comfyui_callback=comfyui_callback if comfyui_callback else self._comfyui_callback, + aggressive_unloading=aggressive_unloading, ) self.__class__._initialised = True diff --git a/hordelib/initialisation.py b/hordelib/initialisation.py index 014d8e4c..1332e315 100644 --- a/hordelib/initialisation.py +++ b/hordelib/initialisation.py @@ -28,6 +28,7 @@ def initialise( extra_comfyui_args: list[str] | None = None, disable_smart_memory: bool = False, do_not_load_model_mangers: bool = False, + models_not_to_force_load: list[str] | None = None, ): """Initialise hordelib. This is required before using any other hordelib functions. @@ -40,6 +41,9 @@ def initialise( force_low_vram (bool, optional): Whether to forcibly disable ComfyUI's high/med vram modes. Defaults to False. extra_comfyui_args (list[str] | None, optional): Any additional CLI args for comfyui that should be used. \ Defaults to None. + models_not_to_force_load (list[str] | None, optional): A list of baselines that should not be force loaded.\ + **If this is `None`, the defaults are used.** If you wish to override the defaults, pass an empty list. \ + Defaults to None. """ global _is_initialised @@ -79,6 +83,8 @@ def initialise( extra_comfyui_args=extra_comfyui_args, disable_smart_memory=disable_smart_memory, ) + if models_not_to_force_load is not None: + hordelib.comfy_horde.models_not_to_force_load = models_not_to_force_load.copy() vram_on_start_free = hordelib.comfy_horde.get_torch_free_vram_mb() vram_total = hordelib.comfy_horde.get_torch_total_vram_mb() diff --git a/images_expected/text_to_image_max_resolution.png b/images_expected/text_to_image_max_resolution.png new file mode 100644 index 00000000..f41eb0e6 Binary files /dev/null and b/images_expected/text_to_image_max_resolution.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 89a007d0..27c0d7d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -92,8 +92,13 @@ def init_horde( hordelib.initialise( setup_logging=True, logging_verbosity=5, - disable_smart_memory=False, + disable_smart_memory=True, + force_normal_vram_mode=True, do_not_load_model_mangers=True, + models_not_to_force_load=[ + "sdxl", + "cascade", + ], ) from hordelib.settings import UserSettings diff --git a/tests/test_horde_inference.py b/tests/test_horde_inference.py index c0c58b0b..072473af 100644 --- a/tests/test_horde_inference.py +++ b/tests/test_horde_inference.py @@ -47,6 +47,43 @@ def test_text_to_image( pil_image, ) + @pytest.mark.default_sd15_model + def test_text_to_image_max_resolution( + self, + hordelib_instance: HordeLib, + stable_diffusion_model_name_for_testing: str, + ): + data = { + "sampler_name": "k_dpmpp_2m", + "cfg_scale": 7.5, + "denoising_strength": 1.0, + "seed": 123456789, + "height": 2048, + "width": 2048, + "karras": False, + "tiling": False, + "hires_fix": False, + "clip_skip": 1, + "control_type": None, + "image_is_control": False, + "return_control_map": False, + "prompt": "an ancient llamia monster", + "ddim_steps": 5, + "n_iter": 1, + "model": stable_diffusion_model_name_for_testing, + } + pil_image = hordelib_instance.basic_inference_single_image(data).image + assert pil_image is not None + assert isinstance(pil_image, Image.Image) + + img_filename = "text_to_image_max_resolution.png" + pil_image.save(f"images/{img_filename}", quality=100) + + assert check_single_inference_image_similarity( + f"images_expected/{img_filename}", + pil_image, + ) + @pytest.mark.default_sd15_model def test_text_to_image_n_iter( self,