Skip to content

Commit

Permalink
end implement support for text generation web ui
Browse files Browse the repository at this point in the history
  • Loading branch information
psyb0t committed Jul 2, 2024
1 parent bd5870c commit 24c9f3d
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 38 deletions.
64 changes: 33 additions & 31 deletions src/ezpyai/llm/providers/_http_clients/text_generation_web_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ class HTTPClientTextGenerationWebUI:
"""

# Endpoint constants
ENDPOINT_ENCODE = "/v1/internal/encode"
ENDPOINT_DECODE = "/v1/internal/decode"
ENDPOINT_TOKEN_COUNT = "/v1/internal/token-count"
ENDPOINT_LOGITS = "/v1/internal/logits"
ENDPOINT_CHAT_PROMPT = "/v1/internal/chat-prompt"
ENDPOINT_STOP_GENERATION = "/v1/internal/stop-generation"
ENDPOINT_MODEL_INFO = "/v1/internal/model/info"
ENDPOINT_MODEL_LIST = "/v1/internal/model/list"
ENDPOINT_MODEL_LOAD = "/v1/internal/model/load"
ENDPOINT_MODEL_UNLOAD = "/v1/internal/model/unload"
ENDPOINT_LORA_LIST = "/v1/internal/lora/list"
ENDPOINT_LORA_LOAD = "/v1/internal/lora/load"
ENDPOINT_LORA_UNLOAD = "/v1/internal/lora/unload"
_ENDPOINT_ENCODE = "/v1/internal/encode"
_ENDPOINT_DECODE = "/v1/internal/decode"
_ENDPOINT_TOKEN_COUNT = "/v1/internal/token-count"
_ENDPOINT_LOGITS = "/v1/internal/logits"
_ENDPOINT_CHAT_PROMPT = "/v1/internal/chat-prompt"
_ENDPOINT_STOP_GENERATION = "/v1/internal/stop-generation"
_ENDPOINT_MODEL_INFO = "/v1/internal/model/info"
_ENDPOINT_MODEL_LIST = "/v1/internal/model/list"
_ENDPOINT_MODEL_LOAD = "/v1/internal/model/load"
_ENDPOINT_MODEL_UNLOAD = "/v1/internal/model/unload"
_ENDPOINT_LORA_LIST = "/v1/internal/lora/list"
_ENDPOINT_LORA_LOAD = "/v1/internal/lora/load"
_ENDPOINT_LORA_UNLOAD = "/v1/internal/lora/unload"

def __init__(self, base_url: str, api_key: str):
"""
Expand All @@ -48,7 +48,7 @@ def __init__(self, base_url: str, api_key: str):

self.base_url = base_url
self.headers = {
HTTP_HEADER_AUTHORIZATION: api_key,
HTTP_HEADER_AUTHORIZATION: f"Bearer {api_key}",
HTTP_HEADER_CONTENT_TYPE: HTTP_CONTENT_TYPE_JSON,
}

Expand Down Expand Up @@ -95,7 +95,7 @@ def encode_tokens(self, text: str) -> Dict[str, Union[List[int], int]]:
"""

return self._make_request(
HTTP_METHOD_POST, self.ENDPOINT_ENCODE, {"text": text}
HTTP_METHOD_POST, self._ENDPOINT_ENCODE, {"text": text}
)

def decode_tokens(self, tokens: List[int]) -> Dict[str, str]:
Expand All @@ -110,7 +110,7 @@ def decode_tokens(self, tokens: List[int]) -> Dict[str, str]:
"""

return self._make_request(
HTTP_METHOD_POST, self.ENDPOINT_DECODE, {"tokens": tokens}
HTTP_METHOD_POST, self._ENDPOINT_DECODE, {"tokens": tokens}
)

def count_tokens(self, text: str) -> Dict[str, int]:
Expand All @@ -125,7 +125,7 @@ def count_tokens(self, text: str) -> Dict[str, int]:
"""

return self._make_request(
HTTP_METHOD_POST, self.ENDPOINT_TOKEN_COUNT, {"text": text}
HTTP_METHOD_POST, self._ENDPOINT_TOKEN_COUNT, {"text": text}
)

def get_logits(self, prompt: str, **kwargs: Any) -> Dict[str, Dict[str, float]]:
Expand All @@ -142,7 +142,7 @@ def get_logits(self, prompt: str, **kwargs: Any) -> Dict[str, Dict[str, float]]:

data = {"prompt": prompt, **kwargs}

return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_LOGITS, data)
return self._make_request(HTTP_METHOD_POST, self._ENDPOINT_LOGITS, data)

def get_chat_prompt(
self, messages: List[Dict[str, Any]], **kwargs: Any
Expand All @@ -160,7 +160,7 @@ def get_chat_prompt(

data = {"messages": messages, **kwargs}

return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_CHAT_PROMPT, data)
return self._make_request(HTTP_METHOD_POST, self._ENDPOINT_CHAT_PROMPT, data)

def stop_generation(self) -> str:
"""
Expand All @@ -170,7 +170,7 @@ def stop_generation(self) -> str:
str: A string indicating the result of the operation.
"""

return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_STOP_GENERATION)
return self._make_request(HTTP_METHOD_POST, self._ENDPOINT_STOP_GENERATION)

def get_model_info(self) -> Dict[str, Union[str, List[str]]]:
"""
Expand All @@ -180,7 +180,7 @@ def get_model_info(self) -> Dict[str, Union[str, List[str]]]:
Dict[str, Union[str, List[str]]]: A dictionary containing model information.
"""

return self._make_request(HTTP_METHOD_GET, self.ENDPOINT_MODEL_INFO)
return self._make_request(HTTP_METHOD_GET, self._ENDPOINT_MODEL_INFO)

def list_models(self) -> List[str]:
"""
Expand All @@ -190,7 +190,7 @@ def list_models(self) -> List[str]:
List[str]: A list of model names.
"""

return self._make_request(HTTP_METHOD_GET, self.ENDPOINT_MODEL_LIST)[
return self._make_request(HTTP_METHOD_GET, self._ENDPOINT_MODEL_LIST)[
"model_names"
]

Expand Down Expand Up @@ -218,7 +218,7 @@ def load_model(
"settings": settings or {},
}

return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_MODEL_LOAD, data)
return self._make_request(HTTP_METHOD_POST, self._ENDPOINT_MODEL_LOAD, data)

def unload_model(self) -> str:
"""
Expand All @@ -228,31 +228,33 @@ def unload_model(self) -> str:
str: A string indicating the result of the operation.
"""

return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_MODEL_UNLOAD)
return self._make_request(HTTP_METHOD_POST, self._ENDPOINT_MODEL_UNLOAD)

def list_loras(self) -> Dict[str, List[str]]:
def list_loras(self) -> List[str]:
"""
Get a list of available LoRA adapters.
Returns:
Dict[str, List[str]]: A dictionary containing a list of LoRA names.
List[str]: A list of LoRA adapter names.
"""

return self._make_request(HTTP_METHOD_GET, self.ENDPOINT_LORA_LIST)
return self._make_request(HTTP_METHOD_GET, self._ENDPOINT_LORA_LIST)[
"lora_names"
]

def load_loras(self, lora_names: List[str]) -> str:
def load_loras(self, loras: List[str]) -> str:
"""
Load specific LoRA adapters.
Args:
lora_names (List[str]): A list of LoRA adapter names to load.
loras (List[str]): A list of LoRA adapter names to load.
Returns:
str: A string indicating the result of the operation.
"""

return self._make_request(
HTTP_METHOD_POST, self.ENDPOINT_LORA_LOAD, {"lora_names": lora_names}
HTTP_METHOD_POST, self._ENDPOINT_LORA_LOAD, {"lora_names": loras}
)

def unload_loras(self) -> str:
Expand All @@ -263,4 +265,4 @@ def unload_loras(self) -> str:
str: A string indicating the result of the operation.
"""

return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_LORA_UNLOAD)
return self._make_request(HTTP_METHOD_POST, self._ENDPOINT_LORA_UNLOAD)
61 changes: 59 additions & 2 deletions src/ezpyai/llm/providers/_llm_provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
from abc import ABC, abstractmethod
from ezpyai.llm.prompt import Prompt
from typing import Union, Dict, List, Any

from ezpyai.llm.prompt import Prompt
from ezpyai.llm.providers.exceptions import JSONParseError


_STRUCTURED_RESPONSE_OUTPUT_INSTRUCTIONS = (
"Output instructions: your output must be JSON-formatted similar to the following:"
Expand All @@ -22,12 +24,40 @@ def get_structured_response(


class BaseLLMProvider(LLMProvider):
"""
Base class for LLM providers.
"""

@abstractmethod
def get_response(self, _: Prompt) -> str:
"""
Get the response for the given prompt.
Args:
prompt (Prompt): The input prompt.
Returns:
str: The response.
"""

return ""

def _validate_response_format(
self, data: Any, response_format: Union[Dict, List]
) -> bool:
"""
Validate the response format.
Args:
data (Any): The data to validate.
response_format (Union[Dict, List]): The response format.
Returns:
bool: True if the response format is valid, False otherwise.
"""

if not response_format:
return True
if isinstance(response_format, dict):
if not isinstance(data, dict):
return False
Expand All @@ -44,6 +74,16 @@ def _validate_response_format(
)

def remove_artifacts(self, response: str) -> str:
"""
Remove artifacts from the response.
Args:
response (str): The response to remove artifacts from.
Returns:
str: The response without artifacts.
"""

artifacts = ["```json", "```"]
for artifact in artifacts:
response = response.replace(artifact, "")
Expand All @@ -53,6 +93,19 @@ def remove_artifacts(self, response: str) -> str:
def get_structured_response(
self, prompt: Prompt, response_format: Union[List, Dict]
) -> Union[List, Dict]:
"""
Get the structured response for the given prompt.
Args:
prompt (Prompt): The input prompt.
response_format (Union[Dict, List]): The response format.
Returns:
Union[List, Dict]: The structured response.
Raises:
JSONParseError: If the response cannot be parsed as JSON.
"""
prompt = Prompt(
user_message=prompt.get_user_message(),
context=prompt.get_context(),
Expand All @@ -61,7 +114,11 @@ def get_structured_response(

response = self.remove_artifacts(self.get_response(prompt)).strip()

structured_resp = json.loads(response)
try:
structured_resp = json.loads(response)
except:
raise JSONParseError(f"Failed to parse structured response: {response}")

if self._validate_response_format(structured_resp, response_format):
return structured_resp

Expand Down
14 changes: 14 additions & 0 deletions src/ezpyai/llm/providers/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,17 @@ class UnsupportedModelError(Exception):

def __init__(self, message="Unsupported model", *args):
super().__init__(message, *args)


class UnsupportedLoraError(Exception):
"""Exception raised when an unsupported lora is used."""

def __init__(self, message="Unsupported lora", *args):
super().__init__(message, *args)


class JSONParseError(Exception):
"""Exception raised when a JSON parse error occurs."""

def __init__(self, message="JSON parse error", *args):
super().__init__(message, *args)
2 changes: 1 addition & 1 deletion src/ezpyai/llm/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)

_DEFAULT_MODEL: str = MODEL_GPT_3_5_TURBO
_DEFAULT_TEMPERATURE: float = 0.5
_DEFAULT_TEMPERATURE: float = 0.7
_DEFAULT_MAX_TOKENS: int = 150


Expand Down
54 changes: 50 additions & 4 deletions src/ezpyai/llm/providers/text_generation_web_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL,
)

from ezpyai.llm.providers.exceptions import UnsupportedModelError
from ezpyai.llm.providers.exceptions import UnsupportedModelError, UnsupportedLoraError

from ezpyai.llm.providers.openai import (
LLMProviderOpenAI,
Expand All @@ -24,13 +24,24 @@
class LLMProviderTextGenerationWebUI(LLMProviderOpenAI):
"""
LLM provider for Text Generation Web UI's OpenAI compatible API.
"""
_loaded_model: str = None
Args:
model (str): The model to use.
loras (List[str]): The loras to use.
base_url (str): The base URL of the API.
api_key (str): The API key for authentication.
temperature (float): The temperature to use.
max_tokens (int): The maximum number of tokens to generate.
Raises:
UnsupportedModelError: If the model is not supported.
UnsupportedLoraError: If any of the loras is not supported.
"""

def __init__(
self,
model: str,
loras: List[str] = None,
base_url: str = None,
api_key: str = None,
temperature: float = _DEFAULT_TEMPERATURE,
Expand All @@ -43,7 +54,7 @@ def __init__(
api_key = os.getenv(ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY)

self._client = OpenAI(
base_url=base_url,
base_url=f"{base_url}/v1",
api_key=api_key,
)

Expand All @@ -52,12 +63,28 @@ def __init__(
api_key=api_key,
)

self._cleanup()

self._ensure_model_available(model=model)
self._ensure_loras_exist(loras=loras)

self._internal_client.load_model(model_name=model)

if loras:
self._internal_client.load_loras(loras=loras)

self._model = model
self._temperature = temperature
self._max_tokens = max_tokens

def _cleanup(self):
"""
Cleanup any resources used by the LLM provider.
"""

self._internal_client.unload_loras()
self._internal_client.unload_model()

def _ensure_model_available(self, model: str):
"""
Ensure that the given model is available.
Expand All @@ -75,3 +102,22 @@ def _ensure_model_available(self, model: str):
raise UnsupportedModelError(
f"Model {model} not available. Available models: {available_models}"
)

def _ensure_loras_exist(self, loras: List[str]):
"""
Ensure that the given loras exist.
Args:
loras (List[str]): The loras to check.
Raises:
UnsupportedModelError: If the loras are not available.
"""

available_loras = self._internal_client.list_loras()

for lora in loras:
if lora not in available_loras:
raise UnsupportedLoraError(
f"Lora {lora} not available. Available loras: {available_loras}"
)

0 comments on commit 24c9f3d

Please sign in to comment.