Skip to content

Commit

Permalink
Merge pull request #1 from hack4impact-uiuc/andrewlester/llama-3
Browse files Browse the repository at this point in the history
Add provider for llama 3
  • Loading branch information
ericcccsliu authored Apr 26, 2024
2 parents fdef69b + aa5e503 commit 9ebcc77
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 0 deletions.
1 change: 1 addition & 0 deletions api/utils/llm_provider_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ def __init__(self, model_name: str, model_provider: str, display_name: str, inpu
LLMProvider("claude-3-opus-20240229", "anthropic", "claude 3 opus", 15 / 1000000, 75 / 1000000, True),
LLMProvider("claude-3-sonnet-20240229", "anthropic", "claude 3 sonnet", 3 / 1000000, 15 / 1000000, False),
LLMProvider("claude-3-haiku-20240307", "anthropic", "claude 3 haiku", 0.25 / 1000000, 1.25 / 1000000, False),
LLMProvider("llama-3-70b-instruct-iq2xs", "alllama", "llama-3 70B", 0, 0, False),
]
44 changes: 44 additions & 0 deletions api/utils/llm_providers/alllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# api/utils/llm_providers/alllama.py

import tiktoken
from openai import AsyncOpenAI
from starlette.config import Config

from api.utils.conversation_utils import update_user_usage

config = Config(".env")
client = AsyncOpenAI(
base_url=config("ALLLAMA_API_BASE_URL"), api_key=config("ALLLAMA_API_KEY")
)

async def alllama_generate_response(conversation):
messages = [
{"role": message.role, "content": message.content}
for message in conversation.messages
]

# Count the input tokens
encoding = tiktoken.encoding_for_model(conversation.model.name)
input_tokens = sum(len(encoding.encode(message["content"])) for message in messages)

stream = await client.chat.completions.create(
max_tokens=1500,
model=conversation.model.name,
messages=messages,
stream=True,
)

collected_chunks = []
output_tokens = 0
async for chunk in stream:
content = chunk.choices[0].delta.content
if content is None:
content = ""
collected_chunks.append(content)
output_tokens += len(encoding.encode(content))
yield content

# Update the user's usage
await update_user_usage(
conversation.user_email, conversation.model.name, input_tokens, output_tokens
)
6 changes: 6 additions & 0 deletions api/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from api.models.conversation import Message
from api.utils.llm_providers.openai import openai_generate_response
from api.utils.llm_providers.anthropic import anthropic_generate_response, generate_conversation_name
from api.utils.llm_providers.alllama import alllama_generate_response

async def generate_response_stream(conversation):
visible_messages = [message for message in conversation.messages if not message.hidden]
Expand All @@ -28,6 +29,11 @@ async def generate_response_stream(conversation):
if chunk:
collected_chunks.append(chunk)
yield chunk
elif conversation.model.provider == "alllama":
async for chunk in alllama_generate_response(conversation):
if chunk:
collected_chunks.append(chunk)
yield chunk
else:
# Fallback to mock response for other providers
async def mock_response_generator():
Expand Down
6 changes: 6 additions & 0 deletions app/llm_providers.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,10 @@ export const LLMProviders = [
display_name: "claude 3 haiku",
isFlagship: false,
},
{
model_name: "llama-3-70b-instruct-iq2xs",
model_provider: "alllama",
display_name: "llama-3 70B",
isFlagship: false,
},
];

0 comments on commit 9ebcc77

Please sign in to comment.