Skip to content

Commit

Permalink
add text-generation-webui llm provider + refactur
Browse files Browse the repository at this point in the history
  • Loading branch information
psyb0t committed Jun 23, 2024
1 parent 6b70607 commit bd47c2b
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 16 deletions.
9 changes: 9 additions & 0 deletions src/ezpyai/_constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from typing import Dict

_LIB_NAME: str = "ezpyai"

_ENV_VAR_NAME_OPENAI_API_KEY: str = "OPENAI_API_KEY"
_ENV_VAR_NAME_OPENAI_ORGANIZATION: str = "OPENAI_ORGANIZATION"
_ENV_VAR_NAME_OPENAI_PROJECT: str = "OPENAI_PROJECT"
_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY: str = "TEXT_GENERATION_WEBUI_API_KEY"
_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL: str = "TEXT_GENERATION_WEBUI_BASE_URL"


_DICT_KEY_ID: str = "id"
_DICT_KEY_METADATA: str = "metadata"
_DICT_KEY_CONTENT: str = "content"
Expand Down
15 changes: 12 additions & 3 deletions src/ezpyai/llm/knowledge/_knowledge_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List
from abc import ABC, abstractmethod
from ezpyai.llm._llm import LLM
from ezpyai.llm.providers._llm_provider import LLMProvider
from ezpyai.llm.knowledge.knowledge_item import KnowledgeItem


Expand Down Expand Up @@ -43,8 +43,17 @@ def get_dsn(self) -> str:
def destroy(self) -> None:
pass

def store(self, collection: str, data_path: str) -> None:
def store(
self,
collection: str,
data_path: str,
summarizer: LLMProvider = None,
) -> None:
pass

def search(self, collection: str, query: str) -> List[KnowledgeItem]:
def search(
self,
collection: str,
query: str,
) -> List[KnowledgeItem]:
pass
12 changes: 6 additions & 6 deletions src/ezpyai/llm/knowledge/_knowledge_gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from docx import Document
from ezpyai._logger import logger
from ezpyai._constants import _DICT_KEY_SUMMARY
from ezpyai.llm._llm import LLM
from ezpyai.llm.providers._llm_provider import LLMProvider
from ezpyai.llm.prompt import Prompt, get_summarizer_prompt
from ezpyai.llm.knowledge.knowledge_item import KnowledgeItem

Expand All @@ -31,6 +31,7 @@
_MIMETYPE_XML = "text/xml"


# TODO: implement semantic chunking
class KnowledgeGatherer:
"""
A class to gather knowledge from files within a directory or from a single file.
Expand All @@ -40,16 +41,15 @@ class KnowledgeGatherer:
It adds each file's data to the _items dictionary with its SHA256 hash as the key.
Attributes:
_items (Dict[str, KnowledgeItem]): A dictionary containing file paths
and their processed content indexed by SHA256 hashes of the content.
_summarizer (LLM): The LLM summarizer to use for knowledge collection.
_items (Dict[str, KnowledgeItem]): A dictionary of KnowledgeItem objects indexed by SHA256 hashes of their content.
_summarizer (LLMProvider): The LLMProvider hosting the summarizer model to use for knowledge collection.
"""

def __init__(self, summarizer: LLM = None) -> None:
def __init__(self, summarizer: LLMProvider = None) -> None:
"""Initialize the KnowledgeGatherer with an empty _items dictionary."""

self._items: Dict[str, KnowledgeItem] = {}
self._summarizer: LLM = summarizer
self._summarizer: LLMProvider = summarizer

logger.debug("KnowledgeGatherer initialized with an empty _items dictionary.")

Expand Down
8 changes: 5 additions & 3 deletions src/ezpyai/llm/knowledge/chroma_db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import chromadb
import json
import chromadb.utils.embedding_functions as ef

from typing import Dict, List
from ezpyai._logger import logger
from ezpyai._constants import _DICT_KEY_SUMMARY
from ezpyai.llm._llm import LLM
from ezpyai.llm.providers._llm_provider import LLMProvider
from ezpyai.llm.knowledge._knowledge_db import BaseKnowledgeDB
from ezpyai.llm.knowledge._knowledge_gatherer import KnowledgeGatherer
from ezpyai.llm.knowledge.knowledge_item import KnowledgeItem
Expand Down Expand Up @@ -50,13 +49,16 @@ def destroy(self) -> None:

self._client.reset()

def store(self, collection: str, data_path: str, summarizer: LLM = None) -> None:
def store(
self, collection: str, data_path: str, summarizer: LLMProvider = None
) -> None:
"""
Store the data in the given collection.
Args:
collection (str): The name of the collection.
data_path (str): The path to the data.
summarizer (LLMProvider): The LLMProvider summarizer to use for knowledge collection.
"""
logger.debug(f"Storing data in collection: {collection} from: {data_path}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)


class LLM(ABC):
class LLMProvider(ABC):
@abstractmethod
def get_response(self, prompt: Prompt) -> str:
pass
Expand All @@ -25,7 +25,7 @@ def remove_artifacts(self, response: str) -> str:
pass


class BaseLLM(LLM):
class BaseLLMProvider(LLMProvider):
def get_response(self, _: Prompt) -> str:
return ""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Annotated
from openai import OpenAI as _OpenAI
from ezpyai._logger import logger
from ezpyai.llm._llm import BaseLLM
from ezpyai.llm.providers._llm_provider import BaseLLMProvider
from ezpyai.llm.prompt import Prompt


Expand Down Expand Up @@ -46,7 +46,7 @@
_DEFAULT_MAX_TOKENS: int = 150


class OpenAI(BaseLLM):
class LLMProviderOpenAI(BaseLLMProvider):
def __init__(
self,
model: str = _DEFAULT_MODEL,
Expand Down
38 changes: 38 additions & 0 deletions src/ezpyai/llm/providers/text_generation_web_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
from sys import api_version

import openai
from openai import OpenAI
from ezpyai._constants import (
_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY,
_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL,
)
from ezpyai.llm.providers.openai import (
LLMProviderOpenAI,
_DEFAULT_TEMPERATURE,
_DEFAULT_MAX_TOKENS,
)


class LLMProviderTextGenerationWebUI(LLMProviderOpenAI):
"""
LLM provider for Text Generation Web UI's OpenAI compatible API.
"""

def __init__(
self,
model: str,
base_url: str = os.getenv(_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_BASE_URL),
temperature: float = _DEFAULT_TEMPERATURE,
max_tokens: int = _DEFAULT_MAX_TOKENS,
api_key: str = os.getenv(_ENV_VAR_NAME_TEXT_GENERATION_WEBUI_API_KEY),
) -> None:
openai.api_version = "2023-05-15"
self._client = OpenAI(
base_url=base_url,
api_key=api_key,
)

self._model = model
self._temperature = temperature
self._max_tokens = max_tokens

0 comments on commit bd47c2b

Please sign in to comment.