-
Notifications
You must be signed in to change notification settings - Fork 43
/
datamodel.py
68 lines (57 loc) · 1.76 KB
/
datamodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# from dataclasses import dataclass
from typing import Any, List, Optional, Union
from pydantic.dataclasses import dataclass
# from typer import Option
import os
@dataclass
class ModelConfig:
""" Configuration for the HF Diffuser Model"""
model: Optional[str] = "runwayml/stable-diffusion-v1-5"
token: Optional[str] = os.environ.get("HF_API_TOKEN")
device: Optional[str] = None
revision: Optional[str] = "fp16"
@dataclass
class GeneratorConfig:
"""Configuration for a generation"""
prompt: Union[str, List[str]]
num_images: int = 1
height: Optional[int] = 512
width: Optional[int] = 512
num_inference_steps: Optional[int] = 25
guidance_scale: Optional[float] = 7.5
eta: Optional[float] = 0.0
strength: float = 0.8
init_image: Optional[Any] = None
seed: Optional[Union[int, None]] = None # e.g. 2147483647
return_intermediates: bool = False
mask_image: Optional[Any] = None
attention_slice: Optional[Union[str, int]] = None
negative_prompt: Union[str, List[str]] = None
latents: Optional[Any] = None
callback: Optional[Any] = None
prompt_weights: Optional[List[float]] = None
use_prompt_weights: bool = False
application: Optional[Any] = None
text_embeddings: Optional[Any] = None
filter_nsfw: bool = True
@dataclass
class WebRequestData:
"""Data sent over the websocket"""
type: str
config: GeneratorConfig
@dataclass
class PreviewQuery:
"""Query data to preview a prompt"""
prompt: str
@dataclass
class SocketData:
"""Data sent over the websocket"""
data: Any
type: str
token: Optional[str] = None
@dataclass
class ModelResponse:
"""Response from the model"""
status: bool
message: Optional[str] = None
data: Optional[Any] = None