Skip to content

Commit

Permalink
fix_super_node (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccssu authored Aug 6, 2024
1 parent bfa43ef commit faea95b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
18 changes: 9 additions & 9 deletions src/bizyair/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,18 +398,18 @@ def _(input: str):


@singledispatch
def encode_data(output):
def encode_data(output, disable_image_marker=False):
raise NotImplementedError(f"Unsupported type: {type(output)}")


@encode_data.register(dict)
def _(output):
return {k: encode_data(v) for k, v in output.items()}
def _(output, **kwargs):
return {k: encode_data(v, **kwargs) for k, v in output.items()}


@encode_data.register(list)
def _(output):
return [encode_data(x) for x in output]
def _(output, **kwargs):
return [encode_data(x, **kwargs) for x in output]


def is_image_tensor(tensor) -> bool:
Expand Down Expand Up @@ -440,14 +440,14 @@ def is_image_tensor(tensor) -> bool:


@encode_data.register(torch.Tensor)
def _(output):
if is_image_tensor(output):
def _(output, **kwargs):
if is_image_tensor(output) and not kwargs.get("disable_image_marker", False):
return IMAGE_MARKER + encode_comfy_image(output, image_format="WEBP")
return TENSOR_MARKER + tensor_to_base64(output)


@encode_data.register(BizyAirNodeIO)
def _(output: BizyAirNodeIO):
def _(output: BizyAirNodeIO, **kwargs):
origin_id = output.node_id
origin_slot = output.nodes[origin_id]["outputs"]["slot_index"]
return [origin_id, origin_slot]
Expand All @@ -458,5 +458,5 @@ def _(output: BizyAirNodeIO):
@encode_data.register(str)
@encode_data.register(bool)
@encode_data.register(type(None))
def _(output):
def _(output, **kwargs):
return output
2 changes: 1 addition & 1 deletion supernode.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def super_resolution(self, image):
"content-type": "application/json",
"authorization": auth,
}
input_image = encode_data(image)
input_image = encode_data(image, disable_image_marker=True)
payload["image"] = input_image
payload["is_compress"] = True

Expand Down

0 comments on commit faea95b

Please sign in to comment.