From 88742adb10bd25e5eed02af5394632e0a32a4bf7 Mon Sep 17 00:00:00 2001 From: Tasha Upchurch Date: Tue, 30 Jul 2024 16:47:38 -0600 Subject: [PATCH] High res gpu enabled version of clipseg node. Enables higher resolution segmentation via running the model on patches combining and filtering results. --- WAS_Node_Suite.py | 163 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/WAS_Node_Suite.py b/WAS_Node_Suite.py index 82dd5c7..46b834c 100644 --- a/WAS_Node_Suite.py +++ b/WAS_Node_Suite.py @@ -11427,6 +11427,168 @@ def CLIPSeg_image(self, image, text=None, clipseg_model=None): mask: torch.tensor = mask.unsqueeze(-1) mask_img = mask.repeat(1, 1, 1, 3) return (mask, mask_img,) +class CLIPSeg2: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "text": ("STRING", {"default": "", "multiline": False}), + "use_cuda": ("BOOLEAN", {"default": False}), + }, + "optional": { + "clipseg_model": ("CLIPSEG_MODEL",), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "apply_transform" + + CATEGORY = "image/transformation" + + def apply_transform(self, image, text, use_cuda, clipseg_model): + import torch + import torch.nn.functional as F + from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation + + B, H, W, C = image.shape + + if B != 1: + raise NotImplementedError("Batch size must be 1") + + # Desired slice size and overlap + slice_size = 352 + overlap = slice_size // 2 + + # Calculate the number of slices needed along each dimension + num_slices_h = (H - overlap) // (slice_size - overlap) + 1 + num_slices_w = (W - overlap) // (slice_size - overlap) + 1 + + # Prepare a list to store the slices + slices = [] + + # Generate the slices + for i in range(num_slices_h): + for j in range(num_slices_w): + start_h = i * (slice_size - overlap) + start_w = j * (slice_size - overlap) + + end_h = min(start_h + slice_size, H) + end_w = min(start_w + slice_size, W) + + start_h = max(0, end_h - slice_size) + start_w = max(0, end_w - slice_size) + + slice_ = image[:, start_h:end_h, start_w:end_w, :] + slices.append(slice_) + + # Initialize CLIPSeg model and processor + if clipseg_model: + processor = clipseg_model[0] + model = clipseg_model[1] + else: + processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") + # Move model to CUDA if requested + if use_cuda and torch.cuda.is_available(): + model = model.to('cuda') + + processor.image_processor.do_rescale = True + processor.image_processor.do_resize = False + + image_global = image.permute(0, 3, 1, 2) + image_global = F.interpolate(image_global, size=(slice_size, slice_size), mode='bilinear', align_corners=False) + image_global = image_global.permute(0, 2, 3, 1) + _, image_global = self.CLIPSeg_image(image_global.float(), text, processor, model, use_cuda) + image_global = image_global.permute(0, 3, 1, 2) + image_global = F.interpolate(image_global, size=(H, W), mode='bilinear', align_corners=False) + image_global = image_global.permute(0, 2, 3, 1) + + # Apply the transformation to each slice + transformed_slices = [] + for slice_ in slices: + transformed_mask, transformed_slice = self.CLIPSeg_image(slice_, text, processor, model, use_cuda) + transformed_slices.append(transformed_slice) + + transformed_slices = torch.cat(transformed_slices) + + # Initialize tensors for reconstruction + reconstructed_image = torch.zeros((B, H, W, C)) + count_map = torch.zeros((B, H, W, C)) + + # Create a blending mask + mask = np.ones((slice_size, slice_size)) + mask[:overlap, :] *= np.linspace(0, 1, overlap)[:, None] + mask[-overlap:, :] *= np.linspace(1, 0, overlap)[:, None] + mask[:, :overlap] *= np.linspace(0, 1, overlap)[None, :] + mask[:, -overlap:] *= np.linspace(1, 0, overlap)[None, :] + mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0).unsqueeze(-1) + + # Place the transformed slices back into the original image dimensions + for idx in range(transformed_slices.shape[0]): + i = idx // num_slices_w + j = idx % num_slices_w + + start_h = i * (slice_size - overlap) + start_w = j * (slice_size - overlap) + + end_h = min(start_h + slice_size, H) + end_w = min(start_w + slice_size, W) + + start_h = max(0, end_h - slice_size) + start_w = max(0, end_w - slice_size) + + reconstructed_image[:, start_h:end_h, start_w:end_w, :] += transformed_slices[idx] * mask + count_map[:, start_h:end_h, start_w:end_w, :] += mask + + # Avoid division by zero + count_map[count_map == 0] = 1 + + # Average the overlapping regions + y = reconstructed_image / count_map + + total_power = (y + image_global) / 2 + just_black = image_global < 0.01 + + p1 = total_power > .5 + p2 = y > .5 + p3 = image_global > .5 + + condition = p1 | p2 | p3 + condition = condition & ~just_black + y = torch.where(condition, 1.0, 0.0) + + return (y,) + + def CLIPSeg_image(self, image, text, processor, model, use_cuda): + import torch + import torchvision.transforms.functional as TF + B, H, W, C = image.shape + + import torchvision + with torch.no_grad(): + image = image.permute(0, 3, 1, 2).to(torch.float32) * 255 + + inputs = processor(text=[text] * B, images=image, padding=True, return_tensors="pt") + + # Move model and image tensors to CUDA if requested + if use_cuda and torch.cuda.is_available(): + model = model.to('cuda') + inputs = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + + result = model(**inputs) + t = torch.sigmoid(result[0]) + mask = (t - t.min()) / t.max() + mask = torchvision.transforms.functional.resize(mask, (H, W)) + mask = mask.unsqueeze(-1) + mask_img = mask.repeat(1, 1, 1, 3) + + # Move mask and mask_img back to CPU if they were moved to CUDA + if use_cuda and torch.cuda.is_available(): + mask = mask.cpu() + mask_img = mask_img.cpu() + + return (mask, mask_img,) # CLIPSeg Node @@ -14258,6 +14420,7 @@ def count_places(self, int_input): "Write to Video": WAS_Video_Writer, "VAE Input Switch": WAS_VAE_Input_Switch, "Video Dump Frames": WAS_Video_Frame_Dump, + "CLIPSEG2": CLIPSeg2 } #! EXTRA NODES