forked from OPHoperHPO/image-background-remove-tool
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conftest.py
138 lines (99 loc) · 4.64 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from pathlib import Path
import pytest
import torch
from PIL import Image
from typing import Callable, Tuple, List, Union, Optional, Any
from carvekit.api.high import HiInterface
from carvekit.api.interface import Interface
from carvekit.trimap.cv_gen import CV2TrimapGenerator
from carvekit.trimap.generator import TrimapGenerator
from carvekit.utils.image_utils import convert_image, load_image
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
@pytest.fixture()
def u2net_model() -> Callable[[], U2NET]:
return lambda: U2NET(layers_cfg="full",
device='cuda' if torch.cuda.is_available() else 'cpu',
input_image_size=320,
batch_size=10,
load_pretrained=True)
@pytest.fixture()
def trimap_instance() -> Callable[[], TrimapGenerator]:
return lambda: TrimapGenerator()
@pytest.fixture()
def cv2_trimap_instance() -> Callable[[], CV2TrimapGenerator]:
return lambda: CV2TrimapGenerator(kernel_size=30, erosion_iters=0)
@pytest.fixture()
def preprocessing_stub_instance() -> Callable[[], PreprocessingStub]:
return lambda: PreprocessingStub()
@pytest.fixture()
def matting_method_instance(fba_model, trimap_instance):
return lambda: MattingMethod(matting_module=fba_model(), trimap_generator=trimap_instance(), device="cpu")
@pytest.fixture()
def high_interface_instance() -> Callable[[], HiInterface]:
return lambda: HiInterface(batch_size_seg=5, batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=320, matting_mask_size=2048)
@pytest.fixture()
def interface_instance(u2net_model, preprocessing_stub_instance,
matting_method_instance) -> Callable[[], Interface]:
return lambda: Interface(u2net_model(),
pre_pipe=preprocessing_stub_instance(),
post_pipe=matting_method_instance(),
device='cuda' if torch.cuda.is_available() else 'cpu')
@pytest.fixture()
def fba_model() -> Callable[[], FBAMatting]:
return lambda: FBAMatting(device='cuda' if torch.cuda.is_available() else 'cpu',
input_tensor_size=1024,
batch_size=2,
load_pretrained=True)
@pytest.fixture()
def deeplabv3_model() -> Callable[[], DeepLabV3]:
return lambda: DeepLabV3(device='cuda' if torch.cuda.is_available() else 'cpu',
batch_size=10,
load_pretrained=True)
@pytest.fixture()
def basnet_model() -> Callable[[], BASNET]:
return lambda: BASNET(device='cuda' if torch.cuda.is_available() else 'cpu',
input_tensor_size=320,
batch_size=10,
load_pretrained=True)
@pytest.fixture()
def image_str(image_path) -> str:
return str(image_path.absolute())
@pytest.fixture()
def image_path() -> Path:
return Path(__file__).parent.joinpath('tests').joinpath('data', 'cat.jpg')
@pytest.fixture()
def image_mask(image_path) -> Image.Image:
return Image.open(image_path.with_name('cat_mask').with_suffix(".png"))
@pytest.fixture()
def image_trimap(image_path) -> Image.Image:
return Image.open(image_path.with_name('cat_trimap').with_suffix(".png")).convert("L")
@pytest.fixture()
def image_pil(image_path) -> Image.Image:
return Image.open(image_path)
@pytest.fixture()
def black_image_pil() -> Image.Image:
return Image.new("RGB", (512, 512))
@pytest.fixture()
def converted_pil_image(image_pil) -> Image.Image:
return convert_image(load_image(image_pil))
@pytest.fixture()
def available_models(u2net_model, deeplabv3_model, basnet_model,
preprocessing_stub_instance, matting_method_instance) -> Tuple[
List[Union[Callable[[], U2NET], Callable[[], DeepLabV3], Callable[[], BASNET]]], List[
Optional[Callable[[], PreprocessingStub]]], List[Union[Optional[Callable[[], MattingMethod]], Any]]]:
models = [u2net_model, deeplabv3_model, basnet_model]
pre_pipes = [None, preprocessing_stub_instance]
post_pipes = [None, matting_method_instance]
return models, pre_pipes, post_pipes