Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: torch 2.1; better memory management in high vram scenarios #80

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,19 @@ def do_comfy_import():

import comfy.model_management

comfy.model_management.vram_state = comfy.model_management.VRAMState.LOW_VRAM
comfy.model_management.set_vram_to = comfy.model_management.VRAMState.LOW_VRAM
# comfy.model_management.vram_state = comfy.model_management.VRAMState.HIGH_VRAM
# comfy.model_management.set_vram_to = comfy.model_management.VRAMState.HIGH_VRAM

def always_cpu(parameters, dtype):
return torch.device("cpu")
logger.info("Comfy_Horde initialised")

comfy.model_management.unet_inital_load_device = always_cpu
# def always_cpu(parameters, dtype):
# return torch.device("cpu")

# comfy.model_management.unet_inital_load_device = always_cpu
comfy.model_management.DISABLE_SMART_MEMORY = True
comfy.model_management.lowvram_available = True
# comfy.model_management.lowvram_available = True

# comfy.model_management.unet_offload_device = _unet_offload_device_hijack

total_vram = get_torch_total_vram_mb()
total_ram = psutil.virtual_memory().total / (1024 * 1024)
Expand All @@ -163,27 +167,33 @@ def always_cpu(parameters, dtype):


def cleanup():
logger.debug(f"{len(_comfy_current_loaded_models)} models loaded in comfy")

_comfy_soft_empty_cache()
logger.debug(f"{len(_comfy_current_loaded_models)} models loaded in comfy")


def unload_all_models_vram():
from hordelib.shared_model_manager import SharedModelManager

logger.debug("In unload_all_models_vram")

logger.debug(f"{len(_comfy_current_loaded_models)} models loaded in comfy")
# _comfy_free_memory(_comfy_get_total_memory(), _comfy_get_torch_device())
with torch.no_grad():
for model in _comfy_current_loaded_models:
model.model_unload()
_comfy_soft_empty_cache()
try:
for model in _comfy_current_loaded_models:
model.model_unload()
_comfy_soft_empty_cache()
except Exception as e:
logger.error(f"Exception during comfy unload: {e}")
_comfy_cleanup_models()
_comfy_soft_empty_cache()
logger.debug(f"{len(_comfy_current_loaded_models)} models loaded in comfy")


def unload_all_models_ram():
from hordelib.shared_model_manager import SharedModelManager

logger.debug("In unload_all_models_ram")

SharedModelManager.manager._models_in_ram = {}
logger.debug(f"{len(_comfy_current_loaded_models)} models loaded in comfy")
with torch.no_grad():
Expand Down
7 changes: 1 addition & 6 deletions hordelib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hordelib.config_path import get_hordelib_path

COMFYUI_VERSION = "099226015edcdf595d915f0f39359ef0d83fbba6"
COMFYUI_VERSION = "0e763e880f5e838e7a1e3914444cae6790c48627"
"""The exact version of ComfyUI version to load."""

REMOTE_PROXY = ""
Expand All @@ -22,11 +22,6 @@ class HordeSupportedBackends(Enum):
ComfyUI = auto()


# Models Excluded from hordelib (for now) # FIXME
# this could easily be a json file on the AI-Horde-image-model-reference repo
EXCLUDED_MODEL_NAMES = ["pix2pix"]


class MODEL_CATEGORY_NAMES(StrEnum):
"""Look up str enum for the categories of models (compvis, controlnet, clip, etc...)."""

Expand Down
2 changes: 2 additions & 0 deletions hordelib/nodes/node_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def load_checkpoint(
same_loaded_model[0][0].model.apply(make_regular)
make_regular_vae(same_loaded_model[0][2])

logger.debug("Model was previously loaded, returning it.")

return same_loaded_model[0]

if not ckpt_name:
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Add this in for tox, comment out for build
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://download.pytorch.org/whl/cu121
horde_sdk>=0.7.10
horde_model_reference>=0.5.2
pydantic
torch>=2.0.0
xformers>=0.0.19
torch>=2.1.0
# xformers>=0.0.19
torchvision
torchaudio
torchdiffeq
Expand Down