Skip to content

Commit

Permalink
dont hardcode cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
dikarel committed Mar 4, 2024
1 parent 38e2d68 commit 12f157f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
13 changes: 13 additions & 0 deletions lib/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch import cuda, float16


def cuda_if_available() -> str:
return "cuda" if cuda.is_available() else "cpu"


def fp16_if_available() -> str | None:
return "fp16" if cuda.is_available() else None


def float16_if_available():
return float16 if cuda.is_available() else None
7 changes: 4 additions & 3 deletions lib/find_people.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from torch import zeros_like
from torch.nn.functional import interpolate
from functools import cache
from lib.cuda import cuda_if_available


def find_people(image: PILImage) -> PILImage:
processor = get_seg_processor()
model = get_seg_model()

inputs = processor(images=image, return_tensors="pt").to("cuda")
inputs = processor(images=image, return_tensors="pt").to(cuda_if_available())
logits = model(**inputs).logits.cpu()

upsampled_logits = interpolate(
Expand All @@ -33,12 +34,12 @@ def find_people(image: PILImage) -> PILImage:
@cache
def get_seg_processor():
return SegformerImageProcessor.from_pretrained(
"mattmdjaga/segformer_b2_clothes", device="cuda"
"mattmdjaga/segformer_b2_clothes", device=cuda_if_available()
)


@cache
def get_seg_model():
return AutoModelForSemanticSegmentation.from_pretrained(
"mattmdjaga/segformer_b2_clothes"
).to("cuda")
).to(cuda_if_available())
8 changes: 4 additions & 4 deletions lib/redraw_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from diffusers import StableDiffusionInpaintPipeline
from lib.optimize import optimize_sd_model
from lib.pad_image import pad_image, unpad_image
from torch import float16
from lib.cuda import cuda_if_available, float16_if_available, fp16_if_available


def redraw_image(prompt: str, image: PILImage, mask: PILImage) -> PILImage:
Expand All @@ -29,8 +29,8 @@ def redraw_image(prompt: str, image: PILImage, mask: PILImage) -> PILImage:
def get_inpaint_model():
inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=float16,
).to("cuda")
revision=fp16_if_available(),
torch_dtype=float16_if_available(),
).to(cuda_if_available())

return optimize_sd_model(inpaint_model)

0 comments on commit 12f157f

Please sign in to comment.