Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
iYingg committed Oct 31, 2024
1 parent aecf496 commit 357f082
Showing 1 changed file with 0 additions and 90 deletions.
90 changes: 0 additions & 90 deletions supernode.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,99 +137,9 @@ def generate_image(self, prompt, seed, width, height, cfg, batch_size):
return (tensors,)


class BizyAirSegmentAnythingText:
API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/sam"

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"prompt": ("STRING", {}),
"box_threshold": (
"FLOAT",
{"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
),
"text_threshold": (
"FLOAT",
{"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
),
}
}

RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "text_sam"

CATEGORY = "☁️BizyAir/segment-anything"

def text_sam(self, image, prompt, box_threshold, text_threshold):
API_KEY = get_api_key()
SIZE_LIMIT = 1536
device = image.device
_, w, h, c = image.shape
assert (
w <= SIZE_LIMIT and h <= SIZE_LIMIT
), f"width and height must be less than {SIZE_LIMIT}x{SIZE_LIMIT}, but got {w} and {h}"

payload = {
"image": None,
"mode": 1, # 文本分割模式
"params": {
"prompt": prompt,
"box_threshold": box_threshold,
"text_threshold": text_threshold,
},
}
auth = f"Bearer {API_KEY}"
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": auth,
}
image = image.squeeze(0).numpy()
image_pil = Image.fromarray((image * 255).astype(np.uint8))
input_image = encode_image_to_base64(image_pil, format="webp")
payload["image"] = input_image

ret: str = send_post_request(self.API_URL, payload=payload, headers=headers)
ret = json.loads(ret)

try:
if "result" in ret:
ret = json.loads(ret["result"])
except Exception as e:
raise Exception(f"Unexpected response: {ret} {e=}")

if ret["status"] == "error":
raise Exception(ret["message"])

msg = ret["data"]
if msg["type"] not in ("bizyair",):
raise Exception(f"Unexpected response type: {msg}")

if "error" in msg:
raise Exception(f"Error happens: {msg}")

img = msg["image"]
mask_image = msg["mask_image"]

img = (
(torch.from_numpy(decode_base64_to_np(img)).float() / 255.0)
.unsqueeze(0)
.to(device)
)
img_mask = (
torch.from_numpy(decode_base64_to_np(mask_image)).float() / 255.0
).to(device)
img_mask = img_mask.mean(dim=-1)
img_mask = img_mask.unsqueeze(0)
return (img, img_mask)


NODE_CLASS_MAPPINGS = {
"BizyAirRemoveBackground": RemoveBackground,
"BizyAirGenerateLightningImage": GenerateLightningImage,
"BizyAirSegmentAnythingText": BizyAirSegmentAnythingText,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BizyAirRemoveBackground": "☁️BizyAir Remove Image Background",
Expand Down

0 comments on commit 357f082

Please sign in to comment.