Skip to content

Commit

Permalink
Merge branch 'main' of github.com:dikarel/nice-outfit
Browse files Browse the repository at this point in the history
  • Loading branch information
dikarel committed Mar 4, 2024
2 parents 9ac70fa + 6a40b6d commit 56ec250
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
14 changes: 8 additions & 6 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from lib.resize_image import resize_image

OUTFIT_SELECTION = [
"Camel trench coat",
"Fuzzy white winter coat",
"Golden ball gown with tiara",
"Mondrian-inspired haute couture",
"Rock and roll t-shirt",
"Summer dress",
"Winter coat",
"Bridal gown",
"Fall jacket",
"Formal wear",
"Ripped swole athletic chest",
"Tie dye t-shirt",
"Velvet dinner jacket",
]


Expand Down Expand Up @@ -44,7 +46,7 @@ def generate_output(img_input: PILImage, drp_outfit: str) -> PILImage:

people_mask = find_people(img_input)
img_output = redraw_image(
prompt=f"person wearing {drp_outfit}",
prompt=f"best quality. high 4k resolution. person wearing {drp_outfit}",
image=img_input,
mask=people_mask,
)
Expand Down
7 changes: 7 additions & 0 deletions lib/expand_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from PIL import Image, ImageFilter


def expand_mask(mask: Image.Image, factor: int = 5):
mask = mask.filter(ImageFilter.MaxFilter(factor * 2 + 1))
mask = mask.filter(ImageFilter.GaussianBlur(factor / 2))
return mask
8 changes: 7 additions & 1 deletion lib/find_people.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.functional import interpolate
from functools import cache
from lib.cuda import cuda_if_available
from lib.expand_mask import expand_mask


def find_people(image: PILImage) -> PILImage:
Expand All @@ -28,7 +29,12 @@ def find_people(image: PILImage) -> PILImage:
for type in everyhing_but_background_face_and_hair():
mask += (predictions == type.value).long()

return Image.fromarray((mask * 255).byte().numpy(), "L")
mask = Image.fromarray((mask * 255).byte().numpy(), "L")

# Don't have the segmentation be too strict. Loosen it up a bit
mask = expand_mask(mask, 20)

return mask


@cache
Expand Down
3 changes: 2 additions & 1 deletion lib/redraw_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def redraw_image(prompt: str, image: PILImage, mask: PILImage) -> PILImage:

output = inpaint_model(
prompt=prompt,
negative_prompt="low quality, low resolution, dark, scary, nsfw",
image=image,
mask_image=mask,
width=image.width,
height=image.height,
num_inference_steps=30,
num_inference_steps=60,
).images[0]

return unpad_image(output, original_image_size)
Expand Down

0 comments on commit 56ec250

Please sign in to comment.