Skip to content

Commit

Permalink
Merge branch 'dev_bbs' of https://github.com/siliconflow/BizyAir into…
Browse files Browse the repository at this point in the history
… dev_bbs
  • Loading branch information
linjm8780860 committed Nov 26, 2024
2 parents f1145ef + 32623e7 commit 63cef4e
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 53 deletions.
8 changes: 4 additions & 4 deletions js/biz_lib_frontend.js

Large diffs are not rendered by default.

130 changes: 128 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import comfy

from bizyair import BizyAirBaseNode, BizyAirNodeIO, create_node_data, data_types
from bizyair.configs.conf import config_manager
from bizyair.path_utils import path_manager as folder_paths

LOGO = "☁️"
Expand Down Expand Up @@ -242,6 +243,85 @@ def decode(self, vae, samples):


class BizyAir_LoraLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (data_types.MODEL,),
"clip": (data_types.CLIP,),
"lora_name": (
[
"to choose",
],
),
"strength_model": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
"strength_clip": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
"model_version_id": (
"STRING",
{
"default": "",
},
),
}
}

RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
RETURN_NAMES = ("MODEL", "CLIP")

FUNCTION = "load_lora"
CATEGORY = f"{PREFIX}/loaders"

def load_lora(
self,
model,
clip,
lora_name,
strength_model,
strength_clip,
model_version_id: str = None,
):
assigned_id = self.assigned_id
new_model: BizyAirNodeIO = model.copy(assigned_id)
new_clip: BizyAirNodeIO = clip.copy(assigned_id)
instances: List[BizyAirNodeIO] = [new_model, new_clip]

if model_version_id is not None and model_version_id != "":
# use model version id as lora name
lora_name = (
f"{config_manager.get_model_version_id_prefix()}{model_version_id}"
)

for slot_index, ins in zip(range(2), instances):
ins.add_node_data(
class_type="LoraLoader",
inputs={
"model": model,
"clip": clip,
"lora_name": lora_name,
"strength_model": strength_model,
"strength_clip": strength_clip,
},
outputs={"slot_index": slot_index},
)
return (
new_model,
new_clip,
)

@classmethod
def VALIDATE_INPUTS(cls, lora_name):
if lora_name == "" or lora_name is None:
return False
return True


class BizyAir_LoraLoader_Legacy(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -346,6 +426,52 @@ def encode(self, vae, pixels, mask, grow_mask_by=6):


class BizyAir_ControlNetLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"control_net_name": (
[
"to choose",
],
),
"model_version_id": ("STRING", {"default": "", "multiline": False}),
}
}

RETURN_TYPES = (data_types.CONTROL_NET,)
RETURN_NAMES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"

CATEGORY = f"{PREFIX}/loaders"

@classmethod
def VALIDATE_INPUTS(cls, control_net_name, model_version_id):
if control_net_name == "to choose":
return False
if model_version_id is not None and model_version_id != "":
return True
return True

def load_controlnet(self, control_net_name, model_version_id):
if model_version_id is not None and model_version_id != "":
control_net_name = (
f"{config_manager.get_model_version_id_prefix()}{model_version_id}"
)

node_data = create_node_data(
class_type="ControlNetLoader",
inputs={
"control_net_name": control_net_name,
},
outputs={"slot_index": 0},
)
assigned_id = self.assigned_id
node = BizyAirNodeIO(assigned_id, {assigned_id: node_data})
return (node,)


class BizyAir_ControlNetLoader_Legacy(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -789,7 +915,7 @@ def INPUT_TYPES(s):
return ret


class SharedLoraLoader(BizyAir_LoraLoader):
class SharedLoraLoader(BizyAir_LoraLoader_Legacy):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -1010,7 +1136,7 @@ def INPUT_TYPES(s):
CATEGORY = "advanced/conditioning"


class SharedControlNetLoader(BizyAir_ControlNetLoader):
class SharedControlNetLoader(BizyAir_ControlNetLoader_Legacy):
@classmethod
def INPUT_TYPES(s):
ret = super().INPUT_TYPES()
Expand Down
88 changes: 44 additions & 44 deletions src/bizy_server/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,56 +41,56 @@ async def do_get(self, url, params=None, headers=None):
if params:
query_string = urllib.parse.urlencode(params, doseq=True)
url = f"{url}?{query_string}"
session = await self.get_session()
async with session.get(url, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None
async with await self.get_session() as session:
async with session.get(url, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None

async def do_post(self, url, data=None, headers=None):
session = await self.get_session()
async with session.post(url, json=data, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None
async with await self.get_session() as session:
async with session.post(url, json=data, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None

async def do_put(self, url, data=None, headers=None):
session = await self.get_session()
async with session.put(url, json=data, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None
async with await self.get_session() as session:
async with session.put(url, json=data, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None

async def do_delete(self, url, data=None, headers=None):
session = await self.get_session()
async with session.delete(url, json=data, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None
async with await self.get_session() as session:
async with session.delete(url, json=data, headers=headers) as response:
resp_json = await response.json()
if response.status != 200:
return None, ErrorNo(
response.status,
resp_json.get("code", response.status),
None,
resp_json.get("message", await response.text()),
)
return resp_json, None

async def user_info(self) -> tuple[dict | None, ErrorNo | None]:
headers, err = self.auth_header()
Expand Down
6 changes: 3 additions & 3 deletions src/bizy_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def commit_bizy_model(request):
# 开启线程检查同步状态
threading.Thread(
target=self.check_sync_status,
args=(self, resp["id"], resp["version_ids"], sid),
args=(resp["id"], resp["version_ids"], sid),
daemon=True,
).start()

Expand Down Expand Up @@ -414,7 +414,7 @@ async def update_model(request):
# 开启线程检查同步状态
threading.Thread(
target=self.check_sync_status,
args=(self, resp["id"], resp["version_ids"]),
args=(resp["id"], resp["version_ids"]),
daemon=True,
).start()

Expand Down Expand Up @@ -622,7 +622,7 @@ def check_sync_status(self, bizy_model_id: str, version_ids: list, sid=None):
removed.append(version_id)
continue

if model_version["available"]:
if "available" in model_version and model_version["available"]:
self.send_sync(
event="synced",
data={
Expand Down

0 comments on commit 63cef4e

Please sign in to comment.