Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flux upscale #253

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
15 changes: 15 additions & 0 deletions bizyair_extras/nodes_upscale_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,18 @@ def load_model(self, **kwargs):
model = BizyAirNodeIO(self.assigned_id)
model.add_node_data(class_type="UpscaleModelLoader", inputs=kwargs)
return (model,)


class ImageUpscaleWithModel(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"upscale_model": (UPSCALE_MODEL,),
"image": ("IMAGE",),
}
}

RETURN_TYPES = ("IMAGE",)
# FUNCTION = "upscale"
CATEGORY = "image/upscaling"
8 changes: 7 additions & 1 deletion src/bizyair/commands/processors/prompt_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List

from bizyair.common import client
from bizyair.common.caching import BizyAirTaskCache, CacheConfig
from bizyair.common.env_var import (
BIZYAIR_DEBUG,
BIZYAIR_DEV_REQUEST_URL,
Expand Down Expand Up @@ -62,7 +63,8 @@ def process(self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]):

base_model, out_route, out_score = None, None, None
for rule in results[::-1]:
if rule.mode_type in {"unet", "vae", "checkpoint"}:
# TODO add to config models.yaml
if rule.mode_type in {"unet", "vae", "checkpoint", "upscale_models"}:
base_model = rule.base_model
out_route = rule.route
out_score = rule.score
Expand All @@ -83,6 +85,10 @@ def validate_input(


class PromptProcessor(Processor):
bizyair_task_cache = BizyAirTaskCache(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 bizyair_task_cache 是在哪里使用呢,在这个PR里没搜到。

config=CacheConfig.from_config(config_manager.get_cache_config())
)

def _exec_info(self, prompt: Dict[str, Dict[str, Any]]):
exec_info = {
"model_version_ids": [],
Expand Down
194 changes: 182 additions & 12 deletions src/bizyair/commands/servers/prompt_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,177 @@
import hashlib
import json
import pprint
import time
import traceback
from dataclasses import dataclass, field
from typing import Any, Dict, List

from bizyair.common.env_var import BIZYAIR_DEBUG
import comfy

from bizyair.common.caching import BizyAirTaskCache, CacheConfig
from bizyair.common.client import send_request
from bizyair.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS
from bizyair.common.utils import truncate_long_strings
from bizyair.configs.conf import config_manager
from bizyair.image_utils import decode_data, encode_data

from ..base import Command, Processor # type: ignore


def get_task_result(task_id: str, offset: int = 0) -> dict:
"""
Get the result of a task.
"""
import requests

task_api = config_manager.get_task_api()
url = f"{BIZYAIR_SERVER_ADDRESS}/{task_api.task_result_endpoint}/{task_id}"
response_json = send_request(
method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8")
)
out = response_json
events = out.get("data", {}).get("events", [])
new_events = []
for event in events:
if (
"data" in event
and isinstance(event["data"], str)
and event["data"].startswith("https://")
):
event["data"] = requests.get(event["data"]).json()
new_events.append(event)
out["data"]["events"] = new_events
return out


@dataclass
class BizyAirTask:
TASK_DATA_STATUS = ["PENDING", "PROCESSING", "COMPLETED"]
task_id: str
data_pool: list[dict] = field(default_factory=list)
data_status: str = None

@staticmethod
def check_inputs(inputs: dict) -> bool:
return (
inputs.get("code") == 20000
and inputs.get("status", False)
and "task_id" in inputs.get("data", {})
)

@classmethod
def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask":
if check_inputs and not cls.check_inputs(inputs):
raise ValueError(f"Invalid inputs: {inputs}")
data = inputs.get("data", {})
task_id = data.get("task_id", "")
return cls(task_id=task_id, data_pool=[], data_status="started")

def is_finished(self) -> bool:
if not self.data_pool:
return False
if self.data_pool[-1].get("data_status") == self.TASK_DATA_STATUS[-1]:
return True
return False

def send_request(self, offset: int = 0) -> dict:
if offset >= len(self.data_pool):
return get_task_result(self.task_id, offset)
else:
return self.data_pool[offset]

def get_data(self, offset: int = 0) -> dict:
if offset >= len(self.data_pool):
return {}
return self.data_pool[offset]

@staticmethod
def _fetch_remote_data(url: str) -> dict:
import requests

return requests.get(url).json()

def get_last_data(self) -> dict:
return self.get_data(len(self.data_pool) - 1)

def do_task_until_completed(
self, *, timeout: int = 480, poll_interval: float = 1
) -> list[dict]:
offset = 0
start_time = time.time()
pbar = None
while not self.is_finished():
try:
data = self.send_request(offset)
data_lst = self._extract_data_list(data)
self.data_pool.extend(data_lst)
offset += len(data_lst)
for data in data_lst:
message = data.get("data", {}).get("message", {})
if (
isinstance(message, dict)
and message.get("event", None) == "progress"
):
value = message["data"]["value"]
total = message["data"]["max"]
if pbar is None:
pbar = comfy.utils.ProgressBar(total)
pbar.update_absolute(value + 1, total, None)
except Exception as e:
print(f"Exception: {e}")

if time.time() - start_time > timeout:
raise TimeoutError(f"Timeout waiting for task {self.task_id} to finish")

time.sleep(poll_interval)

return self.data_pool

def _extract_data_list(self, data):
data_lst = data.get("data", {}).get("events", [])
if not data_lst:
raise ValueError(f"No data found in task {self.task_id}")
return data_lst


class PromptServer(Command):
cache_manager: BizyAirTaskCache = BizyAirTaskCache(
config=CacheConfig.from_config(config_manager.get_cache_config())
)

def __init__(self, router: Processor, processor: Processor):
self.router = router
self.processor = processor

def get_task_id(self, result: Dict[str, Any]) -> str:
return result.get("data", {}).get("task_id", "")

def is_async_task(self, result: Dict[str, Any]) -> str:
"""Determine if the result indicates an asynchronous task."""
return (
result.get("code") == 20000
and result.get("status", False)
and "task_id" in result.get("data", {})
)

def _get_result(self, result: Dict[str, Any], *, cache_key: str = None):
try:
response_data = result["data"]
if BizyAirTask.check_inputs(result):
self.cache_manager.set(cache_key, result)
bz_task = BizyAirTask.from_data(result, check_inputs=False)
bz_task.do_task_until_completed()
last_data = bz_task.get_last_data()
response_data = last_data.get("data")
out = response_data["payload"]
assert out is not None, "Output payload should not be None"
self.cache_manager.set(cache_key, out, overwrite=True)
return out
except Exception as e:
print(f"Exception occurred: {e}")
self.cache_manager.delete(cache_key)
raise

def execute(
self,
prompt: Dict[str, Dict[str, Any]],
Expand All @@ -23,34 +181,46 @@ def execute(
):

prompt = encode_data(prompt)

if BIZYAIR_DEBUG:
debug_info = {
"prompt": truncate_long_strings(prompt, 50),
"last_node_ids": last_node_ids,
}
pprint.pprint(debug_info, indent=4)

url = self.router(prompt=prompt, last_node_ids=last_node_ids)

if BIZYAIR_DEBUG:
print(f"Generated URL: {url}")

result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids)
start_time = time.time()
sh256 = hashlib.sha256(
json.dumps({"url": url, "prompt": prompt}).encode("utf-8")
).hexdigest()
end_time = time.time()
if BIZYAIR_DEBUG:
pprint.pprint({"result": truncate_long_strings(result, 50)}, indent=4)
print(
f"Time taken to generate sh256-{sh256}: {end_time - start_time} seconds"
)

if result is None:
raise RuntimeError("result is None")
cached_output = self.cache_manager.get(sh256)
if cached_output:
if BIZYAIR_DEBUG:
print(f"Cache hit for sh256-{sh256}")
out = cached_output
else:
result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids)
out = self._get_result(result, cache_key=sh256)

if BIZYAIR_DEBUG:
pprint.pprint({"out": truncate_long_strings(out, 50)}, indent=4)

try:
out = result["data"]["payload"]
assert out is not None
except Exception as e:
raise RuntimeError(
f'Unexpected error accessing result["data"]["payload"]. Result: {result}'
) from e
try:
real_out = decode_data(out)
return real_out[0]
except Exception as e:
print("Exception occurred while decoding data")
self.cache_manager.delete(sh256)
traceback.print_exc()
raise RuntimeError(f"Exception: {e=}") from e
Loading
Loading