Skip to content

Commit

Permalink
Merge pull request #452 from TashaSkyUp/patch-2
Browse files Browse the repository at this point in the history
High res gpu enabled version of clipseg node.
  • Loading branch information
WAS-PlaiLabs authored Aug 8, 2024
2 parents 30af49f + 88742ad commit df24828
Showing 1 changed file with 163 additions and 0 deletions.
163 changes: 163 additions & 0 deletions WAS_Node_Suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11427,6 +11427,168 @@ def CLIPSeg_image(self, image, text=None, clipseg_model=None):
mask: torch.tensor = mask.unsqueeze(-1)
mask_img = mask.repeat(1, 1, 1, 3)
return (mask, mask_img,)
class CLIPSeg2:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"text": ("STRING", {"default": "", "multiline": False}),
"use_cuda": ("BOOLEAN", {"default": False}),
},
"optional": {
"clipseg_model": ("CLIPSEG_MODEL",),
}
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "apply_transform"

CATEGORY = "image/transformation"

def apply_transform(self, image, text, use_cuda, clipseg_model):
import torch
import torch.nn.functional as F
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

B, H, W, C = image.shape

if B != 1:
raise NotImplementedError("Batch size must be 1")

# Desired slice size and overlap
slice_size = 352
overlap = slice_size // 2

# Calculate the number of slices needed along each dimension
num_slices_h = (H - overlap) // (slice_size - overlap) + 1
num_slices_w = (W - overlap) // (slice_size - overlap) + 1

# Prepare a list to store the slices
slices = []

# Generate the slices
for i in range(num_slices_h):
for j in range(num_slices_w):
start_h = i * (slice_size - overlap)
start_w = j * (slice_size - overlap)

end_h = min(start_h + slice_size, H)
end_w = min(start_w + slice_size, W)

start_h = max(0, end_h - slice_size)
start_w = max(0, end_w - slice_size)

slice_ = image[:, start_h:end_h, start_w:end_w, :]
slices.append(slice_)

# Initialize CLIPSeg model and processor
if clipseg_model:
processor = clipseg_model[0]
model = clipseg_model[1]
else:
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# Move model to CUDA if requested
if use_cuda and torch.cuda.is_available():
model = model.to('cuda')

processor.image_processor.do_rescale = True
processor.image_processor.do_resize = False

image_global = image.permute(0, 3, 1, 2)
image_global = F.interpolate(image_global, size=(slice_size, slice_size), mode='bilinear', align_corners=False)
image_global = image_global.permute(0, 2, 3, 1)
_, image_global = self.CLIPSeg_image(image_global.float(), text, processor, model, use_cuda)
image_global = image_global.permute(0, 3, 1, 2)
image_global = F.interpolate(image_global, size=(H, W), mode='bilinear', align_corners=False)
image_global = image_global.permute(0, 2, 3, 1)

# Apply the transformation to each slice
transformed_slices = []
for slice_ in slices:
transformed_mask, transformed_slice = self.CLIPSeg_image(slice_, text, processor, model, use_cuda)
transformed_slices.append(transformed_slice)

transformed_slices = torch.cat(transformed_slices)

# Initialize tensors for reconstruction
reconstructed_image = torch.zeros((B, H, W, C))
count_map = torch.zeros((B, H, W, C))

# Create a blending mask
mask = np.ones((slice_size, slice_size))
mask[:overlap, :] *= np.linspace(0, 1, overlap)[:, None]
mask[-overlap:, :] *= np.linspace(1, 0, overlap)[:, None]
mask[:, :overlap] *= np.linspace(0, 1, overlap)[None, :]
mask[:, -overlap:] *= np.linspace(1, 0, overlap)[None, :]
mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)

# Place the transformed slices back into the original image dimensions
for idx in range(transformed_slices.shape[0]):
i = idx // num_slices_w
j = idx % num_slices_w

start_h = i * (slice_size - overlap)
start_w = j * (slice_size - overlap)

end_h = min(start_h + slice_size, H)
end_w = min(start_w + slice_size, W)

start_h = max(0, end_h - slice_size)
start_w = max(0, end_w - slice_size)

reconstructed_image[:, start_h:end_h, start_w:end_w, :] += transformed_slices[idx] * mask
count_map[:, start_h:end_h, start_w:end_w, :] += mask

# Avoid division by zero
count_map[count_map == 0] = 1

# Average the overlapping regions
y = reconstructed_image / count_map

total_power = (y + image_global) / 2
just_black = image_global < 0.01

p1 = total_power > .5
p2 = y > .5
p3 = image_global > .5

condition = p1 | p2 | p3
condition = condition & ~just_black
y = torch.where(condition, 1.0, 0.0)

return (y,)

def CLIPSeg_image(self, image, text, processor, model, use_cuda):
import torch
import torchvision.transforms.functional as TF
B, H, W, C = image.shape

import torchvision
with torch.no_grad():
image = image.permute(0, 3, 1, 2).to(torch.float32) * 255

inputs = processor(text=[text] * B, images=image, padding=True, return_tensors="pt")

# Move model and image tensors to CUDA if requested
if use_cuda and torch.cuda.is_available():
model = model.to('cuda')
inputs = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

result = model(**inputs)
t = torch.sigmoid(result[0])
mask = (t - t.min()) / t.max()
mask = torchvision.transforms.functional.resize(mask, (H, W))
mask = mask.unsqueeze(-1)
mask_img = mask.repeat(1, 1, 1, 3)

# Move mask and mask_img back to CPU if they were moved to CUDA
if use_cuda and torch.cuda.is_available():
mask = mask.cpu()
mask_img = mask_img.cpu()

return (mask, mask_img,)

# CLIPSeg Node

Expand Down Expand Up @@ -14258,6 +14420,7 @@ def count_places(self, int_input):
"Write to Video": WAS_Video_Writer,
"VAE Input Switch": WAS_VAE_Input_Switch,
"Video Dump Frames": WAS_Video_Frame_Dump,
"CLIPSEG2": CLIPSeg2
}

#! EXTRA NODES
Expand Down

0 comments on commit df24828

Please sign in to comment.