Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for TEI API key authentication #11006

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
- variable: api_key
label:
en_US: API Key
type: secret-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,13 @@ def _invoke(

server_url = server_url.removesuffix("/")

headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

try:
results = TeiHelper.invoke_rerank(server_url, query, docs)
results = TeiHelper.invoke_rerank(server_url, query, docs, headers)

rerank_documents = []
for result in results:
Expand Down Expand Up @@ -80,7 +85,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
try:
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
if extra_args.model_type != "reranker":
raise CredentialsValidateFailedError("Current model is not a rerank model")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ def __init__(self, model_type: str, max_input_length: int, max_client_batch_size

class TeiHelper:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
def get_tei_extra_parameter(
server_url: str, model_name: str, headers: Optional[dict] = None
) -> TeiModelExtraParameter:
TeiHelper._clean_cache()
with cache_lock:
if model_name not in cache:
cache[model_name] = {
"expires": time() + 300,
"value": TeiHelper._get_tei_extra_parameter(server_url),
"value": TeiHelper._get_tei_extra_parameter(server_url, headers),
}
return cache[model_name]["value"]

Expand All @@ -47,7 +49,7 @@ def _clean_cache() -> None:
pass

@staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
"""
get tei model extra parameter like model_type, max_input_length, max_batch_requests
"""
Expand All @@ -61,7 +63,7 @@ def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
session.mount("https://", HTTPAdapter(max_retries=3))

try:
response = session.get(url, timeout=10)
response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
if response.status_code != 200:
Expand All @@ -86,7 +88,7 @@ def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
)

@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
"""
Invoke tokenize endpoint

Expand Down Expand Up @@ -114,15 +116,15 @@ def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
:param server_url: server url
:param texts: texts to tokenize
"""
resp = httpx.post(
f"{server_url}/tokenize",
json={"inputs": texts},
)
url = f"{server_url}/tokenize"
json_data = {"inputs": texts}
resp = httpx.post(url, json=json_data, headers=headers)

resp.raise_for_status()
return resp.json()

@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
"""
Invoke embeddings endpoint

Expand All @@ -147,15 +149,14 @@ def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
:param texts: texts to embed
"""
# Use OpenAI compatible API here, which has usage tracking
resp = httpx.post(
f"{server_url}/v1/embeddings",
json={"input": texts},
)
url = f"{server_url}/v1/embeddings"
json_data = {"input": texts}
resp = httpx.post(url, json=json_data, headers=headers)
resp.raise_for_status()
return resp.json()

@staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
"""
Invoke rerank endpoint

Expand All @@ -173,10 +174,7 @@ def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
:param candidates: candidates to rerank
"""
params = {"query": query, "texts": docs, "return_text": True}

response = httpx.post(
server_url + "/rerank",
json=params,
)
url = f"{server_url}/rerank"
response = httpx.post(url, json=params, headers=headers)
response.raise_for_status()
return response.json()
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def _invoke(

server_url = server_url.removesuffix("/")

headers = {"Content-Type": "application/json"}
api_key = credentials["api_key"]
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
Expand All @@ -60,7 +64,7 @@ def _invoke(
used_tokens = 0

# get tokenized results from TEI
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)

for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size
Expand Down Expand Up @@ -97,7 +101,7 @@ def _invoke(
used_tokens = 0
for i in _iter:
iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
embeddings = results["data"]
embeddings = [embedding["embedding"] for embedding in embeddings]
batched_embeddings.extend(embeddings)
Expand Down Expand Up @@ -127,7 +131,11 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int

server_url = server_url.removesuffix("/")

batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
}

batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
num_tokens = sum(len(tokens) for tokens in batch_tokens)
return num_tokens

Expand All @@ -141,7 +149,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
try:
server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
headers = {"Content-Type": "application/json"}

api_key = credentials.get("api_key")

if api_key:
headers["Authorization"] = f"Bearer {api_key}"

extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
print(extra_args)
if extra_args.model_type != "embedding":
raise CredentialsValidateFailedError("Current model is not a embedding model")
Expand Down
1 change: 1 addition & 0 deletions api/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ env =
OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
TEI_RERANK_SERVER_URL = http://a.abc.com:11451
TEI_API_KEY = ttttttttttttttt
UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
XINFERENCE_CHAT_MODEL_UID = chat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def test_validate_credentials(setup_tei_mock):
model="reranker",
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

model.validate_credentials(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

Expand All @@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
texts=["hello", "world"],
user="abc-123",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def test_validate_credentials(setup_tei_mock):
model="embedding",
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

model.validate_credentials(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)

Expand All @@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
query="Who is Kasumi?",
docs=[
Expand Down