forked from ptrckqnln/runpod-worker-oobabooga
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rp_handler.py
151 lines (118 loc) · 4.29 KB
/
rp_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import json
import time
import requests
import runpod
from runpod.serverless.utils.rp_validator import validate
from runpod.serverless.modules.rp_logger import RunPodLogger
from requests.adapters import HTTPAdapter, Retry
from schemas.api import API_SCHEMA
from schemas.chat import CHAT_SCHEMA
from schemas.generate import GENERATE_SCHEMA
from schemas.token_count import TOKEN_COUNT_SCHEMA
from schemas.model import MODEL_SCHEMA
BASE_URL = 'http://127.0.0.1:5000/api/v1'
TIMEOUT = 600
VALIDATION_SCHEMAS = {
'chat': CHAT_SCHEMA,
'generate': GENERATE_SCHEMA,
'token-count': TOKEN_COUNT_SCHEMA
}
session = requests.Session()
retries = Retry(total=10, backoff_factor=0.1, status_forcelist=[502, 503, 504])
session.mount('http://', HTTPAdapter(max_retries=retries))
logger = RunPodLogger()
# ---------------------------------------------------------------------------- #
# Application Functions #
# ---------------------------------------------------------------------------- #
def wait_for_service(url):
retries = 0
while True:
try:
requests.get(url)
return
except requests.exceptions.RequestException:
retries += 1
# Only log every 15 retries so the logs don't get spammed
if retries % 15 == 0:
logger.info('Service not ready yet. Retrying...')
except Exception as err:
logger.error(f'Error: {err}')
time.sleep(0.2)
def send_get_request(endpoint):
return session.get(
url=f'{BASE_URL}/{endpoint}',
timeout=TIMEOUT
)
def send_post_request(endpoint, payload):
return session.post(
url=f'{BASE_URL}/{endpoint}',
json=payload,
timeout=TIMEOUT
)
def validate_api(event):
if 'api' not in event['input']:
return {
'errors': '"api" is a required field in the "input" payload'
}
api = event['input']['api']
if type(api) is not dict:
return {
'errors': '"api" must be a dictionary containing "method" and "endpoint"'
}
api['endpoint'] = api['endpoint'].lstrip('/')
return validate(api, API_SCHEMA)
def validate_payload(event):
method = event['input']['api']['method']
endpoint = event['input']['api']['endpoint']
payload = event['input']['payload']
validated_input = {}
if endpoint == 'generate':
validated_input = validate(payload, GENERATE_SCHEMA)
elif endpoint == 'chat':
validated_input = validate(payload, CHAT_SCHEMA)
elif endpoint == 'token-count':
validated_input = validate(payload, TOKEN_COUNT_SCHEMA)
elif endpoint == 'model' and method == 'POST':
validated_input = validate(payload, MODEL_SCHEMA)
return endpoint, event['input']['api']['method'], validated_input
# ---------------------------------------------------------------------------- #
# RunPod Handler #
# ---------------------------------------------------------------------------- #
def handler(event):
validated_api = validate_api(event)
if 'errors' in validated_api:
return {
'error': validated_api['errors']
}
payload=event["input"]["payload"]
endpoint=event["input"]["api"]["endpoint"]
method=event["input"]["api"]["method"]
# endpoint, method, validated_input = validate_payload(event)
# if 'errors' in validated_input:
# return {
# 'error': validated_input['errors']
# }
# if 'validated_input' in validated_input:
# payload = validated_input['validated_input']
# else:
# payload = {}
try:
logger.log(f'Sending {method} request to: {endpoint}')
if method == 'GET':
response = send_post_request(endpoint, payload)
elif method == 'POST':
response = send_post_request(endpoint, payload)
except Exception as e:
return {
'error': str(e)
}
return response.json()
if __name__ == '__main__':
wait_for_service(url='http://127.0.0.1:5000/api/v1/model')
logger.log('Oobabooga API is ready', 'INFO')
logger.log('Starting RunPod Serverless...', 'INFO')
runpod.serverless.start(
{
'handler': handler
}
)