From bd5870c02f0f02384edd5e3464a0dd2bce466647 Mon Sep 17 00:00:00 2001 From: Ciprian Mandache Date: Tue, 2 Jul 2024 02:39:53 +0300 Subject: [PATCH] refactor + add text generation webui internal client --- README.md | 1 + src/ezpyai/_constants.py | 32 ++- src/ezpyai/_logger.py | 4 +- .../llm/knowledge/_knowledge_gatherer.py | 13 +- src/ezpyai/llm/knowledge/chroma_db.py | 8 +- src/ezpyai/llm/knowledge/knowledge_item.py | 16 +- .../llm/providers/_http_clients/__init__.py | 0 .../_http_clients/text_generation_web_ui.py | 266 ++++++++++++++++++ src/ezpyai/llm/providers/_llm_provider.py | 4 - src/ezpyai/llm/providers/exceptions.py | 5 + src/ezpyai/llm/providers/openai.py | 22 +- .../llm/providers/text_generation_web_ui.py | 49 +++- 12 files changed, 379 insertions(+), 41 deletions(-) create mode 100644 src/ezpyai/llm/providers/_http_clients/__init__.py create mode 100644 src/ezpyai/llm/providers/_http_clients/text_generation_web_ui.py create mode 100644 src/ezpyai/llm/providers/exceptions.py diff --git a/README.md b/README.md index 196ab1b..5b53565 100644 --- a/README.md +++ b/README.md @@ -44,3 +44,4 @@ TODO - prompt - add prompt enhancer - prompt - add prompt compression using LLMLingua - llm provider - optionally use NuExtract text to json model to get structured response(like instead of embedding the instruction to the base llm, take the resp and send it to nuextract) +- llm provider - text generation web ui - support multiple instances hosted in diff locations diff --git a/src/ezpyai/_constants.py b/src/ezpyai/_constants.py index 689fc9a..5c341ef 100644 --- a/src/ezpyai/_constants.py +++ b/src/ezpyai/_constants.py @@ -1,15 +1,25 @@ -from typing import Dict +LIB_NAME: str = "ezpyai" -_LIB_NAME: str = "ezpyai" +# ENVIRONMENT VARIABLES +ENV_VAR_NAME_OPENAI_API_KEY: str = "OPENAI_API_KEY" +ENV_VAR_NAME_OPENAI_ORGANIZATION: str = "OPENAI_ORGANIZATION" +ENV_VAR_NAME_OPENAI_PROJECT: str = "OPENAI_PROJECT" +ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL: str = "TEXT_GENERATION_WEBUI_BASE_URL" +ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY: str = "TEXT_GENERATION_WEBUI_API_KEY" -_ENV_VAR_NAME_OPENAI_API_KEY: str = "OPENAI_API_KEY" -_ENV_VAR_NAME_OPENAI_ORGANIZATION: str = "OPENAI_ORGANIZATION" -_ENV_VAR_NAME_OPENAI_PROJECT: str = "OPENAI_PROJECT" -_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY: str = "TEXT_GENERATION_WEBUI_API_KEY" -_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL: str = "TEXT_GENERATION_WEBUI_BASE_URL" +# DICTIONARY KEYS +DICT_KEY_ID: str = "id" +DICT_KEY_METADATA: str = "metadata" +DICT_KEY_CONTENT: str = "content" +DICT_KEY_SUMMARY: str = "summary" +# HTTP METHODS +HTTP_METHOD_GET: str = "GET" +HTTP_METHOD_POST: str = "POST" -_DICT_KEY_ID: str = "id" -_DICT_KEY_METADATA: str = "metadata" -_DICT_KEY_CONTENT: str = "content" -_DICT_KEY_SUMMARY: str = "summary" +# HTTP CONTENT TYPES +HTTP_CONTENT_TYPE_JSON: str = "application/json" + +# HTTP HEADERS +HTTP_HEADER_CONTENT_TYPE: str = "Content-Type" +HTTP_HEADER_AUTHORIZATION: str = "Authorization" diff --git a/src/ezpyai/_logger.py b/src/ezpyai/_logger.py index 645bfab..a012e03 100644 --- a/src/ezpyai/_logger.py +++ b/src/ezpyai/_logger.py @@ -1,4 +1,4 @@ import logging -from ezpyai._constants import _LIB_NAME +from ezpyai._constants import LIB_NAME -logger = logging.getLogger(_LIB_NAME) +logger = logging.getLogger(LIB_NAME) diff --git a/src/ezpyai/llm/knowledge/_knowledge_gatherer.py b/src/ezpyai/llm/knowledge/_knowledge_gatherer.py index 96481b5..a3c8bda 100644 --- a/src/ezpyai/llm/knowledge/_knowledge_gatherer.py +++ b/src/ezpyai/llm/knowledge/_knowledge_gatherer.py @@ -14,7 +14,7 @@ from PyPDF2 import PdfReader from docx import Document from ezpyai._logger import logger -from ezpyai._constants import _DICT_KEY_SUMMARY +from ezpyai._constants import DICT_KEY_SUMMARY from ezpyai.llm.providers._llm_provider import LLMProvider from ezpyai.llm.prompt import Prompt, get_summarizer_prompt from ezpyai.llm.knowledge.knowledge_item import KnowledgeItem @@ -45,7 +45,12 @@ class KnowledgeGatherer: """ def __init__(self, summarizer: LLMProvider = None) -> None: - """Initialize the KnowledgeGatherer with an empty _items dictionary.""" + """ + Initialize the KnowledgeGatherer with an empty _items dictionary. + + Args: + summarizer (LLMProvider, optional): The LLMProvider hosting the summarizer model to use for knowledge collection. Defaults to None. + """ self._items: Dict[str, KnowledgeItem] = {} self._summarizer: LLMProvider = summarizer @@ -114,8 +119,8 @@ def _summarize(self, knowledge_item: KnowledgeItem) -> None: prompt: Prompt = get_summarizer_prompt(knowledge_item.content) knowledge_item.summary = self._summarizer.get_structured_response( - prompt, response_format={_DICT_KEY_SUMMARY: ""} - )[_DICT_KEY_SUMMARY] + prompt, response_format={DICT_KEY_SUMMARY: ""} + )[DICT_KEY_SUMMARY] logger.debug(f"Summarized knowledge item: {knowledge_item}") diff --git a/src/ezpyai/llm/knowledge/chroma_db.py b/src/ezpyai/llm/knowledge/chroma_db.py index bb7ede2..37c12c4 100644 --- a/src/ezpyai/llm/knowledge/chroma_db.py +++ b/src/ezpyai/llm/knowledge/chroma_db.py @@ -3,7 +3,7 @@ from typing import Dict, List from ezpyai._logger import logger -from ezpyai._constants import _DICT_KEY_SUMMARY +from ezpyai._constants import DICT_KEY_SUMMARY from ezpyai.llm.providers._llm_provider import LLMProvider from ezpyai.llm.knowledge._knowledge_db import BaseKnowledgeDB from ezpyai.llm.knowledge._knowledge_gatherer import KnowledgeGatherer @@ -88,7 +88,7 @@ def store( logger.debug(f"Pre-processing item: {knowledge_item}") metadata = knowledge_item.metadata - metadata[_DICT_KEY_SUMMARY] = knowledge_item.summary + metadata[DICT_KEY_SUMMARY] = knowledge_item.summary document_ids.append(knowledge_item.id) documents.append(knowledge_item.content) @@ -141,8 +141,8 @@ def search( for i in range(len(documents)): summary = "" - if _DICT_KEY_SUMMARY in metadatas[i]: - summary = metadatas[i].pop(_DICT_KEY_SUMMARY, "") + if DICT_KEY_SUMMARY in metadatas[i]: + summary = metadatas[i].pop(DICT_KEY_SUMMARY, "") knowledge_items.append( KnowledgeItem( diff --git a/src/ezpyai/llm/knowledge/knowledge_item.py b/src/ezpyai/llm/knowledge/knowledge_item.py index 3fccc19..dd45c5e 100644 --- a/src/ezpyai/llm/knowledge/knowledge_item.py +++ b/src/ezpyai/llm/knowledge/knowledge_item.py @@ -1,9 +1,9 @@ from typing import Dict from ezpyai._constants import ( - _DICT_KEY_ID, - _DICT_KEY_METADATA, - _DICT_KEY_CONTENT, - _DICT_KEY_SUMMARY, + DICT_KEY_ID, + DICT_KEY_METADATA, + DICT_KEY_CONTENT, + DICT_KEY_SUMMARY, ) @@ -38,8 +38,8 @@ def __str__(self): def to_dict(self) -> Dict[str, str]: return { - _DICT_KEY_ID: self.id, - _DICT_KEY_METADATA: self.metadata, - _DICT_KEY_CONTENT: self.content, - _DICT_KEY_SUMMARY: self.summary, + DICT_KEY_ID: self.id, + DICT_KEY_METADATA: self.metadata, + DICT_KEY_CONTENT: self.content, + DICT_KEY_SUMMARY: self.summary, } diff --git a/src/ezpyai/llm/providers/_http_clients/__init__.py b/src/ezpyai/llm/providers/_http_clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ezpyai/llm/providers/_http_clients/text_generation_web_ui.py b/src/ezpyai/llm/providers/_http_clients/text_generation_web_ui.py new file mode 100644 index 0000000..2f1b089 --- /dev/null +++ b/src/ezpyai/llm/providers/_http_clients/text_generation_web_ui.py @@ -0,0 +1,266 @@ +import requests + +from typing import List, Dict, Any, Optional, Union + +from ezpyai._constants import ( + HTTP_METHOD_GET, + HTTP_METHOD_POST, + HTTP_CONTENT_TYPE_JSON, + HTTP_HEADER_CONTENT_TYPE, + HTTP_HEADER_AUTHORIZATION, +) + +from ezpyai._logger import logger + + +class HTTPClientTextGenerationWebUI: + """ + A client for interacting with Text Generation Web UI's internal API endpoints. + """ + + # Endpoint constants + ENDPOINT_ENCODE = "/v1/internal/encode" + ENDPOINT_DECODE = "/v1/internal/decode" + ENDPOINT_TOKEN_COUNT = "/v1/internal/token-count" + ENDPOINT_LOGITS = "/v1/internal/logits" + ENDPOINT_CHAT_PROMPT = "/v1/internal/chat-prompt" + ENDPOINT_STOP_GENERATION = "/v1/internal/stop-generation" + ENDPOINT_MODEL_INFO = "/v1/internal/model/info" + ENDPOINT_MODEL_LIST = "/v1/internal/model/list" + ENDPOINT_MODEL_LOAD = "/v1/internal/model/load" + ENDPOINT_MODEL_UNLOAD = "/v1/internal/model/unload" + ENDPOINT_LORA_LIST = "/v1/internal/lora/list" + ENDPOINT_LORA_LOAD = "/v1/internal/lora/load" + ENDPOINT_LORA_UNLOAD = "/v1/internal/lora/unload" + + def __init__(self, base_url: str, api_key: str): + """ + Initialize the API client. + + Args: + base_url (str): The base URL of the API. + api_key (str): The API key for authentication. + """ + + logger.debug( + f"Initializing HTTPClientTextGenerationWebUI with base_url={base_url}" + ) + + self.base_url = base_url + self.headers = { + HTTP_HEADER_AUTHORIZATION: api_key, + HTTP_HEADER_CONTENT_TYPE: HTTP_CONTENT_TYPE_JSON, + } + + def _make_request( + self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None + ) -> Any: + """ + Make an HTTP request to the API. + + Args: + method (str): The HTTP method (GET, POST, etc.). + endpoint (str): The API endpoint. + data (Optional[Dict[str, Any]]): The request payload (for POST requests). + + Returns: + Any: The JSON response from the API. + """ + + logger.debug(f"Making request to {self.base_url}{endpoint} with data={data}") + + url = f"{self.base_url}{endpoint}" + response = requests.request(method, url, headers=self.headers, json=data) + + logger.debug( + f"""Response: +Status: {response.status_code} +Content: {response.content} +""" + ) + + response.raise_for_status() + + return response.json() + + def encode_tokens(self, text: str) -> Dict[str, Union[List[int], int]]: + """ + Encode text into tokens. + + Args: + text (str): The text to encode. + + Returns: + Dict[str, Union[List[int], int]]: A dictionary containing the tokens and token count. + """ + + return self._make_request( + HTTP_METHOD_POST, self.ENDPOINT_ENCODE, {"text": text} + ) + + def decode_tokens(self, tokens: List[int]) -> Dict[str, str]: + """ + Decode tokens back into text. + + Args: + tokens (List[int]): The list of tokens to decode. + + Returns: + Dict[str, str]: A dictionary containing the decoded text. + """ + + return self._make_request( + HTTP_METHOD_POST, self.ENDPOINT_DECODE, {"tokens": tokens} + ) + + def count_tokens(self, text: str) -> Dict[str, int]: + """ + Count the number of tokens in the given text. + + Args: + text (str): The text to count tokens for. + + Returns: + Dict[str, int]: A dictionary containing the token count. + """ + + return self._make_request( + HTTP_METHOD_POST, self.ENDPOINT_TOKEN_COUNT, {"text": text} + ) + + def get_logits(self, prompt: str, **kwargs: Any) -> Dict[str, Dict[str, float]]: + """ + Get the logits for the given prompt. + + Args: + prompt (str): The input prompt. + **kwargs (Any): Additional parameters for the logits calculation. + + Returns: + Dict[str, Dict[str, float]]: A dictionary containing the logits. + """ + + data = {"prompt": prompt, **kwargs} + + return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_LOGITS, data) + + def get_chat_prompt( + self, messages: List[Dict[str, Any]], **kwargs: Any + ) -> Dict[str, str]: + """ + Generate a chat prompt from the given messages. + + Args: + messages (List[Dict[str, Any]]): A list of message dictionaries. + **kwargs (Any): Additional parameters for prompt generation. + + Returns: + Dict[str, str]: A dictionary containing the generated prompt. + """ + + data = {"messages": messages, **kwargs} + + return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_CHAT_PROMPT, data) + + def stop_generation(self) -> str: + """ + Stop the current text generation process. + + Returns: + str: A string indicating the result of the operation. + """ + + return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_STOP_GENERATION) + + def get_model_info(self) -> Dict[str, Union[str, List[str]]]: + """ + Get information about the currently loaded model. + + Returns: + Dict[str, Union[str, List[str]]]: A dictionary containing model information. + """ + + return self._make_request(HTTP_METHOD_GET, self.ENDPOINT_MODEL_INFO) + + def list_models(self) -> List[str]: + """ + Get a list of available models. + + Returns: + List[str]: A list of model names. + """ + + return self._make_request(HTTP_METHOD_GET, self.ENDPOINT_MODEL_LIST)[ + "model_names" + ] + + def load_model( + self, + model_name: str, + args: Optional[Dict[str, Any]] = None, + settings: Optional[Dict[str, Any]] = None, + ) -> str: + """ + Load a specific model. + + Args: + model_name (str): The name of the model to load. + args (Optional[Dict[str, Any]]): Optional arguments for model loading. + settings (Optional[Dict[str, Any]]): Optional settings for the model. + + Returns: + str: A string indicating the result of the operation. + """ + + data = { + "model_name": model_name, + "args": args or {}, + "settings": settings or {}, + } + + return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_MODEL_LOAD, data) + + def unload_model(self) -> str: + """ + Unload the currently loaded model. + + Returns: + str: A string indicating the result of the operation. + """ + + return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_MODEL_UNLOAD) + + def list_loras(self) -> Dict[str, List[str]]: + """ + Get a list of available LoRA adapters. + + Returns: + Dict[str, List[str]]: A dictionary containing a list of LoRA names. + """ + + return self._make_request(HTTP_METHOD_GET, self.ENDPOINT_LORA_LIST) + + def load_loras(self, lora_names: List[str]) -> str: + """ + Load specific LoRA adapters. + + Args: + lora_names (List[str]): A list of LoRA adapter names to load. + + Returns: + str: A string indicating the result of the operation. + """ + + return self._make_request( + HTTP_METHOD_POST, self.ENDPOINT_LORA_LOAD, {"lora_names": lora_names} + ) + + def unload_loras(self) -> str: + """ + Unload all currently loaded LoRA adapters. + + Returns: + str: A string indicating the result of the operation. + """ + + return self._make_request(HTTP_METHOD_POST, self.ENDPOINT_LORA_UNLOAD) diff --git a/src/ezpyai/llm/providers/_llm_provider.py b/src/ezpyai/llm/providers/_llm_provider.py index b1e1ec7..77307f1 100644 --- a/src/ezpyai/llm/providers/_llm_provider.py +++ b/src/ezpyai/llm/providers/_llm_provider.py @@ -20,10 +20,6 @@ def get_structured_response( ) -> Union[List, Dict]: pass - @abstractmethod - def remove_artifacts(self, response: str) -> str: - pass - class BaseLLMProvider(LLMProvider): def get_response(self, _: Prompt) -> str: diff --git a/src/ezpyai/llm/providers/exceptions.py b/src/ezpyai/llm/providers/exceptions.py new file mode 100644 index 0000000..252bffb --- /dev/null +++ b/src/ezpyai/llm/providers/exceptions.py @@ -0,0 +1,5 @@ +class UnsupportedModelError(Exception): + """Exception raised when an unsupported model is used.""" + + def __init__(self, message="Unsupported model", *args): + super().__init__(message, *args) diff --git a/src/ezpyai/llm/providers/openai.py b/src/ezpyai/llm/providers/openai.py index b87ade5..fa778b9 100644 --- a/src/ezpyai/llm/providers/openai.py +++ b/src/ezpyai/llm/providers/openai.py @@ -7,6 +7,12 @@ from ezpyai.llm.providers._llm_provider import BaseLLMProvider from ezpyai.llm.prompt import Prompt +from ezpyai._constants import ( + ENV_VAR_NAME_OPENAI_API_KEY, + ENV_VAR_NAME_OPENAI_ORGANIZATION, + ENV_VAR_NAME_OPENAI_PROJECT, +) + # Constants for OpenAI GPT models with context window sizes and specific versions MODEL_GPT_4O: str = ( @@ -47,15 +53,25 @@ class LLMProviderOpenAI(BaseLLMProvider): + def __init__( self, model: str = _DEFAULT_MODEL, temperature: float = _DEFAULT_TEMPERATURE, max_tokens: int = _DEFAULT_MAX_TOKENS, - api_key: str = os.getenv("OPENAI_API_KEY"), - organization: str = os.getenv("OPENAI_ORGANIZATION"), - project: str = os.getenv("OPENAI_PROJECT"), + api_key: str = None, + organization: str = None, + project: str = None, ) -> None: + if api_key is None: + api_key = os.getenv(ENV_VAR_NAME_OPENAI_API_KEY) + + if organization is None: + organization = os.getenv(ENV_VAR_NAME_OPENAI_ORGANIZATION) + + if project is None: + project = os.getenv(ENV_VAR_NAME_OPENAI_PROJECT) + self._client = _OpenAI( api_key=api_key, organization=organization, diff --git a/src/ezpyai/llm/providers/text_generation_web_ui.py b/src/ezpyai/llm/providers/text_generation_web_ui.py index a978941..2e189be 100644 --- a/src/ezpyai/llm/providers/text_generation_web_ui.py +++ b/src/ezpyai/llm/providers/text_generation_web_ui.py @@ -1,17 +1,25 @@ import os -import openai from openai import OpenAI +from typing import List + from ezpyai._constants import ( - _ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY, - _ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL, + ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY, + ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL, ) + +from ezpyai.llm.providers.exceptions import UnsupportedModelError + from ezpyai.llm.providers.openai import ( LLMProviderOpenAI, _DEFAULT_TEMPERATURE, _DEFAULT_MAX_TOKENS, ) +from ezpyai.llm.providers._http_clients.text_generation_web_ui import ( + HTTPClientTextGenerationWebUI, +) + class LLMProviderTextGenerationWebUI(LLMProviderOpenAI): """ @@ -23,16 +31,47 @@ class LLMProviderTextGenerationWebUI(LLMProviderOpenAI): def __init__( self, model: str, - base_url: str = os.getenv(_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL), + base_url: str = None, + api_key: str = None, temperature: float = _DEFAULT_TEMPERATURE, max_tokens: int = _DEFAULT_MAX_TOKENS, - api_key: str = os.getenv(_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY), ) -> None: + if base_url is None: + base_url = os.getenv(ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL) + + if api_key is None: + api_key = os.getenv(ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY) + self._client = OpenAI( base_url=base_url, api_key=api_key, ) + self._internal_client = HTTPClientTextGenerationWebUI( + base_url=base_url, + api_key=api_key, + ) + + self._ensure_model_available(model=model) + self._model = model self._temperature = temperature self._max_tokens = max_tokens + + def _ensure_model_available(self, model: str): + """ + Ensure that the given model is available. + + Args: + model (str): The model name to check. + + Raises: + UnsupportedModelError: If the model is not available. + """ + + available_models = self._internal_client.list_models() + + if model not in available_models: + raise UnsupportedModelError( + f"Model {model} not available. Available models: {available_models}" + )