Skip to content

Commit

Permalink
enhance: add original image handling and return types in BizyAirSegme…
Browse files Browse the repository at this point in the history
…ntAnythingPointBox
  • Loading branch information
lcolok committed Nov 28, 2024
1 parent 19307df commit 788d62a
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def INPUT_TYPES(s):
},
}

RETURN_TYPES = ("IMAGE", "MASK")
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
RETURN_NAMES = ("processed_image", "mask", "original_image")
FUNCTION = "apply"

CATEGORY = "☁️BizyAir/segment-anything"
Expand All @@ -145,10 +146,13 @@ def apply(self, image, is_point):
API_KEY = get_api_key()
SIZE_LIMIT = 1536

image, _ = LoadImage().load_image(image)
# 加载原始图像
original_image, _ = LoadImage().load_image(image)
# 直接克隆原始图像用于处理和透传
image_to_process = original_image.clone()

device = image.device
_, w, h, c = image.shape
device = image_to_process.device
_, w, h, c = image_to_process.shape
assert (
w <= SIZE_LIMIT and h <= SIZE_LIMIT
), f"width and height must be less than {SIZE_LIMIT}x{SIZE_LIMIT}, but got {w} and {h}"
Expand All @@ -167,7 +171,7 @@ def apply(self, image, is_point):
input_label = [coord["pointType"] for coord in coordinates]
payload = {
"image": None,
"mode": INFER_MODE.points_box.value, # Point/Box分割模式
"mode": INFER_MODE.points_box.value,
"params": {
"input_points": json.dumps(input_points),
"input_label": json.dumps(input_label),
Expand All @@ -191,7 +195,7 @@ def apply(self, image, is_point):

payload = {
"image": None,
"mode": INFER_MODE.batched_boxes.value, # Point/Box分割模式
"mode": INFER_MODE.batched_boxes.value,
"params": {
"input_points": None,
"input_label": None,
Expand All @@ -205,8 +209,9 @@ def apply(self, image, is_point):
"content-type": "application/json",
"authorization": auth,
}
image = image.squeeze(0).numpy()
image_pil = Image.fromarray((image * 255).astype(np.uint8))
# 处理用于API的图像
api_image = image_to_process.squeeze(0).numpy()
image_pil = Image.fromarray((api_image * 255).astype(np.uint8))
input_image = encode_image_to_base64(image_pil, format="webp")
payload["image"] = input_image

Expand All @@ -232,7 +237,7 @@ def apply(self, image, is_point):
img = msg["image"]
mask_image = msg["mask_image"]

img = (
processed_img = (
(torch.from_numpy(decode_base64_to_np(img)).float() / 255.0)
.unsqueeze(0)
.to(device)
Expand All @@ -243,7 +248,8 @@ def apply(self, image, is_point):
img_mask = img_mask.mean(dim=-1)
img_mask = img_mask.unsqueeze(0)

return (img, img_mask)
# 直接返回克隆的原始图像,无需转换
return (processed_img, img_mask, image_to_process)

@classmethod
def IS_CHANGED(s, image, is_point):
Expand Down

0 comments on commit 788d62a

Please sign in to comment.