diff --git a/segment_anything.py b/segment_anything.py index de6cc78e..6e478ac7 100644 --- a/segment_anything.py +++ b/segment_anything.py @@ -136,7 +136,8 @@ def INPUT_TYPES(s): }, } - RETURN_TYPES = ("IMAGE", "MASK") + RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") + RETURN_NAMES = ("processed_image", "mask", "original_image") FUNCTION = "apply" CATEGORY = "☁️BizyAir/segment-anything" @@ -145,10 +146,13 @@ def apply(self, image, is_point): API_KEY = get_api_key() SIZE_LIMIT = 1536 - image, _ = LoadImage().load_image(image) + # 加载原始图像 + original_image, _ = LoadImage().load_image(image) + # 直接克隆原始图像用于处理和透传 + image_to_process = original_image.clone() - device = image.device - _, w, h, c = image.shape + device = image_to_process.device + _, w, h, c = image_to_process.shape assert ( w <= SIZE_LIMIT and h <= SIZE_LIMIT ), f"width and height must be less than {SIZE_LIMIT}x{SIZE_LIMIT}, but got {w} and {h}" @@ -167,7 +171,7 @@ def apply(self, image, is_point): input_label = [coord["pointType"] for coord in coordinates] payload = { "image": None, - "mode": INFER_MODE.points_box.value, # Point/Box分割模式 + "mode": INFER_MODE.points_box.value, "params": { "input_points": json.dumps(input_points), "input_label": json.dumps(input_label), @@ -191,7 +195,7 @@ def apply(self, image, is_point): payload = { "image": None, - "mode": INFER_MODE.batched_boxes.value, # Point/Box分割模式 + "mode": INFER_MODE.batched_boxes.value, "params": { "input_points": None, "input_label": None, @@ -205,8 +209,9 @@ def apply(self, image, is_point): "content-type": "application/json", "authorization": auth, } - image = image.squeeze(0).numpy() - image_pil = Image.fromarray((image * 255).astype(np.uint8)) + # 处理用于API的图像 + api_image = image_to_process.squeeze(0).numpy() + image_pil = Image.fromarray((api_image * 255).astype(np.uint8)) input_image = encode_image_to_base64(image_pil, format="webp") payload["image"] = input_image @@ -232,7 +237,7 @@ def apply(self, image, is_point): img = msg["image"] mask_image = msg["mask_image"] - img = ( + processed_img = ( (torch.from_numpy(decode_base64_to_np(img)).float() / 255.0) .unsqueeze(0) .to(device) @@ -243,7 +248,8 @@ def apply(self, image, is_point): img_mask = img_mask.mean(dim=-1) img_mask = img_mask.unsqueeze(0) - return (img, img_mask) + # 直接返回克隆的原始图像,无需转换 + return (processed_img, img_mask, image_to_process) @classmethod def IS_CHANGED(s, image, is_point):