-
Notifications
You must be signed in to change notification settings - Fork 24
/
mzkolors.py
73 lines (56 loc) · 1.72 KB
/
mzkolors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import uuid
import torch
from .utils import (
decode_and_deserialize,
send_post_request,
serialize_and_encode,
get_api_key,
)
CATEGORY_NAME = "☁️BizyAir/Kolors"
BIZYAIR_SERVER_ADDRESS = os.getenv(
"BIZYAIR_SERVER_ADDRESS", "https://api.siliconflow.cn"
)
class BizyAirMZChatGLM3TextEncode:
API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/mzkolorschatglm3"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = CATEGORY_NAME
def encode(self, text):
API_KEY = get_api_key()
assert len(text) <= 4096, f"the prompt is too long, length: {len(text)}"
payload = {
"text": text,
}
auth = f"Bearer {API_KEY}"
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": auth,
}
response: str = send_post_request(
self.API_URL, payload=payload, headers=headers
)
tensors_np = decode_and_deserialize(response)
ret_conditioning = []
for item in tensors_np:
t, d = item
t_tensor = torch.from_numpy(t)
d_dict = {}
for k, v in d.items():
d_dict[k] = torch.from_numpy(v)
ret_conditioning.append([t_tensor, d_dict])
return (ret_conditioning,)
NODE_CLASS_MAPPINGS = {
"BizyAirMZChatGLM3TextEncode": BizyAirMZChatGLM3TextEncode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BizyAirMZChatGLM3TextEncode": "☁️BizyAir ChatGLM3 Text Encode",
}