Skip to content

Commit

Permalink
fix_infer_share_id (#266)
Browse files Browse the repository at this point in the history
* fix_infer_share_id

* refine
  • Loading branch information
ccssu authored Dec 6, 2024
1 parent 8e5bd01 commit 9459054
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
9 changes: 7 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,10 +956,11 @@ def VALIDATE_INPUTS(cls, share_id: str, lora_name: str):
def shared_load_lora(
self, model, clip, lora_name, strength_model, strength_clip, **kwargs
):
resolved_path = folder_paths.filename_path_mapping["loras"][lora_name]
return super().load_lora(
model=model,
clip=clip,
lora_name=lora_name,
lora_name=resolved_path,
strength_model=strength_model,
strength_clip=strength_clip,
)
Expand Down Expand Up @@ -1155,10 +1156,14 @@ def VALIDATE_INPUTS(cls, share_id: str, control_net_name: str):
raise ValueError(
f"ControlNet {control_net_name} not found in share {share_id} with {outs}"
)

return True

def load_controlnet(self, control_net_name, share_id, **kwargs):
return super().load_controlnet(control_net_name=control_net_name, **kwargs)
resolved_path = folder_paths.filename_path_mapping["controlnet"][
control_net_name
]
return super().load_controlnet(control_net_name=resolved_path, **kwargs)


class CLIPVisionEncode(BizyAirBaseNode):
Expand Down
29 changes: 29 additions & 0 deletions src/bizy_server/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,32 @@ async def get_download_url(
except Exception as e:
print(f"\033[31m[BizyAir]\033[0m Fail to get download url: {str(e)}")
return None, errnos.GET_DOWNLOAD_URL

async def get_share_model_files(self, shareId, payload) -> (dict, ErrorNo):
server_url = f"{BIZYAIR_SERVER_ADDRESS}/{shareId}/models/files"
try:

def callback(ret: dict):
if ret["code"] != errnos.OK.code:
return [], ErrorNo(500, ret["code"], None, f"{ret}")
if not ret or "data" not in ret or ret["data"] is None:
return [], None

outputs = [
x["label_path"] for x in ret["data"]["files"] if x["label_path"]
]
outputs = bizyair.path_utils.filter_files_extensions(
outputs,
extensions=bizyair.path_utils.path_manager.supported_pt_extensions,
)
return outputs, None

ret = await bizyair.common.client.async_send_request(
method="GET", url=server_url, params=payload, callback=callback
)
return ret[0], ret[1]
except Exception as e:
print(
f"\033[31m[BizyAir]\033[0m Fail to list share model files: response {ret} error {str(e)}"
)
return [], errnos.LIST_SHARE_MODEL_FILE_ERR
3 changes: 3 additions & 0 deletions src/bizy_server/errno.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@ class errnos:
TOGGLE_USER_LIKE = ErrorNo(500, 500125, None, "Failed to toggle user like")
GET_DOWNLOAD_URL = ErrorNo(500, 500126, None, "Failed to get download url")
DOWNLOAD_JSON = ErrorNo(500, 500127, None, "Failed to download json")
LIST_SHARE_MODEL_FILE_ERR = ErrorNo(
500, 500128, None, "Failed to list share model file"
)
21 changes: 5 additions & 16 deletions src/bizy_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,29 +475,18 @@ async def get_workflow_json(request):
@self.prompt_server.routes.get(f"/{MODEL_HOST_API}" + "/{shareId}/models/files")
async def list_share_model_files(request):
shareId = request.match_info["shareId"]

if not is_string_valid(shareId):
return ErrResponse("INVALID_SHARE_ID")

err = check_type(request.rel_url.query)
if err is not None:
return err

payload = {
"type": request.rel_url.query["type"],
}

if "name" in request.rel_url.query:
payload["name"] = request.rel_url.query["name"]

if "ext_name" in request.rel_url.query:
payload["ext_name"] = request.rel_url.query["ext_name"]
payload = {}
query_params = ["type", "name", "ext_name"]
for param in query_params:
if param in request.rel_url.query and request.rel_url.query[param]:
payload[param] = request.rel_url.query[param]
model_files, err = await self.api_client.get_share_model_files(
shareId=shareId, payload=payload
)
if err is not None:
return ErrResponse(err)

return OKResponse(model_files)

async def send_json(self, event, data, sid=None):
Expand Down

0 comments on commit 9459054

Please sign in to comment.