From 357f082147b8ba01668b1827e4728b7d9692ede0 Mon Sep 17 00:00:00 2001 From: zhujintao <2030442731@qq.com> Date: Thu, 31 Oct 2024 11:04:39 +0800 Subject: [PATCH] merge --- supernode.py | 90 ---------------------------------------------------- 1 file changed, 90 deletions(-) diff --git a/supernode.py b/supernode.py index 7a25f04f..75a1b9d9 100644 --- a/supernode.py +++ b/supernode.py @@ -137,99 +137,9 @@ def generate_image(self, prompt, seed, width, height, cfg, batch_size): return (tensors,) -class BizyAirSegmentAnythingText: - API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/sam" - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "prompt": ("STRING", {}), - "box_threshold": ( - "FLOAT", - {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, - ), - "text_threshold": ( - "FLOAT", - {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, - ), - } - } - - RETURN_TYPES = ("IMAGE", "MASK") - FUNCTION = "text_sam" - - CATEGORY = "☁️BizyAir/segment-anything" - - def text_sam(self, image, prompt, box_threshold, text_threshold): - API_KEY = get_api_key() - SIZE_LIMIT = 1536 - device = image.device - _, w, h, c = image.shape - assert ( - w <= SIZE_LIMIT and h <= SIZE_LIMIT - ), f"width and height must be less than {SIZE_LIMIT}x{SIZE_LIMIT}, but got {w} and {h}" - - payload = { - "image": None, - "mode": 1, # 文本分割模式 - "params": { - "prompt": prompt, - "box_threshold": box_threshold, - "text_threshold": text_threshold, - }, - } - auth = f"Bearer {API_KEY}" - headers = { - "accept": "application/json", - "content-type": "application/json", - "authorization": auth, - } - image = image.squeeze(0).numpy() - image_pil = Image.fromarray((image * 255).astype(np.uint8)) - input_image = encode_image_to_base64(image_pil, format="webp") - payload["image"] = input_image - - ret: str = send_post_request(self.API_URL, payload=payload, headers=headers) - ret = json.loads(ret) - - try: - if "result" in ret: - ret = json.loads(ret["result"]) - except Exception as e: - raise Exception(f"Unexpected response: {ret} {e=}") - - if ret["status"] == "error": - raise Exception(ret["message"]) - - msg = ret["data"] - if msg["type"] not in ("bizyair",): - raise Exception(f"Unexpected response type: {msg}") - - if "error" in msg: - raise Exception(f"Error happens: {msg}") - - img = msg["image"] - mask_image = msg["mask_image"] - - img = ( - (torch.from_numpy(decode_base64_to_np(img)).float() / 255.0) - .unsqueeze(0) - .to(device) - ) - img_mask = ( - torch.from_numpy(decode_base64_to_np(mask_image)).float() / 255.0 - ).to(device) - img_mask = img_mask.mean(dim=-1) - img_mask = img_mask.unsqueeze(0) - return (img, img_mask) - - NODE_CLASS_MAPPINGS = { "BizyAirRemoveBackground": RemoveBackground, "BizyAirGenerateLightningImage": GenerateLightningImage, - "BizyAirSegmentAnythingText": BizyAirSegmentAnythingText, } NODE_DISPLAY_NAME_MAPPINGS = { "BizyAirRemoveBackground": "☁️BizyAir Remove Image Background",