Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Http #6

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ name: tests
on:
pull_request:
push:
branches:
- main
release:

env:
Expand Down
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ default: lint typecheck test
init:
pip install --upgrade pip
pip install -r requirements-dev.txt
pip install -r requirements.txt

test:
pytest -v --cov=l2m2 --cov=test_utils --cov-report=term-missing --failed-first --durations=0
Expand Down
3 changes: 3 additions & 0 deletions l2m2/_internal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .http import llm_post

__all__ = ["llm_post"]
61 changes: 61 additions & 0 deletions l2m2/_internal/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Optional, Dict, Any
import requests

from l2m2.model_info import API_KEY, MODEL_ID, PROVIDER_INFO


def _get_headers(provider: str, api_key: str) -> Dict[str, str]:
provider_info = PROVIDER_INFO[provider]
headers = provider_info["headers"].copy()
return {key: value.replace(API_KEY, api_key) for key, value in headers.items()}


def _handle_replicate_201(response: requests.Response, api_key: str) -> Any:
# See https://replicate.com/docs/reference/http#models.versions.get
resource = response.json()
if "status" in resource and "urls" in resource and "get" in resource["urls"]:

while resource["status"] != "succeeded":
if resource["status"] == "failed" or resource["status"] == "cancelled":
raise Exception(resource)

next_response = requests.get(
resource["urls"]["get"],
headers=_get_headers("replicate", api_key),
)

if next_response.status_code != 200:
raise Exception(next_response.text)

resource = next_response.json()

return resource

else:
raise Exception(resource)


def llm_post(
provider: str,
api_key: str,
data: Dict[str, Any],
model_id: Optional[str] = None,
) -> Any:
endpoint = PROVIDER_INFO[provider]["endpoint"]
endpoint = endpoint.replace(API_KEY, api_key)
if model_id is not None:
endpoint = endpoint.replace(MODEL_ID, model_id)

response = requests.post(
endpoint,
headers=_get_headers(provider, api_key),
json=data,
)

if provider == "replicate" and response.status_code == 201:
return _handle_replicate_201(response, api_key)

if response.status_code != 200:
raise Exception(response.text)

return response.json()
107 changes: 52 additions & 55 deletions l2m2/client/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
from typing import Any, Set, Dict, Optional, Tuple

import google.generativeai as google
from cohere import Client as CohereClient
from openai import OpenAI
from anthropic import Anthropic
from groq import Groq
import replicate
from typing import Any, List, Set, Dict, Optional, Tuple

from l2m2.model_info import (
MODEL_INFO,
Expand All @@ -29,6 +22,7 @@
get_extra_message,
run_json_strats_out,
)
from l2m2._internal.http import llm_post


class LLMClient:
Expand Down Expand Up @@ -131,11 +125,14 @@ def add_provider(self, provider: str, api_key: str) -> None:

Raises:
ValueError: If the provider is not one of the available providers.
ValueError: If the API key is not a string.
"""
if provider not in (providers := self.get_available_providers()):
raise ValueError(
f"Invalid provider: {provider}. Available providers: {providers}"
)
if not isinstance(api_key, str):
raise ValueError(f"API key for provider {provider} must be a string.")

self.api_keys[provider] = api_key
self.active_providers.add(provider)
Expand Down Expand Up @@ -462,19 +459,18 @@ def _call_openai(
params: Dict[str, Any],
*_: Any, # json_mode and json_mode_strategy are not used here
) -> str:
oai = OpenAI(api_key=self.api_keys["openai"])
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
result = oai.chat.completions.create(
model=model_id,
messages=messages, # type: ignore
**params,
result = llm_post(
"openai",
self.api_keys["openai"],
{"model": model_id, "messages": messages, **params},
)
return str(result.choices[0].message.content)
return str(result["choices"][0]["message"]["content"])

def _call_anthropic(
self,
Expand All @@ -485,7 +481,6 @@ def _call_anthropic(
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
) -> str:
anthr = Anthropic(api_key=self.api_keys["anthropic"])
if system_prompt is not None:
params["system"] = system_prompt
messages = []
Expand All @@ -498,12 +493,12 @@ def _call_anthropic(
if append_msg:
messages.append({"role": "assistant", "content": append_msg})

result = anthr.messages.create(
model=model_id,
messages=messages, # type: ignore
**params,
result = llm_post(
"anthropic",
self.api_keys["anthropic"],
{"model": model_id, "messages": messages, **params},
)
return str(result.content[0].text)
return str(result["content"][0]["text"])

def _call_cohere(
self,
Expand All @@ -514,7 +509,6 @@ def _call_cohere(
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
) -> str:
cohere = CohereClient(api_key=self.api_keys["cohere"])
if system_prompt is not None:
params["preamble"] = system_prompt
if isinstance(self.memory, ChatMemory):
Expand All @@ -528,12 +522,12 @@ def _call_cohere(
entry = {"role": "CHATBOT", "message": append_msg}
params.setdefault("chat_history", []).append(entry)

result = cohere.chat(
model=model_id,
message=prompt,
**params,
result = llm_post(
"cohere",
self.api_keys["cohere"],
{"model": model_id, "message": prompt, **params},
)
return str(result.text)
return str(result["text"])

def _call_groq(
self,
Expand All @@ -544,7 +538,6 @@ def _call_groq(
json_mode: bool,
json_mode_strategy: JsonModeStrategy,
) -> str:
groq = Groq(api_key=self.api_keys["groq"])
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
Expand All @@ -557,12 +550,12 @@ def _call_groq(
if append_msg:
messages.append({"role": "assistant", "content": append_msg})

result = groq.chat.completions.create(
model=model_id,
messages=messages, # type: ignore
**params,
result = llm_post(
"groq",
self.api_keys["groq"],
{"model": model_id, "messages": messages, **params},
)
return str(result.choices[0].message.content)
return str(result["choices"][0]["message"]["content"])

def _call_google(
self,
Expand All @@ -572,28 +565,32 @@ def _call_google(
params: Dict[str, Any],
*_: Any, # json_mode and json_mode_strategy are not used here
) -> str:
google.configure(api_key=self.api_keys["google"])
data: Dict[str, Any] = {}

model_params = {"model_name": model_id}
if system_prompt is not None:
# Earlier versions don't support system prompts, so prepend it to the prompt
if model_id not in ["gemini-1.5-pro-latest"]:
# Earlier models don't support system prompts, so prepend it to the prompt
if model_id not in ["gemini-1.5-pro"]:
prompt = f"{system_prompt}\n{prompt}"
else:
model_params["system_instruction"] = system_prompt
model = google.GenerativeModel(**model_params)
data["system_instruction"] = {"parts": {"text": system_prompt}}

messages = []
messages: List[Dict[str, Any]] = []
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "parts", "user", "model"))
messages.append({"role": "user", "parts": prompt})
mem_items = self.memory.unpack("role", "parts", "user", "model")
# Need to do this wrap – see https://ai.google.dev/api/rest/v1beta/cachedContents#Part
messages.extend([{**m, "parts": {"text": m["parts"]}} for m in mem_items])

messages.append({"role": "user", "parts": {"text": prompt}})

result = model.generate_content(messages, generation_config=params)
result = result.candidates[0]
data["contents"] = messages
data["generation_config"] = params

result = llm_post("google", self.api_keys["google"], data, model_id=model_id)
result = result["candidates"][0]

# Will sometimes fail due to safety filters
if result.content:
return str(result.content.parts[0].text)
if "content" in result:
return str(result["content"]["parts"][0]["text"])
else:
return str(result)

Expand All @@ -608,25 +605,25 @@ def _call_replicate(
) -> str:
if isinstance(self.memory, ChatMemory):
raise ValueError(
"Chat memory is not supported with Replicate. Try using Groq."
"Chat memory is not supported with Replicate."
+ " Try using Groq, or using ExternalMemory instead."
)
if json_mode_strategy.strategy_name == StrategyName.PREPEND:
raise ValueError(
"JsonModeStrategy.prepend() is not supported with Replicate."
+ "Try using Groq, or using JsonModeStrategy.strip() instead."
+ " Try using Groq, or using JsonModeStrategy.strip() instead."
)

client = replicate.Client(api_token=self.api_keys["replicate"])
if system_prompt is not None:
params["system_prompt"] = system_prompt
result = client.run(
model_id,
input={
"prompt": prompt,
**params,
},

result = llm_post(
"replicate",
self.api_keys["replicate"],
{"input": {"prompt": prompt, **params}},
model_id=model_id,
)
return "".join(result)
return "".join(result["output"])

def _get_external_memory_prompts(
self, system_prompt: Optional[str], prompt: str
Expand Down
38 changes: 36 additions & 2 deletions l2m2/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@
class ProviderEntry(TypedDict):
name: str
homepage: str
endpoint: str
headers: Dict[str, str]


T = TypeVar("T")

Marker = Enum("Marker", {"PROVIDER_DEFAULT": "<<PROVIDER_DEFAULT>>"})
PROVIDER_DEFAULT: Marker = Marker.PROVIDER_DEFAULT

API_KEY = "<<API_KEY>>"
MODEL_ID = "<<MODEL_ID>>"


class Param(TypedDict, Generic[T]):
custom_key: NotRequired[str]
Expand Down Expand Up @@ -47,26 +52,55 @@ class ModelEntry(TypedDict):
"openai": {
"name": "OpenAI",
"homepage": "https://openai.com/product",
"endpoint": "https://api.openai.com/v1/chat/completions",
"headers": {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
},
},
"google": {
"name": "Google",
"homepage": "https://ai.google.dev/",
"endpoint": f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_ID}:generateContent?key={API_KEY}",
"headers": {"Content-Type": "application/json"},
},
"anthropic": {
"name": "Anthropic",
"homepage": "https://www.anthropic.com/api",
"endpoint": "https://api.anthropic.com/v1/messages",
"headers": {
"x-api-key": API_KEY,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json",
},
},
"cohere": {
"name": "Cohere",
"homepage": "https://docs.cohere.com/",
"endpoint": "https://api.cohere.com/v1/chat",
"headers": {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {API_KEY}",
},
},
"groq": {
"name": "Groq",
"homepage": "https://wow.groq.com/",
"endpoint": "https://api.groq.com/openai/v1/chat/completions",
"headers": {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
},
},
"replicate": {
"name": "Replicate",
"homepage": "https://replicate.com/",
"endpoint": f"https://api.replicate.com/v1/models/{MODEL_ID}/predictions",
"headers": {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
},
},
}

Expand Down Expand Up @@ -121,7 +155,7 @@ class ModelEntry(TypedDict):
},
"gemini-1.5-pro": {
"google": {
"model_id": "gemini-1.5-pro-latest",
"model_id": "gemini-1.5-pro",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
Expand All @@ -139,7 +173,7 @@ class ModelEntry(TypedDict):
},
"gemini-1.0-pro": {
"google": {
"model_id": "gemini-1.0-pro-latest",
"model_id": "gemini-1.0-pro",
"params": {
"temperature": {
"default": PROVIDER_DEFAULT,
Expand Down
Loading