From d04a17aaa42207d360a34ce78f42e42a00071a7e Mon Sep 17 00:00:00 2001 From: FengWen Date: Mon, 2 Dec 2024 14:06:28 +0800 Subject: [PATCH 01/12] tmp save --- src/bizyair/commands/servers/prompt_server.py | 44 +++++++-- src/bizyair/common/__init__.py | 1 + src/bizyair/common/client.py | 96 ++++++++++++++++++- 3 files changed, 132 insertions(+), 9 deletions(-) diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index a7c0f855..eb023258 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -1,7 +1,11 @@ +import json import pprint import traceback from typing import Any, Dict, List +import requests + +from bizyair.common import BizyAirTask from bizyair.common.env_var import BIZYAIR_DEBUG from bizyair.common.utils import truncate_long_strings from bizyair.image_utils import decode_data, encode_data @@ -14,6 +18,39 @@ def __init__(self, router: Processor, processor: Processor): self.router = router self.processor = processor + # def _get_result(self, result: Dict[str, Any]): + # try: + # response_data = result["data"] + # if result.get("code") == 20000 and result.get("data", {}).get("task_id"): + # task_id = result["data"]["task_id"] + # bz_task = BizyAirTask.from_data(result["data"]) + + # i = 0 + # while i < 1000: + # import time + + # time.sleep(1) + # try: + # resp = bz_task.send_request(offset=i) + # import ipdb + + # ipdb.set_trace() + # i += 1 + # except Exception as e: + # print(f"Exception: {e}") + + # if "upload_to_s3" in result and result["upload_to_s3"]: + # upload_url = result["data"] + # response = requests.get(upload_url) + # assert response.status_code == 200 + # response_data = response.json() + # out = response_data["payload"] + # return out + # except Exception as e: + # raise RuntimeError( + # f'Unexpected error accessing result["data"]["payload"]. Result: {result}' + # ) from e + def execute( self, prompt: Dict[str, Dict[str, Any]], @@ -39,12 +76,7 @@ def execute( if result is None: raise RuntimeError("result is None") - try: - out = result["data"]["payload"] - except Exception as e: - raise RuntimeError( - f'Unexpected error accessing result["data"]["payload"]. Result: {result}' - ) from e + out = self._get_result(result) try: real_out = decode_data(out) return real_out[0] diff --git a/src/bizyair/common/__init__.py b/src/bizyair/common/__init__.py index 55fe7626..828c3fd9 100644 --- a/src/bizyair/common/__init__.py +++ b/src/bizyair/common/__init__.py @@ -1,4 +1,5 @@ from .client import ( + BizyAirTask, fetch_models_by_type, get_api_key, send_request, diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index a990fb6a..7870c22f 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -4,6 +4,9 @@ import urllib.error import urllib.request import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Union import aiohttp @@ -106,8 +109,9 @@ def send_request( data: bytes = None, verbose=False, callback: callable = process_response_data, + response_handler: callable = json.loads, **kwargs, -) -> dict: +) -> Union[dict, Any]: try: headers = kwargs.pop("headers") if "headers" in kwargs else _headers() headers["User-Agent"] = "BizyAir Client" @@ -133,9 +137,11 @@ def send_request( + "Also, verify your network settings and disable any proxies if necessary.\n" + "After checking, please restart the ComfyUI service." ) + if response_handler: + response_data = response_handler(response_data) if callback: - return callback(json.loads(response_data)) - return json.loads(response_data) + return callback(response_data) + return response_data async def async_send_request( @@ -196,3 +202,87 @@ def fetch_models_by_type( verbose=verbose, ) return msg + + +def get_task_result(task_id: str, offset: int = 0) -> dict: + """ + Get the result of a task. + """ + # TODO fix url to config + url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" + out = send_request( + url=url, + data=json.dumps({"offset": offset}).encode("utf-8"), + callback=lambda x: x, + ) + import ipdb + + ipdb.set_trace() + return out + + +@dataclass +class BizyAirTask: + task_id: str + data: list[dict] = field(default_factory=list) + data_status: str = None + + @classmethod + def from_data(cls, data: dict) -> "BizyAirTask": + import ipdb + + self = cls() + ipdb.set_trace() + self.task_id = data["task_id"] + self.data_status = data["status"] + self = cls() + return self + + def is_finished(self) -> bool: + # TODO: fix this + return self.data_status == "finished" + + def _check_request_result(self, data: dict) -> bool: + # TODO: fix this + import ipdb + + ipdb.set_trace() + return False + + def send_request(self, offset: int = 0) -> dict: + if offset >= len(self.data): + return get_task_result(self.task_id, offset) + else: + return self.data[offset] + + # def get_next_data(self, wait_for_result: bool = False) -> dict: + + # # if self.next_offset >= len(self.data): + # # if self.is_finished(): + # # raise ValueError(f"No more data to get for task {self.task_id}") + # # else: + # # data = get_task_result(self.task_id, self.next_offset) + # # if self._check_request_result(data): + # # self.data.extend(data) + # # self.next_offset += 1 + # # else: + # # import ipdb; ipdb.set_trace() + + +# class BizyAirTasks: +# def __init__(self): +# self.task_data_pool = defaultdict(BizyAirTaskData) + +# @classmethod +# def from_data(cls, data: dict) -> "BizyAirTasks": +# self = cls() +# task_id = data["task_id"] +# self.task_data_pool[task_id] = BizyAirTaskData(task_id=task_id) +# return self + +# def get_task_data(self, task_id: str, offset: int = 0) -> dict: +# if task_id not in self.task_data_pool: +# raise ValueError(f"Task {task_id} not found") + +# task_data = self.task_data_pool[task_id] +# return task_data.get_next_data(wait_for_result=wait_for_result) From 6cb62a3e182f75569822d2f9c54e026213a1dcdc Mon Sep 17 00:00:00 2001 From: FengWen Date: Tue, 3 Dec 2024 14:18:29 +0800 Subject: [PATCH 02/12] tmp save --- src/bizyair/commands/servers/prompt_server.py | 74 +++++++++------- src/bizyair/common/client.py | 87 +++++++++++++------ 2 files changed, 103 insertions(+), 58 deletions(-) diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index eb023258..5d08e780 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -18,38 +18,48 @@ def __init__(self, router: Processor, processor: Processor): self.router = router self.processor = processor - # def _get_result(self, result: Dict[str, Any]): - # try: - # response_data = result["data"] - # if result.get("code") == 20000 and result.get("data", {}).get("task_id"): - # task_id = result["data"]["task_id"] - # bz_task = BizyAirTask.from_data(result["data"]) - - # i = 0 - # while i < 1000: - # import time - - # time.sleep(1) - # try: - # resp = bz_task.send_request(offset=i) - # import ipdb - - # ipdb.set_trace() - # i += 1 - # except Exception as e: - # print(f"Exception: {e}") - - # if "upload_to_s3" in result and result["upload_to_s3"]: - # upload_url = result["data"] - # response = requests.get(upload_url) - # assert response.status_code == 200 - # response_data = response.json() - # out = response_data["payload"] - # return out - # except Exception as e: - # raise RuntimeError( - # f'Unexpected error accessing result["data"]["payload"]. Result: {result}' - # ) from e + def get_task_id(self, result: Dict[str, Any]) -> str: + return result.get("data", {}).get("task_id", "") + + def is_async_task(self, result: Dict[str, Any]) -> str: + """Determine if the result indicates an asynchronous task.""" + return ( + result.get("code") == 20000 + and result.get("status", False) + and "task_id" in result.get("data", {}) + ) + + def _get_result(self, result: Dict[str, Any]): + try: + response_data = result["data"] + if BizyAirTask.check_inputs(result): + bz_task = BizyAirTask.from_data(result, check_inputs=False) + + i = 0 + while i < 1000: + import time + + time.sleep(1) + try: + _ = bz_task.send_request(offset=i) + import ipdb + + ipdb.set_trace() + i += 1 + except Exception as e: + print(f"Exception: {e}") + + if "upload_to_s3" in result and result["upload_to_s3"]: + upload_url = result["data"] + response = requests.get(upload_url) + assert response.status_code == 200 + response_data = response.json() + out = response_data["payload"] + return out + except Exception as e: + raise RuntimeError( + f'Unexpected error accessing result["data"]["payload"]. Result: {result}' + ) from e def execute( self, diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index 7870c22f..d356415c 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -1,6 +1,7 @@ import asyncio import json import pprint +import time import urllib.error import urllib.request import warnings @@ -208,52 +209,86 @@ def get_task_result(task_id: str, offset: int = 0) -> dict: """ Get the result of a task. """ - # TODO fix url to config + + print(f"get_task_result: {task_id}, {offset}") + import requests + url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" - out = send_request( - url=url, - data=json.dumps({"offset": offset}).encode("utf-8"), - callback=lambda x: x, - ) import ipdb ipdb.set_trace() - return out + response = requests.get(url, headers=_headers(), params={"offset": offset}) + if response.status_code == 200: + return response.json() + else: + raise Exception( + f"bizyair get task resp failed, status_code: {response.status_code}, response: {response.json()}" + ) + + # TODO fix url to config + # # url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" + # # out = send_request( + # url=url, + # data=json.dumps({"offset": offset}).encode("utf-8"), + # callback=lambda x: x, + # ) + # import ipdb + + # ipdb.set_trace() + # return out @dataclass class BizyAirTask: task_id: str - data: list[dict] = field(default_factory=list) + data_pool: list[dict] = field(default_factory=list) data_status: str = None - @classmethod - def from_data(cls, data: dict) -> "BizyAirTask": - import ipdb + @staticmethod + def check_inputs(inputs: dict) -> bool: + return ( + inputs.get("code") == 20000 + and inputs.get("status", False) + and "task_id" in inputs.get("data", {}) + ) - self = cls() - ipdb.set_trace() - self.task_id = data["task_id"] - self.data_status = data["status"] - self = cls() - return self + @classmethod + def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask": + if check_inputs and not cls.check_inputs(inputs): + raise ValueError(f"Invalid inputs: {inputs}") + data = inputs.get("data", {}) + task_id = data.get("task_id", "") + return cls(task_id=task_id, data_pool=[], data_status="started") def is_finished(self) -> bool: # TODO: fix this return self.data_status == "finished" - def _check_request_result(self, data: dict) -> bool: - # TODO: fix this - import ipdb - - ipdb.set_trace() - return False - def send_request(self, offset: int = 0) -> dict: - if offset >= len(self.data): + if offset >= len(self.data_pool): return get_task_result(self.task_id, offset) else: - return self.data[offset] + return self.data_pool[offset] + + def get_all_outputs(self, timeout: int = 240) -> list[dict]: + offset = 0 + start_time = time.time() + while not self.is_finished(): + try: + data = self.send_request(offset) + import ipdb + + ipdb.set_trace() + self.data_pool.extend(data) + offset += 1 + except Exception as e: + print(f"Exception: {e}") + finally: + if time.time() - start_time > timeout: + raise TimeoutError( + f"Timeout waiting for task {self.task_id} to finish" + ) + return self.data_pool # def get_next_data(self, wait_for_result: bool = False) -> dict: From 99807ad315c2eb7738227bdf0a99bb95b4287921 Mon Sep 17 00:00:00 2001 From: FengWen Date: Wed, 4 Dec 2024 10:38:09 +0800 Subject: [PATCH 03/12] fix BizyAirTask in client.py --- src/bizyair/commands/servers/prompt_server.py | 29 +++------ src/bizyair/common/client.py | 59 +++++-------------- 2 files changed, 25 insertions(+), 63 deletions(-) diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index 5d08e780..2f9d4853 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -34,26 +34,15 @@ def _get_result(self, result: Dict[str, Any]): response_data = result["data"] if BizyAirTask.check_inputs(result): bz_task = BizyAirTask.from_data(result, check_inputs=False) - - i = 0 - while i < 1000: - import time - - time.sleep(1) - try: - _ = bz_task.send_request(offset=i) - import ipdb - - ipdb.set_trace() - i += 1 - except Exception as e: - print(f"Exception: {e}") - - if "upload_to_s3" in result and result["upload_to_s3"]: - upload_url = result["data"] - response = requests.get(upload_url) - assert response.status_code == 200 - response_data = response.json() + bz_task.do_task_until_completed() + last_data = bz_task.get_last_data() + response_data = last_data + + # if "upload_to_s3" in result and result["upload_to_s3"]: + # upload_url = result["data"] + # response = requests.get(upload_url) + # assert response.status_code == 200 + # response_data = response.json() out = response_data["payload"] return out except Exception as e: diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index d356415c..d549f50a 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -214,9 +214,6 @@ def get_task_result(task_id: str, offset: int = 0) -> dict: import requests url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" - import ipdb - - ipdb.set_trace() response = requests.get(url, headers=_headers(), params={"offset": offset}) if response.status_code == 200: return response.json() @@ -240,6 +237,7 @@ def get_task_result(task_id: str, offset: int = 0) -> dict: @dataclass class BizyAirTask: + TASK_DATA_STATUS = ["PENDING", "PROCESSING", "COMPLETED"] task_id: str data_pool: list[dict] = field(default_factory=list) data_status: str = None @@ -261,8 +259,11 @@ def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask": return cls(task_id=task_id, data_pool=[], data_status="started") def is_finished(self) -> bool: - # TODO: fix this - return self.data_status == "finished" + if not self.data_pool: + return False + if self.data_pool[-1].get("data_status") == self.TASK_DATA_STATUS[-1]: + return True + return False def send_request(self, offset: int = 0) -> dict: if offset >= len(self.data_pool): @@ -270,17 +271,20 @@ def send_request(self, offset: int = 0) -> dict: else: return self.data_pool[offset] - def get_all_outputs(self, timeout: int = 240) -> list[dict]: + def get_last_data(self) -> dict: + return self.data_pool[-1] + + def do_task_until_completed(self, *, timeout: int = 480) -> list[dict]: offset = 0 start_time = time.time() while not self.is_finished(): try: data = self.send_request(offset) - import ipdb - - ipdb.set_trace() - self.data_pool.extend(data) - offset += 1 + data_lst = data.get("data", {}).get("events", []) + if not data_lst: + raise ValueError(f"No data found in task {self.task_id}") + self.data_pool.extend(data_lst) + offset += len(data_lst) except Exception as e: print(f"Exception: {e}") finally: @@ -288,36 +292,5 @@ def get_all_outputs(self, timeout: int = 240) -> list[dict]: raise TimeoutError( f"Timeout waiting for task {self.task_id} to finish" ) + time.sleep(2) return self.data_pool - - # def get_next_data(self, wait_for_result: bool = False) -> dict: - - # # if self.next_offset >= len(self.data): - # # if self.is_finished(): - # # raise ValueError(f"No more data to get for task {self.task_id}") - # # else: - # # data = get_task_result(self.task_id, self.next_offset) - # # if self._check_request_result(data): - # # self.data.extend(data) - # # self.next_offset += 1 - # # else: - # # import ipdb; ipdb.set_trace() - - -# class BizyAirTasks: -# def __init__(self): -# self.task_data_pool = defaultdict(BizyAirTaskData) - -# @classmethod -# def from_data(cls, data: dict) -> "BizyAirTasks": -# self = cls() -# task_id = data["task_id"] -# self.task_data_pool[task_id] = BizyAirTaskData(task_id=task_id) -# return self - -# def get_task_data(self, task_id: str, offset: int = 0) -> dict: -# if task_id not in self.task_data_pool: -# raise ValueError(f"Task {task_id} not found") - -# task_data = self.task_data_pool[task_id] -# return task_data.get_next_data(wait_for_result=wait_for_result) From bd737e98f21f53bfdee33768886cbad3269b9883 Mon Sep 17 00:00:00 2001 From: FengWen Date: Wed, 4 Dec 2024 15:05:32 +0800 Subject: [PATCH 04/12] fix payload error --- src/bizyair/commands/servers/prompt_server.py | 12 +++--------- src/bizyair/common/client.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index 2f9d4853..5ec17b1f 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -3,8 +3,6 @@ import traceback from typing import Any, Dict, List -import requests - from bizyair.common import BizyAirTask from bizyair.common.env_var import BIZYAIR_DEBUG from bizyair.common.utils import truncate_long_strings @@ -36,13 +34,7 @@ def _get_result(self, result: Dict[str, Any]): bz_task = BizyAirTask.from_data(result, check_inputs=False) bz_task.do_task_until_completed() last_data = bz_task.get_last_data() - response_data = last_data - - # if "upload_to_s3" in result and result["upload_to_s3"]: - # upload_url = result["data"] - # response = requests.get(upload_url) - # assert response.status_code == 200 - # response_data = response.json() + response_data = last_data.get("data") out = response_data["payload"] return out except Exception as e: @@ -77,6 +69,8 @@ def execute( out = self._get_result(result) try: + if "error" in out: + raise RuntimeError(out["error"]) real_out = decode_data(out) return real_out[0] except Exception as e: diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index d549f50a..0dfa9063 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -216,7 +216,8 @@ def get_task_result(task_id: str, offset: int = 0) -> dict: url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" response = requests.get(url, headers=_headers(), params={"offset": offset}) if response.status_code == 200: - return response.json() + out = response.json() + return out else: raise Exception( f"bizyair get task resp failed, status_code: {response.status_code}, response: {response.json()}" @@ -272,7 +273,12 @@ def send_request(self, offset: int = 0) -> dict: return self.data_pool[offset] def get_last_data(self) -> dict: - return self.data_pool[-1] + out = self.data_pool[-1] + import requests + + if out.get("data").startswith("https://"): + out["data"] = requests.get(out.get("data")).json() + return out def do_task_until_completed(self, *, timeout: int = 480) -> list[dict]: offset = 0 @@ -283,6 +289,7 @@ def do_task_until_completed(self, *, timeout: int = 480) -> list[dict]: data_lst = data.get("data", {}).get("events", []) if not data_lst: raise ValueError(f"No data found in task {self.task_id}") + self.data_pool.extend(data_lst) offset += len(data_lst) except Exception as e: From 30c2a3e2b98b5056e2614b19d4ce39ae274f2e99 Mon Sep 17 00:00:00 2001 From: FengWen Date: Sat, 7 Dec 2024 00:49:38 +0800 Subject: [PATCH 05/12] tmp save --- src/bizyair/common/client.py | 92 +++++++++++++++++++++++++----------- src/bizyair/common/utils.py | 4 ++ 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index 0dfa9063..3092160e 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -10,6 +10,10 @@ from typing import Any, Union import aiohttp +import comfy + +# TODO refine +import server # comfyui module __all__ = ["send_request"] @@ -209,31 +213,25 @@ def get_task_result(task_id: str, offset: int = 0) -> dict: """ Get the result of a task. """ - - print(f"get_task_result: {task_id}, {offset}") import requests url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" - response = requests.get(url, headers=_headers(), params={"offset": offset}) - if response.status_code == 200: - out = response.json() - return out - else: - raise Exception( - f"bizyair get task resp failed, status_code: {response.status_code}, response: {response.json()}" - ) - - # TODO fix url to config - # # url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" - # # out = send_request( - # url=url, - # data=json.dumps({"offset": offset}).encode("utf-8"), - # callback=lambda x: x, - # ) - # import ipdb - - # ipdb.set_trace() - # return out + response_json = send_request( + method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8") + ) + out = response_json + events = out.get("data", {}).get("events", []) + new_events = [] + for event in events: + if ( + "data" in event + and isinstance(event["data"], str) + and event["data"].startswith("https://") + ): + event["data"] = requests.get(event["data"]).json() + new_events.append(event) + out["data"]["events"] = new_events + return out @dataclass @@ -272,26 +270,60 @@ def send_request(self, offset: int = 0) -> dict: else: return self.data_pool[offset] - def get_last_data(self) -> dict: - out = self.data_pool[-1] + def get_data(self, offset: int = 0) -> dict: + if offset >= len(self.data_pool): + return {} + return self.data_pool[offset] + + @staticmethod + def _fetch_remote_data(url: str) -> dict: import requests - if out.get("data").startswith("https://"): - out["data"] = requests.get(out.get("data")).json() - return out + return requests.get(url).json() + + def get_last_data(self) -> dict: + return self.get_data(len(self.data_pool) - 1) def do_task_until_completed(self, *, timeout: int = 480) -> list[dict]: offset = 0 start_time = time.time() + # pbar = comfy.utils.ProgressBar(steps) + # def callback(step, x0, x, total_steps): + # if x0_output_dict is not None: + # x0_output_dict["x0"] = x0 + + # preview_bytes = None + # if previewer: + # preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + # pbar.update_absolute(step + 1, total_steps, preview_bytes) + pbar = None while not self.is_finished(): try: + print(f"do_task_until_completed: {offset}") data = self.send_request(offset) data_lst = data.get("data", {}).get("events", []) if not data_lst: raise ValueError(f"No data found in task {self.task_id}") - self.data_pool.extend(data_lst) offset += len(data_lst) + + for data in data_lst: + message = data.get("data", {}).get("message", {}) + if ( + isinstance(message, dict) + and message.get("event", None) == "progress" + ): + value = message["data"]["value"] + total = message["data"]["max"] + if pbar is None: + pbar = comfy.utils.ProgressBar(total) + print("===" * 10, "start") + print(message) + print("===" * 20) + pbar.update_absolute(value + 1, total, None) + # progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + # server.send_sync("progress", progress, server.client_id) + except Exception as e: print(f"Exception: {e}") finally: @@ -300,4 +332,8 @@ def do_task_until_completed(self, *, timeout: int = 480) -> list[dict]: f"Timeout waiting for task {self.task_id} to finish" ) time.sleep(2) + + real_result = get_task_result(self.task_id) + all_events = real_result["data"]["events"] + assert len(self.data_pool) == len(all_events) return self.data_pool diff --git a/src/bizyair/common/utils.py b/src/bizyair/common/utils.py index 8ce5daf4..cd3dd6c3 100644 --- a/src/bizyair/common/utils.py +++ b/src/bizyair/common/utils.py @@ -14,6 +14,10 @@ def truncate_long_strings(obj, max_length=50): return {k: truncate_long_strings(v, max_length) for k, v in obj.items()} elif isinstance(obj, list): return [truncate_long_strings(v, max_length) for v in obj] + elif isinstance(obj, tuple): + return tuple(truncate_long_strings(v, max_length) for v in obj) + elif isinstance(obj, torch.Tensor): + return obj.shape, obj.dtype, obj.device else: return obj From 2e53e16808a06dd56f862656b6df4b4fe66a4332 Mon Sep 17 00:00:00 2001 From: FengWen Date: Tue, 10 Dec 2024 11:41:57 +0800 Subject: [PATCH 06/12] refine --- bizyair_extras/nodes_upscale_model.py | 15 ++++++++ src/bizyair/common/client.py | 50 +++++++++------------------ 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/bizyair_extras/nodes_upscale_model.py b/bizyair_extras/nodes_upscale_model.py index 469da05c..dc6b343c 100644 --- a/bizyair_extras/nodes_upscale_model.py +++ b/bizyair_extras/nodes_upscale_model.py @@ -105,3 +105,18 @@ def load_model(self, **kwargs): model = BizyAirNodeIO(self.assigned_id) model.add_node_data(class_type="UpscaleModelLoader", inputs=kwargs) return (model,) + + +class ImageUpscaleWithModel(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "upscale_model": (UPSCALE_MODEL,), + "image": ("IMAGE",), + } + } + + RETURN_TYPES = ("IMAGE",) + # FUNCTION = "upscale" + CATEGORY = "image/upscaling" diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index 3092160e..83135237 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -1,4 +1,3 @@ -import asyncio import json import pprint import time @@ -12,9 +11,6 @@ import aiohttp import comfy -# TODO refine -import server # comfyui module - __all__ = ["send_request"] from dataclasses import dataclass, field @@ -284,29 +280,19 @@ def _fetch_remote_data(url: str) -> dict: def get_last_data(self) -> dict: return self.get_data(len(self.data_pool) - 1) - def do_task_until_completed(self, *, timeout: int = 480) -> list[dict]: + def do_task_until_completed( + self, *, timeout: int = 480, poll_interval: float = 1 + ) -> list[dict]: offset = 0 start_time = time.time() - # pbar = comfy.utils.ProgressBar(steps) - # def callback(step, x0, x, total_steps): - # if x0_output_dict is not None: - # x0_output_dict["x0"] = x0 - - # preview_bytes = None - # if previewer: - # preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) - # pbar.update_absolute(step + 1, total_steps, preview_bytes) pbar = None while not self.is_finished(): try: print(f"do_task_until_completed: {offset}") data = self.send_request(offset) - data_lst = data.get("data", {}).get("events", []) - if not data_lst: - raise ValueError(f"No data found in task {self.task_id}") + data_lst = self._extract_data_list(data) self.data_pool.extend(data_lst) offset += len(data_lst) - for data in data_lst: message = data.get("data", {}).get("message", {}) if ( @@ -317,23 +303,19 @@ def do_task_until_completed(self, *, timeout: int = 480) -> list[dict]: total = message["data"]["max"] if pbar is None: pbar = comfy.utils.ProgressBar(total) - print("===" * 10, "start") - print(message) - print("===" * 20) pbar.update_absolute(value + 1, total, None) - # progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} - # server.send_sync("progress", progress, server.client_id) - except Exception as e: print(f"Exception: {e}") - finally: - if time.time() - start_time > timeout: - raise TimeoutError( - f"Timeout waiting for task {self.task_id} to finish" - ) - time.sleep(2) - - real_result = get_task_result(self.task_id) - all_events = real_result["data"]["events"] - assert len(self.data_pool) == len(all_events) + + if time.time() - start_time > timeout: + raise TimeoutError(f"Timeout waiting for task {self.task_id} to finish") + + time.sleep(poll_interval) + return self.data_pool + + def _extract_data_list(self, data): + data_lst = data.get("data", {}).get("events", []) + if not data_lst: + raise ValueError(f"No data found in task {self.task_id}") + return data_lst From 99028e5904960c4ef26bfd4a8a5da6c2921f02e6 Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:05:07 +0800 Subject: [PATCH 07/12] add task_manager --- src/bizyair/common/task_manager/task_base.py | 70 +++++++++++++++++++ .../common/task_manager/task_context.py | 24 +++++++ .../common/task_manager/task_manager.py | 33 +++++++++ .../common/task_manager/task_runner.py | 8 +++ .../common/task_manager/task_scheduler.py | 15 ++++ .../common/task_manager/task_state_storage.py | 33 +++++++++ 6 files changed, 183 insertions(+) create mode 100644 src/bizyair/common/task_manager/task_base.py create mode 100644 src/bizyair/common/task_manager/task_context.py create mode 100644 src/bizyair/common/task_manager/task_manager.py create mode 100644 src/bizyair/common/task_manager/task_runner.py create mode 100644 src/bizyair/common/task_manager/task_scheduler.py create mode 100644 src/bizyair/common/task_manager/task_state_storage.py diff --git a/src/bizyair/common/task_manager/task_base.py b/src/bizyair/common/task_manager/task_base.py new file mode 100644 index 00000000..cc56cc44 --- /dev/null +++ b/src/bizyair/common/task_manager/task_base.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod + + +class TaskStateStorageBase(ABC): + @abstractmethod + def save_task_context(self, task_id, context: TaskContextBase): + pass + + @abstractmethod + def load_task_context(self, task_id): + pass + + @abstractmethod + def delete_task_context(self, task_id): + pass + + +class TaskSchedulerBase(ABC): + @abstractmethod + def schedule(self, task_id: str): + pass + + @abstractmethod + def add_task(self, task_id: str, task_data: dict, task_runner: TaskRunnerBase): + pass + +class TaskManagerBase(ABC): + @abstractmethod + def create_task(self, task_id: str, task_data: dict, task_runner: TaskRunnerBase): + pass + + @abstractmethod + def start_task(self, task_id: str): + pass + + + @abstractmethod + def stop_task(self, task_id: str): + pass + + @abstractmethod + def delete_task(self, task_id: str): + pass + + +class TaskContextBase(ABC): + + @abstractmethod + def set_storage(self, storage: TaskStateStorageBase): + pass + @abstractmethod + def update_context(self, context: dict): + pass + + +class TaskRunnerBase(ABC): + @abstractmethod + def load_context(self, context: TaskContextBase): + self.context = context + + @abstractmethod + def execute(self): + pass + # if self.context is None: + # raise ValueError("Context is not loaded, please call load_context first") + + # # self.context.update_context(state="running", registers={"PC": 200, "SP": 300}, memory={"data": [7, 8, 9]}) + # task_data = self.context.task_data + # # 执行任务 + # self.context.save_context() diff --git a/src/bizyair/common/task_manager/task_context.py b/src/bizyair/common/task_manager/task_context.py new file mode 100644 index 00000000..101318c9 --- /dev/null +++ b/src/bizyair/common/task_manager/task_context.py @@ -0,0 +1,24 @@ +from datetime import datetime +from .task_base import TaskContextBase, TaskRunnerBase, TaskStateStorageBase + + +class TaskContext(TaskContextBase): + def __init__(self, task_id: str, task_data: any, storage: TaskStateStorageBase): + """TaskContext is the context of the bizyair task + parameters: + task_id: the id of the task + task_data: the data that the task needs to process + storage: the storage of the task + """ + self.task_id = task_id + self.task_data = ( + task_data # task_data is the data that the task needs to process + ) + self.created_at = datetime.now() + self.storage = storage + + def save_context(self): + self.storage.save_task_context(self.task_id, self) + + def get_context(self): + return self.storage.load_task_context(self.task_id) \ No newline at end of file diff --git a/src/bizyair/common/task_manager/task_manager.py b/src/bizyair/common/task_manager/task_manager.py new file mode 100644 index 00000000..68279d88 --- /dev/null +++ b/src/bizyair/common/task_manager/task_manager.py @@ -0,0 +1,33 @@ +from .task_base import TaskManagerBase, TaskStateStorageBase, TaskSchedulerBase +from .task_context import TaskContext + +class TaskManager(TaskManagerBase): + """TaskManager is a class that manages the tasks. + + + example: + + task_manager = TaskManager(state_storage, scheduler) + task_manager.create_task("task_id", {"task_data": "task_data"}) + task_manager.start_task("task_id") + """ + def __init__(self, state_storage: TaskStateStorageBase, scheduler: TaskSchedulerBase): + self.state_storage = state_storage + self.scheduler = scheduler + + def create_task(self, task_id: str, task_data: dict, *, task_runner: TaskRunnerBase = None)-> str: + context = TaskContext(task_id, task_data, self.state_storage) + if task_runner: + task_runner.load_context(context) + else: + assert False, "task_runner is required" + self.scheduler.add_task(task_id, task_runner) + + def start_task(self, task_id): + """启动任务""" + context = self.state_storage.load_task_context(task_id) + if context: + self.scheduler.schedule(task_id) + print(f"Task {task_id} started.") + else: + print(f"Task {task_id} not found.") diff --git a/src/bizyair/common/task_manager/task_runner.py b/src/bizyair/common/task_manager/task_runner.py new file mode 100644 index 00000000..ad06ebb3 --- /dev/null +++ b/src/bizyair/common/task_manager/task_runner.py @@ -0,0 +1,8 @@ +from .task_base import TaskRunnerBase, TaskContextBase + +class BizyAirTaskRunner(TaskRunnerBase): + def __init__(self, task_context: TaskContextBase): + self.task_context = task_context + + def execute(self, task_id: str): + pass diff --git a/src/bizyair/common/task_manager/task_scheduler.py b/src/bizyair/common/task_manager/task_scheduler.py new file mode 100644 index 00000000..23f4a575 --- /dev/null +++ b/src/bizyair/common/task_manager/task_scheduler.py @@ -0,0 +1,15 @@ +from threading import Lock +from .task_base import TaskSchedulerBase + +class TaskScheduler(TaskSchedulerBase): + def __init__(self) -> None: + super().__init__() + self.task_contexts = {} + self._lock = Lock() + + def add_task(self, task_id: str, task_runner: TaskRunnerBase): + with self._lock: + self.task_contexts[task_id] = task_runner + + def schedule(self, task_id: str): + self.task_contexts[task_id].execute() \ No newline at end of file diff --git a/src/bizyair/common/task_manager/task_state_storage.py b/src/bizyair/common/task_manager/task_state_storage.py new file mode 100644 index 00000000..c5ab5414 --- /dev/null +++ b/src/bizyair/common/task_manager/task_state_storage.py @@ -0,0 +1,33 @@ + +from bizyair.common.task_manager.task_base import TaskStateStorageBase, TaskContextBase +import os +import json +from threading import Lock + +class TaskStateStorage(TaskStateStorageBase): + + def __init__(self): + self.storage_directory = "task_storage" + self._lock = Lock() + self._init_db() + + def _init_db(self): + if not os.path.exists(self.storage_directory): + os.makedirs(self.storage_directory) + + def _get_file_path(self, task_id): + return os.path.join(self.storage_directory, f"{task_id}.json") + + def save_task_context(self, task_id, context: TaskContextBase): + file_path = self._get_file_path(task_id) + with self._lock: + with open(file_path, 'w') as file: + json.dump(context.__dict__, file, default=str, indent=4) + + def load_task_context(self, task_id): + file_path = self._get_file_path(task_id) + try: + with open(file_path, 'r') as file: + return json.load(file) + except FileNotFoundError: + return None \ No newline at end of file From 2a409f89a3ff46a625152fdc6778a7ad6f34f17d Mon Sep 17 00:00:00 2001 From: FengWen Date: Fri, 13 Dec 2024 16:57:10 +0800 Subject: [PATCH 08/12] tmp save --- src/bizyair/common/task_manager/task_base.py | 61 ++++++++----------- .../common/task_manager/task_context.py | 3 +- .../common/task_manager/task_manager.py | 19 ++++-- .../common/task_manager/task_runner.py | 3 +- .../common/task_manager/task_scheduler.py | 8 ++- .../common/task_manager/task_state_storage.py | 21 +++++-- 6 files changed, 64 insertions(+), 51 deletions(-) diff --git a/src/bizyair/common/task_manager/task_base.py b/src/bizyair/common/task_manager/task_base.py index cc56cc44..0380fb32 100644 --- a/src/bizyair/common/task_manager/task_base.py +++ b/src/bizyair/common/task_manager/task_base.py @@ -1,70 +1,59 @@ from abc import ABC, abstractmethod +class TaskContextBase(ABC): + @abstractmethod + def save_context(self): + pass + + @abstractmethod + def get_context(self): + pass + + +class TaskRunnerBase(ABC): + @abstractmethod + def load_context(self, context: TaskContextBase): + self.context = context + + @abstractmethod + def execute(self): + pass + + class TaskStateStorageBase(ABC): @abstractmethod def save_task_context(self, task_id, context: TaskContextBase): pass - + @abstractmethod def load_task_context(self, task_id): pass - - @abstractmethod - def delete_task_context(self, task_id): - pass class TaskSchedulerBase(ABC): @abstractmethod def schedule(self, task_id: str): pass - + @abstractmethod def add_task(self, task_id: str, task_data: dict, task_runner: TaskRunnerBase): pass + class TaskManagerBase(ABC): @abstractmethod def create_task(self, task_id: str, task_data: dict, task_runner: TaskRunnerBase): pass - + @abstractmethod def start_task(self, task_id: str): pass - @abstractmethod def stop_task(self, task_id: str): pass - - @abstractmethod - def delete_task(self, task_id: str): - pass - -class TaskContextBase(ABC): - @abstractmethod - def set_storage(self, storage: TaskStateStorageBase): - pass - @abstractmethod - def update_context(self, context: dict): - pass - - -class TaskRunnerBase(ABC): - @abstractmethod - def load_context(self, context: TaskContextBase): - self.context = context - - @abstractmethod - def execute(self): + def delete_task(self, task_id: str): pass - # if self.context is None: - # raise ValueError("Context is not loaded, please call load_context first") - - # # self.context.update_context(state="running", registers={"PC": 200, "SP": 300}, memory={"data": [7, 8, 9]}) - # task_data = self.context.task_data - # # 执行任务 - # self.context.save_context() diff --git a/src/bizyair/common/task_manager/task_context.py b/src/bizyair/common/task_manager/task_context.py index 101318c9..7cf49bc9 100644 --- a/src/bizyair/common/task_manager/task_context.py +++ b/src/bizyair/common/task_manager/task_context.py @@ -1,4 +1,5 @@ from datetime import datetime + from .task_base import TaskContextBase, TaskRunnerBase, TaskStateStorageBase @@ -21,4 +22,4 @@ def save_context(self): self.storage.save_task_context(self.task_id, self) def get_context(self): - return self.storage.load_task_context(self.task_id) \ No newline at end of file + return self.storage.load_task_context(self.task_id) diff --git a/src/bizyair/common/task_manager/task_manager.py b/src/bizyair/common/task_manager/task_manager.py index 68279d88..646fa9d1 100644 --- a/src/bizyair/common/task_manager/task_manager.py +++ b/src/bizyair/common/task_manager/task_manager.py @@ -1,21 +1,32 @@ -from .task_base import TaskManagerBase, TaskStateStorageBase, TaskSchedulerBase +from .task_base import ( + TaskManagerBase, + TaskRunnerBase, + TaskSchedulerBase, + TaskStateStorageBase, +) from .task_context import TaskContext + class TaskManager(TaskManagerBase): """TaskManager is a class that manages the tasks. example: - + task_manager = TaskManager(state_storage, scheduler) task_manager.create_task("task_id", {"task_data": "task_data"}) task_manager.start_task("task_id") """ - def __init__(self, state_storage: TaskStateStorageBase, scheduler: TaskSchedulerBase): + + def __init__( + self, state_storage: TaskStateStorageBase, scheduler: TaskSchedulerBase + ): self.state_storage = state_storage self.scheduler = scheduler - def create_task(self, task_id: str, task_data: dict, *, task_runner: TaskRunnerBase = None)-> str: + def create_task( + self, task_id: str, task_data: dict, *, task_runner: TaskRunnerBase = None + ) -> str: context = TaskContext(task_id, task_data, self.state_storage) if task_runner: task_runner.load_context(context) diff --git a/src/bizyair/common/task_manager/task_runner.py b/src/bizyair/common/task_manager/task_runner.py index ad06ebb3..384e824b 100644 --- a/src/bizyair/common/task_manager/task_runner.py +++ b/src/bizyair/common/task_manager/task_runner.py @@ -1,4 +1,5 @@ -from .task_base import TaskRunnerBase, TaskContextBase +from .task_base import TaskContextBase, TaskRunnerBase + class BizyAirTaskRunner(TaskRunnerBase): def __init__(self, task_context: TaskContextBase): diff --git a/src/bizyair/common/task_manager/task_scheduler.py b/src/bizyair/common/task_manager/task_scheduler.py index 23f4a575..d3300c1f 100644 --- a/src/bizyair/common/task_manager/task_scheduler.py +++ b/src/bizyair/common/task_manager/task_scheduler.py @@ -1,5 +1,7 @@ from threading import Lock -from .task_base import TaskSchedulerBase + +from .task_base import TaskRunnerBase, TaskSchedulerBase + class TaskScheduler(TaskSchedulerBase): def __init__(self) -> None: @@ -10,6 +12,6 @@ def __init__(self) -> None: def add_task(self, task_id: str, task_runner: TaskRunnerBase): with self._lock: self.task_contexts[task_id] = task_runner - + def schedule(self, task_id: str): - self.task_contexts[task_id].execute() \ No newline at end of file + self.task_contexts[task_id].execute() diff --git a/src/bizyair/common/task_manager/task_state_storage.py b/src/bizyair/common/task_manager/task_state_storage.py index c5ab5414..df701d8c 100644 --- a/src/bizyair/common/task_manager/task_state_storage.py +++ b/src/bizyair/common/task_manager/task_state_storage.py @@ -1,9 +1,10 @@ - -from bizyair.common.task_manager.task_base import TaskStateStorageBase, TaskContextBase -import os import json +import os from threading import Lock +from bizyair.common.task_manager.task_base import TaskContextBase, TaskStateStorageBase + + class TaskStateStorage(TaskStateStorageBase): def __init__(self): @@ -21,13 +22,21 @@ def _get_file_path(self, task_id): def save_task_context(self, task_id, context: TaskContextBase): file_path = self._get_file_path(task_id) with self._lock: - with open(file_path, 'w') as file: + with open(file_path, "w") as file: json.dump(context.__dict__, file, default=str, indent=4) def load_task_context(self, task_id): file_path = self._get_file_path(task_id) try: - with open(file_path, 'r') as file: + with open(file_path, "r") as file: return json.load(file) except FileNotFoundError: - return None \ No newline at end of file + return None + + def list_tasks(self): + return os.listdir(self.storage_directory) + + def delete_task(self, task_id): + file_path = self._get_file_path(task_id) + if os.path.exists(file_path): + os.remove(file_path) From 3099a61afd6ef32567af4b0c20e698942e58ed30 Mon Sep 17 00:00:00 2001 From: FengWen Date: Sat, 14 Dec 2024 07:29:58 +0800 Subject: [PATCH 09/12] refine --- src/bizyair/common/client.py | 1 + src/bizyair/common/task_manager/task_base.py | 59 ------------------- .../common/task_manager/task_context.py | 25 -------- .../common/task_manager/task_manager.py | 44 -------------- .../common/task_manager/task_runner.py | 9 --- .../common/task_manager/task_scheduler.py | 17 ------ .../common/task_manager/task_state_storage.py | 42 ------------- 7 files changed, 1 insertion(+), 196 deletions(-) delete mode 100644 src/bizyair/common/task_manager/task_base.py delete mode 100644 src/bizyair/common/task_manager/task_context.py delete mode 100644 src/bizyair/common/task_manager/task_manager.py delete mode 100644 src/bizyair/common/task_manager/task_runner.py delete mode 100644 src/bizyair/common/task_manager/task_scheduler.py delete mode 100644 src/bizyair/common/task_manager/task_state_storage.py diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index 83135237..0c93eb52 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -1,3 +1,4 @@ +import hashlib import json import pprint import time diff --git a/src/bizyair/common/task_manager/task_base.py b/src/bizyair/common/task_manager/task_base.py deleted file mode 100644 index 0380fb32..00000000 --- a/src/bizyair/common/task_manager/task_base.py +++ /dev/null @@ -1,59 +0,0 @@ -from abc import ABC, abstractmethod - - -class TaskContextBase(ABC): - @abstractmethod - def save_context(self): - pass - - @abstractmethod - def get_context(self): - pass - - -class TaskRunnerBase(ABC): - @abstractmethod - def load_context(self, context: TaskContextBase): - self.context = context - - @abstractmethod - def execute(self): - pass - - -class TaskStateStorageBase(ABC): - @abstractmethod - def save_task_context(self, task_id, context: TaskContextBase): - pass - - @abstractmethod - def load_task_context(self, task_id): - pass - - -class TaskSchedulerBase(ABC): - @abstractmethod - def schedule(self, task_id: str): - pass - - @abstractmethod - def add_task(self, task_id: str, task_data: dict, task_runner: TaskRunnerBase): - pass - - -class TaskManagerBase(ABC): - @abstractmethod - def create_task(self, task_id: str, task_data: dict, task_runner: TaskRunnerBase): - pass - - @abstractmethod - def start_task(self, task_id: str): - pass - - @abstractmethod - def stop_task(self, task_id: str): - pass - - @abstractmethod - def delete_task(self, task_id: str): - pass diff --git a/src/bizyair/common/task_manager/task_context.py b/src/bizyair/common/task_manager/task_context.py deleted file mode 100644 index 7cf49bc9..00000000 --- a/src/bizyair/common/task_manager/task_context.py +++ /dev/null @@ -1,25 +0,0 @@ -from datetime import datetime - -from .task_base import TaskContextBase, TaskRunnerBase, TaskStateStorageBase - - -class TaskContext(TaskContextBase): - def __init__(self, task_id: str, task_data: any, storage: TaskStateStorageBase): - """TaskContext is the context of the bizyair task - parameters: - task_id: the id of the task - task_data: the data that the task needs to process - storage: the storage of the task - """ - self.task_id = task_id - self.task_data = ( - task_data # task_data is the data that the task needs to process - ) - self.created_at = datetime.now() - self.storage = storage - - def save_context(self): - self.storage.save_task_context(self.task_id, self) - - def get_context(self): - return self.storage.load_task_context(self.task_id) diff --git a/src/bizyair/common/task_manager/task_manager.py b/src/bizyair/common/task_manager/task_manager.py deleted file mode 100644 index 646fa9d1..00000000 --- a/src/bizyair/common/task_manager/task_manager.py +++ /dev/null @@ -1,44 +0,0 @@ -from .task_base import ( - TaskManagerBase, - TaskRunnerBase, - TaskSchedulerBase, - TaskStateStorageBase, -) -from .task_context import TaskContext - - -class TaskManager(TaskManagerBase): - """TaskManager is a class that manages the tasks. - - - example: - - task_manager = TaskManager(state_storage, scheduler) - task_manager.create_task("task_id", {"task_data": "task_data"}) - task_manager.start_task("task_id") - """ - - def __init__( - self, state_storage: TaskStateStorageBase, scheduler: TaskSchedulerBase - ): - self.state_storage = state_storage - self.scheduler = scheduler - - def create_task( - self, task_id: str, task_data: dict, *, task_runner: TaskRunnerBase = None - ) -> str: - context = TaskContext(task_id, task_data, self.state_storage) - if task_runner: - task_runner.load_context(context) - else: - assert False, "task_runner is required" - self.scheduler.add_task(task_id, task_runner) - - def start_task(self, task_id): - """启动任务""" - context = self.state_storage.load_task_context(task_id) - if context: - self.scheduler.schedule(task_id) - print(f"Task {task_id} started.") - else: - print(f"Task {task_id} not found.") diff --git a/src/bizyair/common/task_manager/task_runner.py b/src/bizyair/common/task_manager/task_runner.py deleted file mode 100644 index 384e824b..00000000 --- a/src/bizyair/common/task_manager/task_runner.py +++ /dev/null @@ -1,9 +0,0 @@ -from .task_base import TaskContextBase, TaskRunnerBase - - -class BizyAirTaskRunner(TaskRunnerBase): - def __init__(self, task_context: TaskContextBase): - self.task_context = task_context - - def execute(self, task_id: str): - pass diff --git a/src/bizyair/common/task_manager/task_scheduler.py b/src/bizyair/common/task_manager/task_scheduler.py deleted file mode 100644 index d3300c1f..00000000 --- a/src/bizyair/common/task_manager/task_scheduler.py +++ /dev/null @@ -1,17 +0,0 @@ -from threading import Lock - -from .task_base import TaskRunnerBase, TaskSchedulerBase - - -class TaskScheduler(TaskSchedulerBase): - def __init__(self) -> None: - super().__init__() - self.task_contexts = {} - self._lock = Lock() - - def add_task(self, task_id: str, task_runner: TaskRunnerBase): - with self._lock: - self.task_contexts[task_id] = task_runner - - def schedule(self, task_id: str): - self.task_contexts[task_id].execute() diff --git a/src/bizyair/common/task_manager/task_state_storage.py b/src/bizyair/common/task_manager/task_state_storage.py deleted file mode 100644 index df701d8c..00000000 --- a/src/bizyair/common/task_manager/task_state_storage.py +++ /dev/null @@ -1,42 +0,0 @@ -import json -import os -from threading import Lock - -from bizyair.common.task_manager.task_base import TaskContextBase, TaskStateStorageBase - - -class TaskStateStorage(TaskStateStorageBase): - - def __init__(self): - self.storage_directory = "task_storage" - self._lock = Lock() - self._init_db() - - def _init_db(self): - if not os.path.exists(self.storage_directory): - os.makedirs(self.storage_directory) - - def _get_file_path(self, task_id): - return os.path.join(self.storage_directory, f"{task_id}.json") - - def save_task_context(self, task_id, context: TaskContextBase): - file_path = self._get_file_path(task_id) - with self._lock: - with open(file_path, "w") as file: - json.dump(context.__dict__, file, default=str, indent=4) - - def load_task_context(self, task_id): - file_path = self._get_file_path(task_id) - try: - with open(file_path, "r") as file: - return json.load(file) - except FileNotFoundError: - return None - - def list_tasks(self): - return os.listdir(self.storage_directory) - - def delete_task(self, task_id): - file_path = self._get_file_path(task_id) - if os.path.exists(file_path): - os.remove(file_path) From 5a973fbebe191196ff0b519d8a0c09b28aaa9481 Mon Sep 17 00:00:00 2001 From: FengWen Date: Mon, 16 Dec 2024 09:23:02 +0800 Subject: [PATCH 10/12] refine --- .../commands/processors/prompt_processor.py | 5 + src/bizyair/commands/servers/prompt_server.py | 165 +++++++++++++-- src/bizyair/common/__init__.py | 1 - src/bizyair/common/caching.py | 193 ++++++++++++++++++ src/bizyair/common/client.py | 119 +---------- src/bizyair/configs/conf.py | 3 + src/bizyair/configs/models.yaml | 8 + 7 files changed, 363 insertions(+), 131 deletions(-) create mode 100644 src/bizyair/common/caching.py diff --git a/src/bizyair/commands/processors/prompt_processor.py b/src/bizyair/commands/processors/prompt_processor.py index 4ddfc496..32b4f44c 100644 --- a/src/bizyair/commands/processors/prompt_processor.py +++ b/src/bizyair/commands/processors/prompt_processor.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List from bizyair.common import client +from bizyair.common.caching import BizyAirTaskCache, CacheConfig from bizyair.common.env_var import ( BIZYAIR_DEBUG, BIZYAIR_DEV_REQUEST_URL, @@ -83,6 +84,10 @@ def validate_input( class PromptProcessor(Processor): + bizyair_task_cache = BizyAirTaskCache( + config=CacheConfig.from_config(config_manager.get_cache_config()) + ) + def _exec_info(self, prompt: Dict[str, Dict[str, Any]]): exec_info = { "model_version_ids": [], diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index 801d972d..bb142db9 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -1,17 +1,144 @@ +import hashlib import json import pprint +import time import traceback +from dataclasses import dataclass, field from typing import Any, Dict, List -from bizyair.common import BizyAirTask +import comfy + +from bizyair.common.caching import BizyAirTaskCache, CacheConfig +from bizyair.common.client import send_request from bizyair.common.env_var import BIZYAIR_DEBUG from bizyair.common.utils import truncate_long_strings +from bizyair.configs.conf import config_manager from bizyair.image_utils import decode_data, encode_data from ..base import Command, Processor # type: ignore +def get_task_result(task_id: str, offset: int = 0) -> dict: + """ + Get the result of a task. + """ + import requests + + url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" + response_json = send_request( + method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8") + ) + out = response_json + events = out.get("data", {}).get("events", []) + new_events = [] + for event in events: + if ( + "data" in event + and isinstance(event["data"], str) + and event["data"].startswith("https://") + ): + event["data"] = requests.get(event["data"]).json() + new_events.append(event) + out["data"]["events"] = new_events + return out + + +@dataclass +class BizyAirTask: + TASK_DATA_STATUS = ["PENDING", "PROCESSING", "COMPLETED"] + task_id: str + data_pool: list[dict] = field(default_factory=list) + data_status: str = None + + @staticmethod + def check_inputs(inputs: dict) -> bool: + return ( + inputs.get("code") == 20000 + and inputs.get("status", False) + and "task_id" in inputs.get("data", {}) + ) + + @classmethod + def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask": + if check_inputs and not cls.check_inputs(inputs): + raise ValueError(f"Invalid inputs: {inputs}") + data = inputs.get("data", {}) + task_id = data.get("task_id", "") + return cls(task_id=task_id, data_pool=[], data_status="started") + + def is_finished(self) -> bool: + if not self.data_pool: + return False + if self.data_pool[-1].get("data_status") == self.TASK_DATA_STATUS[-1]: + return True + return False + + def send_request(self, offset: int = 0) -> dict: + if offset >= len(self.data_pool): + return get_task_result(self.task_id, offset) + else: + return self.data_pool[offset] + + def get_data(self, offset: int = 0) -> dict: + if offset >= len(self.data_pool): + return {} + return self.data_pool[offset] + + @staticmethod + def _fetch_remote_data(url: str) -> dict: + import requests + + return requests.get(url).json() + + def get_last_data(self) -> dict: + return self.get_data(len(self.data_pool) - 1) + + def do_task_until_completed( + self, *, timeout: int = 480, poll_interval: float = 1 + ) -> list[dict]: + offset = 0 + start_time = time.time() + pbar = None + while not self.is_finished(): + try: + print(f"do_task_until_completed: {offset}") + data = self.send_request(offset) + data_lst = self._extract_data_list(data) + self.data_pool.extend(data_lst) + offset += len(data_lst) + for data in data_lst: + message = data.get("data", {}).get("message", {}) + if ( + isinstance(message, dict) + and message.get("event", None) == "progress" + ): + value = message["data"]["value"] + total = message["data"]["max"] + if pbar is None: + pbar = comfy.utils.ProgressBar(total) + pbar.update_absolute(value + 1, total, None) + except Exception as e: + print(f"Exception: {e}") + + if time.time() - start_time > timeout: + raise TimeoutError(f"Timeout waiting for task {self.task_id} to finish") + + time.sleep(poll_interval) + + return self.data_pool + + def _extract_data_list(self, data): + data_lst = data.get("data", {}).get("events", []) + if not data_lst: + raise ValueError(f"No data found in task {self.task_id}") + return data_lst + + class PromptServer(Command): + cache_manager: BizyAirTaskCache = BizyAirTaskCache( + config=CacheConfig.from_config(config_manager.get_cache_config()) + ) + def __init__(self, router: Processor, processor: Processor): self.router = router self.processor = processor @@ -27,20 +154,24 @@ def is_async_task(self, result: Dict[str, Any]) -> str: and "task_id" in result.get("data", {}) ) - def _get_result(self, result: Dict[str, Any]): + def _get_result(self, result: Dict[str, Any], *, cache_key: str = None): try: response_data = result["data"] if BizyAirTask.check_inputs(result): + self.cache_manager.set(cache_key, result) bz_task = BizyAirTask.from_data(result, check_inputs=False) bz_task.do_task_until_completed() last_data = bz_task.get_last_data() response_data = last_data.get("data") out = response_data["payload"] + assert out is not None + print(f"Setting cache for {cache_key} with out") + self.cache_manager.set(cache_key, out, overwrite=True) return out except Exception as e: - raise RuntimeError( - f'Unexpected error accessing result["data"]["payload"]. Result: {result}' - ) from e + print(f"Exception: {e}") + self.cache_manager.delete(cache_key) + raise e def execute( self, @@ -61,20 +192,26 @@ def execute( if BIZYAIR_DEBUG: print(f"Generated URL: {url}") - result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids) - if BIZYAIR_DEBUG: - pprint.pprint({"result": truncate_long_strings(result, 50)}, indent=4) - - if result is None: - raise RuntimeError("result is None") + start_time = time.time() + sh256 = hashlib.sha256( + json.dumps({"url": url, "prompt": prompt}).encode("utf-8") + ).hexdigest() + end_time = time.time() + print(f"Time taken to generate sh256-{sh256} : {end_time - start_time} seconds") + if self.cache_manager.get(sh256): + print(f"Cache hit for sh256-{sh256}") + out = self.cache_manager.get(sh256) + else: + result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids) + out = self._get_result(result, cache_key=sh256) - out = self._get_result(result) + if BIZYAIR_DEBUG: + pprint.pprint({"out": truncate_long_strings(out, 50)}, indent=4) try: - if "error" in out: - raise RuntimeError(out["error"]) real_out = decode_data(out) return real_out[0] except Exception as e: print("Exception occurred while decoding data") + self.cache_manager.delete(sh256) traceback.print_exc() raise RuntimeError(f"Exception: {e=}") from e diff --git a/src/bizyair/common/__init__.py b/src/bizyair/common/__init__.py index 828c3fd9..55fe7626 100644 --- a/src/bizyair/common/__init__.py +++ b/src/bizyair/common/__init__.py @@ -1,5 +1,4 @@ from .client import ( - BizyAirTask, fetch_models_by_type, get_api_key, send_request, diff --git a/src/bizyair/common/caching.py b/src/bizyair/common/caching.py new file mode 100644 index 00000000..e7a9081f --- /dev/null +++ b/src/bizyair/common/caching.py @@ -0,0 +1,193 @@ +import glob +import json +import os +import time +from abc import ABC, abstractmethod +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict + + +@dataclass +class CacheConfig: + max_size: int = 100 + expiration: int = 300 # 300 seconds + cache_dir: str = "./cache" + file_prefix: str = "bizyair_cache_" + file_suffix: str = ".json" + use_cache: bool = True + + @classmethod + def from_config(cls, config: Dict[str, Any]): + return cls( + max_size=config.get("max_size", 100), + expiration=config.get("expiration", 300), + cache_dir=config.get("cache_dir", "./cache"), + file_prefix=config.get("file_prefix", "bizyair_cache_"), + file_suffix=config.get("file_suffix", ".json"), + use_cache=config.get("use_cache", True), + ) + + +class CacheManager(ABC): + @abstractmethod + def get(self, key): + pass + + @abstractmethod + def set(self, key, value): + pass + + @abstractmethod + def clear(self): + pass + + @abstractmethod + def disable(self): + pass + + +class BizyAirTaskCache(CacheManager): + def __init__(self, config: CacheConfig): + self.config = config + self.cache = OrderedDict() + self.cache_dir = config.cache_dir + self.ensure_directory_exists() + self.cache = self.load_cache() if config.use_cache else self.cache + + def ensure_directory_exists(self): + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + + def load_cache(self): + cache_v_files = glob.glob( + os.path.join(self.cache_dir, f"{self.config.file_prefix}*.json") + ) + output = OrderedDict() + cache_datas = [] + for cache_file in cache_v_files: + try: + file_name = os.path.basename(cache_file)[ + len(self.config.file_prefix) : -len(self.config.file_suffix) + ] + cache_key = file_name.split("-")[0] + cache_timestamp = file_name.split("-")[1] + cache_datas.append( + { + "key": cache_key, + "timestamp": int(cache_timestamp), + "file_path": cache_file, + } + ) + except Exception as e: + print( + f"Warning: Error loading cache file {cache_file}: because {e}, will delete it" + ) + cache_datas = sorted(cache_datas, key=lambda x: x["timestamp"]) + for cache_data in cache_datas: + output[cache_data["key"]] = ( + cache_data["file_path"], + cache_data["timestamp"], + ) + return output + + def delete(self, key): + if key in self.cache: + self.delete_file(self.cache[key][0]) + del self.cache[key] + + def get(self, key): + if key not in self.cache: + return None + + file_path, timestamp = self.cache[key] + if time.time() - timestamp >= self.config.expiration: + self._remove_expired_entry(file_path, key) + return None + + cache_data = self._read_file(file_path) + if cache_data["cache_key"] == key: + return cache_data["result"] + else: + self._remove_expired_entry(file_path, key) + return None + + def _read_file(self, file_path): + try: + with open(file_path, "r") as f: + cache_data = json.load(f) + return cache_data + except Exception as e: + print(f"Error reading file {file_path}: {e}") + return None + + def _remove_expired_entry(self, file_path, key): + self.delete_file(file_path) + del self.cache[key] + + def set(self, key, value, *, overwrite=False): + if not overwrite and key in self.cache: + raise ValueError( + f"Key '{key}' already exists in cache. Use overwrite=True to replace it." + ) + assert isinstance(key, str), "Key must be a string" + + if len(self.cache) >= self.config.max_size: + self._evict_oldest() + + timestamp = int(time.time()) + file_path = os.path.join( + self.cache_dir, f"{self.config.file_prefix}{key}-{timestamp}.json" + ) + self.write_file(key, value, file_path, timestamp) + + def _evict_oldest(self): + oldest_key, (oldest_file_path, _) = self.cache.popitem(last=False) + self.delete_file(oldest_file_path) + + def write_file(self, key: str, value: Any, file_path: str, timestamp: int): + try: + with open(file_path, "w") as f: + json.dump( + {"result": value, "cache_key": key, "timestamp": timestamp}, f + ) + self.cache[key] = (file_path, timestamp) + except Exception as e: + print(f"Error writing file for key '{key}': {e}") + + def delete_file(self, file_path): + if os.path.exists(file_path): + try: + os.remove(file_path) + except Exception as e: + print(f"Error deleting file '{file_path}': {e}") + + def clear(self): + for file_path, _ in self.cache.values(): + self.delete_file(file_path) + self.cache.clear() + + def disable(self): + self.clear() + + +# Example usage +if __name__ == "__main__": + cache_config = CacheConfig(max_size=12, expiration=10, cache_dir="./cache") + cache = BizyAirTaskCache(cache_config) + + # Set some cache values + cache.set("key1", "This is the value for key1") + cache.set("key2", "This is the value for key2") + + # Retrieve values from cache + print(cache.get("key1")) # Output: This is the value for key1 + print(cache.get("key2")) # Output: This is the value for key2 + + # Wait for expiration + time.sleep(9) + print(cache.get("key1")) # Output: None (expired) + + # Clear cache + cache.clear() + print(cache.get("key2")) # Output: None (cache cleared) diff --git a/src/bizyair/common/client.py b/src/bizyair/common/client.py index 0c93eb52..e5a6dad4 100644 --- a/src/bizyair/common/client.py +++ b/src/bizyair/common/client.py @@ -12,6 +12,8 @@ import aiohttp import comfy +from bizyair.common.caching import CacheManager + __all__ = ["send_request"] from dataclasses import dataclass, field @@ -112,6 +114,7 @@ def send_request( verbose=False, callback: callable = process_response_data, response_handler: callable = json.loads, + cache_manager: CacheManager = None, **kwargs, ) -> Union[dict, Any]: try: @@ -204,119 +207,3 @@ def fetch_models_by_type( verbose=verbose, ) return msg - - -def get_task_result(task_id: str, offset: int = 0) -> dict: - """ - Get the result of a task. - """ - import requests - - url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" - response_json = send_request( - method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8") - ) - out = response_json - events = out.get("data", {}).get("events", []) - new_events = [] - for event in events: - if ( - "data" in event - and isinstance(event["data"], str) - and event["data"].startswith("https://") - ): - event["data"] = requests.get(event["data"]).json() - new_events.append(event) - out["data"]["events"] = new_events - return out - - -@dataclass -class BizyAirTask: - TASK_DATA_STATUS = ["PENDING", "PROCESSING", "COMPLETED"] - task_id: str - data_pool: list[dict] = field(default_factory=list) - data_status: str = None - - @staticmethod - def check_inputs(inputs: dict) -> bool: - return ( - inputs.get("code") == 20000 - and inputs.get("status", False) - and "task_id" in inputs.get("data", {}) - ) - - @classmethod - def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask": - if check_inputs and not cls.check_inputs(inputs): - raise ValueError(f"Invalid inputs: {inputs}") - data = inputs.get("data", {}) - task_id = data.get("task_id", "") - return cls(task_id=task_id, data_pool=[], data_status="started") - - def is_finished(self) -> bool: - if not self.data_pool: - return False - if self.data_pool[-1].get("data_status") == self.TASK_DATA_STATUS[-1]: - return True - return False - - def send_request(self, offset: int = 0) -> dict: - if offset >= len(self.data_pool): - return get_task_result(self.task_id, offset) - else: - return self.data_pool[offset] - - def get_data(self, offset: int = 0) -> dict: - if offset >= len(self.data_pool): - return {} - return self.data_pool[offset] - - @staticmethod - def _fetch_remote_data(url: str) -> dict: - import requests - - return requests.get(url).json() - - def get_last_data(self) -> dict: - return self.get_data(len(self.data_pool) - 1) - - def do_task_until_completed( - self, *, timeout: int = 480, poll_interval: float = 1 - ) -> list[dict]: - offset = 0 - start_time = time.time() - pbar = None - while not self.is_finished(): - try: - print(f"do_task_until_completed: {offset}") - data = self.send_request(offset) - data_lst = self._extract_data_list(data) - self.data_pool.extend(data_lst) - offset += len(data_lst) - for data in data_lst: - message = data.get("data", {}).get("message", {}) - if ( - isinstance(message, dict) - and message.get("event", None) == "progress" - ): - value = message["data"]["value"] - total = message["data"]["max"] - if pbar is None: - pbar = comfy.utils.ProgressBar(total) - pbar.update_absolute(value + 1, total, None) - except Exception as e: - print(f"Exception: {e}") - - if time.time() - start_time > timeout: - raise TimeoutError(f"Timeout waiting for task {self.task_id} to finish") - - time.sleep(poll_interval) - - return self.data_pool - - def _extract_data_list(self, data): - data_lst = data.get("data", {}).get("events", []) - if not data_lst: - raise ValueError(f"No data found in task {self.task_id}") - return data_lst diff --git a/src/bizyair/configs/conf.py b/src/bizyair/configs/conf.py index d59bcbdb..81dc035e 100644 --- a/src/bizyair/configs/conf.py +++ b/src/bizyair/configs/conf.py @@ -93,6 +93,9 @@ def get_rules(self, class_type: str) -> List[ModelRule]: def get_model_version_id_prefix(self): return self.model_rule_config["model_version_config"]["model_version_id_prefix"] + def get_cache_config(self): + return self.model_rule_config.get("cache_config", {}) + model_path_config = os.path.join(os.path.dirname(__file__), "models.json") model_rule_config = os.path.join(os.path.dirname(__file__), "models.yaml") diff --git a/src/bizyair/configs/models.yaml b/src/bizyair/configs/models.yaml index 6936c6c2..e1ff6c02 100644 --- a/src/bizyair/configs/models.yaml +++ b/src/bizyair/configs/models.yaml @@ -2,6 +2,14 @@ model_version_config: model_version_id_prefix: "BIZYAIR_MODEL_VERSION_ID:" +cache_config: + max_size: 100 # 100 items + expiration: 604800 # 7 days + cache_dir: ".bizyair_cache" + file_prefix: "bizyair_task_" + file_suffix: ".json" + use_cache: true + model_hub: find_model: From e01467a04ac1920aaed6b02ba8078985dad1dfb4 Mon Sep 17 00:00:00 2001 From: FengWen Date: Mon, 16 Dec 2024 09:41:27 +0800 Subject: [PATCH 11/12] refine --- src/bizyair/commands/servers/prompt_server.py | 25 +++++++++++++------ src/bizyair/common/caching.py | 4 ++- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index bb142db9..1bd2cb92 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -164,14 +164,13 @@ def _get_result(self, result: Dict[str, Any], *, cache_key: str = None): last_data = bz_task.get_last_data() response_data = last_data.get("data") out = response_data["payload"] - assert out is not None - print(f"Setting cache for {cache_key} with out") + assert out is not None, "Output payload should not be None" self.cache_manager.set(cache_key, out, overwrite=True) return out except Exception as e: - print(f"Exception: {e}") + print(f"Exception occurred: {e}") self.cache_manager.delete(cache_key) - raise e + raise def execute( self, @@ -182,13 +181,16 @@ def execute( ): prompt = encode_data(prompt) + if BIZYAIR_DEBUG: debug_info = { "prompt": truncate_long_strings(prompt, 50), "last_node_ids": last_node_ids, } pprint.pprint(debug_info, indent=4) + url = self.router(prompt=prompt, last_node_ids=last_node_ids) + if BIZYAIR_DEBUG: print(f"Generated URL: {url}") @@ -197,16 +199,23 @@ def execute( json.dumps({"url": url, "prompt": prompt}).encode("utf-8") ).hexdigest() end_time = time.time() - print(f"Time taken to generate sh256-{sh256} : {end_time - start_time} seconds") - if self.cache_manager.get(sh256): - print(f"Cache hit for sh256-{sh256}") - out = self.cache_manager.get(sh256) + if BIZYAIR_DEBUG: + print( + f"Time taken to generate sh256-{sh256}: {end_time - start_time} seconds" + ) + + cached_output = self.cache_manager.get(sh256) + if cached_output: + if BIZYAIR_DEBUG: + print(f"Cache hit for sh256-{sh256}") + out = cached_output else: result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids) out = self._get_result(result, cache_key=sh256) if BIZYAIR_DEBUG: pprint.pprint({"out": truncate_long_strings(out, 50)}, indent=4) + try: real_out = decode_data(out) return real_out[0] diff --git a/src/bizyair/common/caching.py b/src/bizyair/common/caching.py index e7a9081f..27cdb978 100644 --- a/src/bizyair/common/caching.py +++ b/src/bizyair/common/caching.py @@ -61,7 +61,9 @@ def ensure_directory_exists(self): def load_cache(self): cache_v_files = glob.glob( - os.path.join(self.cache_dir, f"{self.config.file_prefix}*.json") + os.path.join( + self.cache_dir, f"{self.config.file_prefix}*{self.config.file_suffix}" + ) ) output = OrderedDict() cache_datas = [] From 89fde48e5409b24cba6382e6987deedadf84822d Mon Sep 17 00:00:00 2001 From: FengWen Date: Mon, 16 Dec 2024 19:45:52 +0800 Subject: [PATCH 12/12] refine --- .../commands/processors/prompt_processor.py | 3 ++- src/bizyair/commands/servers/prompt_server.py | 6 +++--- src/bizyair/configs/conf.py | 10 ++++++++- src/bizyair/configs/models.yaml | 21 +++++++++++++++++++ src/bizyair/path_utils/path_manager.py | 3 ++- 5 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/bizyair/commands/processors/prompt_processor.py b/src/bizyair/commands/processors/prompt_processor.py index 32b4f44c..67a7dda2 100644 --- a/src/bizyair/commands/processors/prompt_processor.py +++ b/src/bizyair/commands/processors/prompt_processor.py @@ -63,7 +63,8 @@ def process(self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]): base_model, out_route, out_score = None, None, None for rule in results[::-1]: - if rule.mode_type in {"unet", "vae", "checkpoint"}: + # TODO add to config models.yaml + if rule.mode_type in {"unet", "vae", "checkpoint", "upscale_models"}: base_model = rule.base_model out_route = rule.route out_score = rule.score diff --git a/src/bizyair/commands/servers/prompt_server.py b/src/bizyair/commands/servers/prompt_server.py index 1bd2cb92..b6d18a6f 100644 --- a/src/bizyair/commands/servers/prompt_server.py +++ b/src/bizyair/commands/servers/prompt_server.py @@ -10,7 +10,7 @@ from bizyair.common.caching import BizyAirTaskCache, CacheConfig from bizyair.common.client import send_request -from bizyair.common.env_var import BIZYAIR_DEBUG +from bizyair.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS from bizyair.common.utils import truncate_long_strings from bizyair.configs.conf import config_manager from bizyair.image_utils import decode_data, encode_data @@ -24,7 +24,8 @@ def get_task_result(task_id: str, offset: int = 0) -> dict: """ import requests - url = f"https://uat-bizyair-api.siliconflow.cn/x/v1/bizy_task/{task_id}" + task_api = config_manager.get_task_api() + url = f"{BIZYAIR_SERVER_ADDRESS}/{task_api.task_result_endpoint}/{task_id}" response_json = send_request( method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8") ) @@ -101,7 +102,6 @@ def do_task_until_completed( pbar = None while not self.is_finished(): try: - print(f"do_task_until_completed: {offset}") data = self.send_request(offset) data_lst = self._extract_data_list(data) self.data_pool.extend(data_lst) diff --git a/src/bizyair/configs/conf.py b/src/bizyair/configs/conf.py index 81dc035e..2dea9c22 100644 --- a/src/bizyair/configs/conf.py +++ b/src/bizyair/configs/conf.py @@ -17,6 +17,11 @@ class ModelRule: inputs: dict +@dataclass +class TaskApi: + task_result_endpoint: str + + class ModelRuleManager: def __init__(self, model_rules: list[dict]): self.model_rules = model_rules @@ -57,7 +62,7 @@ def find_rules(self, class_type: str) -> List[ModelRule]: score=self.model_rules[idx_1]["score"], route=self.model_rules[idx_1]["route"], class_type=class_type, - inputs=self.model_rules[idx_1]["nodes"][idx_2]["inputs"], + inputs=self.model_rules[idx_1]["nodes"][idx_2].get("inputs", {}), ) for idx_1, idx_2 in rule_indexes ] @@ -96,6 +101,9 @@ def get_model_version_id_prefix(self): def get_cache_config(self): return self.model_rule_config.get("cache_config", {}) + def get_task_api(self): + return TaskApi(**self.model_rule_config["task_api"]) + model_path_config = os.path.join(os.path.dirname(__file__), "models.json") model_rule_config = os.path.join(os.path.dirname(__file__), "models.yaml") diff --git a/src/bizyair/configs/models.yaml b/src/bizyair/configs/models.yaml index e1ff6c02..c4fcf87e 100644 --- a/src/bizyair/configs/models.yaml +++ b/src/bizyair/configs/models.yaml @@ -22,6 +22,11 @@ model_types: # checkpoints: bizyair/checkpoint # vae: bizyair/vae +task_api: + # Base URL for task-related API calls + task_result_endpoint: bizy_task + + model_rules: - mode_type: unet base_model: FLUX @@ -246,3 +251,19 @@ model_rules: inputs: vae_name: - ^flux.1-canny-vae.safetensors$ + + - mode_type: upscale_models + base_model: UPSCALE_MODEL + describe: Upscale Model + score: 1 + route: /bizy_task/bizyair-flux1-dev-fp8-async + nodes: + - class_type: UpscaleModelLoader + + - mode_type: upscale_model + base_model: FLUX + describe: Flux Upscale Model + score: 6 + route: /bizy_task/bizyair-flux1-dev-fp8-async + nodes: + - class_type: UltimateSDUpscale diff --git a/src/bizyair/path_utils/path_manager.py b/src/bizyair/path_utils/path_manager.py index c7f9df94..ebff143d 100644 --- a/src/bizyair/path_utils/path_manager.py +++ b/src/bizyair/path_utils/path_manager.py @@ -77,7 +77,8 @@ def guess_url_from_node( out = [ rule for rule in rules - if all( + if len(rule.inputs) == 0 + or all( any(re.search(p, node["inputs"][key]) is not None for p in patterns) for key, patterns in rule.inputs.items() )