Skip to content

Commit

Permalink
Merge pull request #213 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: increase timeout rate of lora metadata/model downloads
  • Loading branch information
tazlin authored Mar 7, 2024
2 parents b735d16 + e0fb477 commit 3f51023
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 20 deletions.
63 changes: 46 additions & 17 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,19 @@ class LoraModelManager(BaseModelManager):
)
LORA_API = "https://civitai.com/api/v1/models?types=LORA&sort=Highest%20Rated&primaryFileOnly=true"
MAX_RETRIES = 10 if not TESTS_ONGOING else 3
MAX_DOWNLOAD_THREADS = 3
MAX_DOWNLOAD_THREADS = 5 if not TESTS_ONGOING else 15
RETRY_DELAY = 3 if not TESTS_ONGOING else 0.2
"""The time to wait between retries in seconds"""
REQUEST_METADATA_TIMEOUT = 20
"""The time to wait for a response from the server in seconds"""
REQUEST_DOWNLOAD_TIMEOUT = 300
"""The time to wait for a response from the server in seconds"""
THREAD_WAIT_TIME = 2
REQUEST_METADATA_TIMEOUT = 20 # Longer because civitai performs poorly on metadata requests for more than 5 models
"""The maximum time for no data to be received before we give up on a metadata fetch, in seconds"""
REQUEST_DOWNLOAD_TIMEOUT = 10 if not TESTS_ONGOING else 1
"""The maximum time for no data to be received before we give up on a download, in seconds
This is not the time to download the file, but the time to wait in between data packets. \
If we're actively downloading and the connection to the server is alive, this doesn't come into play
"""

THREAD_WAIT_TIME = 0.1
"""The time to wait between checking the download queue in seconds"""

_file_lock: multiprocessing_lock | nullcontext
Expand Down Expand Up @@ -274,21 +279,45 @@ def _add_lora_ids_to_download_queue(self, lora_ids, adhoc=False, version_compare
def _get_json(self, url):
retries = 0
while retries <= self.MAX_RETRIES:
response = None
try:
response = requests.get(url, timeout=self.REQUEST_METADATA_TIMEOUT)
response = requests.get(
url,
timeout=self.REQUEST_METADATA_TIMEOUT if len(url) < 200 else self.REQUEST_METADATA_TIMEOUT * 1.5,
)
response.raise_for_status()
# Attempt to decode the response to JSON
return response.json()

except (requests.HTTPError, requests.ConnectionError, requests.Timeout, json.JSONDecodeError) as e:
# CivitAI Errors when the model ID is too long
if response.status_code in [404, 500]:
logger.debug(f"url '{url}' download failed with status code {response.status_code}")
return None

logger.debug(f"url '{url}' download failed {type(e)} {e}")

# If this is a 401, 404, or 500, we're not going to get anywhere, just give up
# The following are the CivitAI errors encountered so far
# [401: requires a token, 404: model ID too long, 500: internal server error]
if response is not None:
if response.status_code in [401, 404]:
logger.debug(f"url '{url}' download failed with status code {response.status_code}")
return None
if response.status_code == 500:
logger.debug(f"url '{url}' download failed with status code {response.status_code}")
retries += 3

# The json being invalid is a CivitAI issue, possibly it showing an HTML page and
# this isn't likely to change in the next 30 seconds, so we'll try twice more
# and give up if it doesn't work
if isinstance(e, json.JSONDecodeError):
logger.debug(f"url '{url}' download failed with {type(e)} {e}")
retries += 3

# If the network connection timed out, then self.REQUEST_METADATA_TIMEOUT seconds passed
# and the clock is ticking, so we'll try once more
if response is None:
retries += 5

retries += 1
self.total_retries_attempted += 1

if retries <= self.MAX_RETRIES:
time.sleep(self.RETRY_DELAY)
else:
Expand Down Expand Up @@ -674,8 +703,8 @@ def clear_all_references(self):
def wait_for_downloads(self, timeout=None):
rtr = 0
while not self.are_downloads_complete():
time.sleep(0.5)
rtr += 0.5
time.sleep(self.THREAD_WAIT_TIME)
rtr += self.THREAD_WAIT_TIME
if timeout and rtr > timeout:
raise Exception(f"Lora downloads exceeded specified timeout ({timeout})")
logger.debug("Downloads complete")
Expand Down Expand Up @@ -973,7 +1002,7 @@ def reset_adhoc_loras(self):
if self._stop_all_threads:
logger.debug("Stopped processing thread")
return
time.sleep(0.2)
time.sleep(self.THREAD_WAIT_TIME)
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._adhoc_loras = set()
unsorted_items = []
Expand Down Expand Up @@ -1073,8 +1102,8 @@ def is_adhoc_reset_complete(self):
def wait_for_adhoc_reset(self, timeout=15):
rtr = 0
while not self.is_adhoc_reset_complete():
time.sleep(0.2)
rtr += 0.2
time.sleep(self.THREAD_WAIT_TIME)
rtr += self.THREAD_WAIT_TIME
if timeout and rtr > timeout:
raise Exception(f"Lora adhoc reset exceeded specified timeout ({timeout})")

Expand Down
16 changes: 14 additions & 2 deletions hordelib/model_manager/ti.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _add_ti_ids_to_download_queue(self, ti_ids, adhoc=False, version_compare=Non
def _get_json(self, url):
retries = 0
while retries <= self.MAX_RETRIES:
response = None
try:
response = requests.get(url, timeout=self.REQUEST_METADATA_TIMEOUT)
response.raise_for_status()
Expand All @@ -159,8 +160,19 @@ def _get_json(self, url):

except (requests.HTTPError, requests.ConnectionError, requests.Timeout, json.JSONDecodeError):
# CivitAI Errors when the model ID is too long
if response.status_code in [404, 500]:
return None
if response is not None:
if response.status_code in [401, 404]:
return None
if response.status_code == 500:
retries += 3
logger.debug(
"CivitAI reported an internal error when downloading metadata. "
"Fewer retries will be attempted.",
)

if response is None:
retries += 5

retries += 1
self.total_retries_attempted += 1
if retries <= self.MAX_RETRIES:
Expand Down
2 changes: 1 addition & 1 deletion tests/model_managers/test_mm_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_adhoc_non_existing_intstring_large(self):
lora_model_manager.wait_for_adhoc_reset(15)
lora_name = "99999999999999"
lora_key = lora_model_manager.fetch_adhoc_lora(lora_name)
assert lora_model_manager.total_retries_attempted == 0
assert lora_model_manager.total_retries_attempted == 1
assert lora_key is None
assert not lora_model_manager.is_model_available(lora_name)
lora_model_manager.stop_all()
Expand Down

0 comments on commit 3f51023

Please sign in to comment.