-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
220 lines (192 loc) · 8.58 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from pyspark.sql.functions import pandas_udf
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql.functions import col, udf, length, pandas_udf
import os
import mlflow
import yaml
import time
from typing import Iterator
from mlflow import MlflowClient
mlflow.set_registry_uri('databricks-uc')
########################################################################
###### Functions for setting up vector search index for RAG and cache
########################################################################
def vs_endpoint_exists(vsc, vs_endpoint_name):
'''Check if a vector search endpoint exists'''
try:
return vs_endpoint_name in [e['name'] for e in vsc.list_endpoints().get('endpoints', [])]
except Exception as e:
#Temp fix for potential REQUEST_LIMIT_EXCEEDED issue
if "REQUEST_LIMIT_EXCEEDED" in str(e):
print("WARN: couldn't get endpoint status due to REQUEST_LIMIT_EXCEEDED error. The demo will consider it exists")
return True
else:
raise e
def create_or_wait_for_endpoint(vsc, vs_endpoint_name):
'''Create a vector search endpoint if it doesn't exist. If it does exist, wait for it to be ready'''
if not vs_endpoint_exists(vsc, vs_endpoint_name):
vsc.create_endpoint(name=vs_endpoint_name, endpoint_type="STANDARD")
wait_for_vs_endpoint_to_be_ready(vsc, vs_endpoint_name)
def wait_for_vs_endpoint_to_be_ready(vsc, vs_endpoint_name):
'''Wait for a vector search endpoint to be ready'''
for i in range(180):
try:
endpoint = vsc.get_endpoint(vs_endpoint_name)
except Exception as e:
#Temp fix for potential REQUEST_LIMIT_EXCEEDED issue
if "REQUEST_LIMIT_EXCEEDED" in str(e):
print("WARN: couldn't get endpoint status due to REQUEST_LIMIT_EXCEEDED error. Please manually check your endpoint status")
return
else:
raise e
status = endpoint.get("endpoint_status", endpoint.get("status"))["state"].upper()
if "ONLINE" in status:
return endpoint
elif "PROVISIONING" in status or i <6:
if i % 20 == 0:
print(f"Waiting for endpoint to be ready, this can take a few min... {endpoint}")
time.sleep(10)
else:
raise Exception(f'''Error with the endpoint {vs_endpoint_name}. - this shouldn't happen: {endpoint}.\n Please delete it and re-run the previous cell: vsc.delete_endpoint("{vs_endpoint_name}")''')
raise Exception(f"Timeout, your endpoint isn't ready yet: {vsc.get_endpoint(vs_endpoint_name)}")
def delete_endpoint(vsc, vs_endpoint_name):
'''Delete a vector search endpoint'''
print(f"Deleting endpoint {vs_endpoint_name}...")
try:
vsc.delete_endpoint(vs_endpoint_name)
print(f"Endpoint {vs_endpoint_name} deleted successfully")
except Exception as e:
print(f"Error deleting endpoint {vs_endpoint_name}: {str(e)}")
def index_exists(vsc, vs_endpont_name, vs_index_name):
'''Check if a vector search index exists'''
try:
vsc.get_index(vs_endpont_name, vs_index_name).describe()
return True
except Exception as e:
if 'RESOURCE_DOES_NOT_EXIST' not in str(e):
print(f'Unexpected error describing the index. This could be a permission issue.')
raise e
return False
def wait_for_index_to_be_ready(vsc, vs_endpoint_name, vs_index_fullname):
'''Wait for a vector search index to be ready'''
for i in range(180):
idx = vsc.get_index(vs_endpoint_name, vs_index_fullname).describe()
index_status = idx.get('status', idx.get('index_status', {}))
status = index_status.get('detailed_state', index_status.get('status', 'UNKNOWN')).upper()
url = index_status.get('index_url', index_status.get('url', 'UNKNOWN'))
if "ONLINE" in status:
return
if "UNKNOWN" in status:
print(f"Can't get the status - will assume index is ready {idx} - url: {url}")
return
elif "PROVISIONING" in status:
if i % 40 == 0: print(f"Waiting for index to be ready, this can take a few min... {index_status} - pipeline url:{url}")
time.sleep(10)
else:
raise Exception(f'''Error with the index - this shouldn't happen. DLT pipeline might have been killed.\n Please delete it and re-run the previous cell: vsc.delete_index("{vs_index_fullname}, {vs_endpoint_name}") \nIndex details: {idx}''')
raise Exception(f"Timeout, your index isn't ready yet: {vsc.get_index(vs_index_fullname, vs_endpoint_name)}")
def create_or_update_direct_index(vsc, vs_endpoint_name, vs_index_fullname, vector_search_index_schema, vector_search_index_config):
'''Create a direct access vector search index if it doesn't exist. If it does exist, update it.'''
try:
vsc.create_direct_access_index(
endpoint_name=vs_endpoint_name,
index_name=vs_index_fullname,
schema=vector_search_index_schema,
**vector_search_index_config
)
except Exception as e:
if 'RESOURCE_ALREADY_EXISTS' not in str(e):
print(f'Unexpected error...')
raise e
wait_for_index_to_be_ready(vsc, vs_endpoint_name, vs_index_fullname)
print(f"index {vs_index_fullname} is ready")
#######################################################################
###### Functions for deploying a chain in Model Serving
#######################################################################
def get_latest_model_version(model_name):
'''Get the latest model version for a given model name'''
mlflow_client = MlflowClient(registry_uri="databricks-uc")
latest_version = 1
for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
version_int = int(mv.version)
if version_int > latest_version:
latest_version = version_int
return latest_version
def deploy_model_serving_endpoint(
spark,
model_full_name,
catalog,
logging_schema,
endpoint_name,
host,
token,
):
'''Deploy a model serving endpoint'''
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
_config = {
"served_models": [{
"model_name": model_full_name,
"model_version": get_latest_model_version(model_full_name),
"workload_type": "CPU",
"workload_size": "Small",
"scale_to_zero_enabled": "true",
"environment_vars": {
"DATABRICKS_HOST": host,
"DATABRICKS_TOKEN": token,
"ENABLE_MLFLOW_TRACING": "true",
}
}],
"auto_capture_config": {
"catalog_name": catalog,
"schema_name": logging_schema,
"table_name_prefix": endpoint_name,
}
}
try:
r = client.get_endpoint(endpoint_name)
endpoint = client.update_endpoint(
endpoint="chat",
config=_config,
)
except:
# Make sure to the schema for the inference table exists
_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{logging_schema}")
# Make sure to drop the inference table it exists
_ = spark.sql(f"DROP TABLE IF EXISTS {catalog}.{logging_schema}.`{endpoint_name}_payload`")
endpoint = client.create_endpoint(
name = endpoint_name,
config = _config,
)
def wait_for_model_serving_endpoint_to_be_ready(endpoint_name):
'''Wait for a model serving endpoint to be ready'''
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointStateReady, EndpointStateConfigUpdate
import time
# Wait for it to be ready
w = WorkspaceClient()
state = ""
for i in range(400):
state = w.serving_endpoints.get(endpoint_name).state
if state.config_update == EndpointStateConfigUpdate.IN_PROGRESS:
if i % 40 == 0:
print(f"Waiting for endpoint to deploy {endpoint_name}. Current state: {state}")
time.sleep(10)
elif state.ready == EndpointStateReady.READY:
print('endpoint ready.')
return
else:
break
raise Exception(f"Couldn't start the endpoint, timeout, please check your endpoint for more details: {state}")
def send_request_to_endpoint(endpoint_name, data):
'''Send a request to a model serving endpoint'''
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
response = client.predict(endpoint=endpoint_name, inputs=data)
return response
def delete_model_serving_endpoint(endpoint_name):
'''Delete a model serving endpoint'''
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
r = client.delete_endpoint(endpoint_name)