Skip to content

Commit

Permalink
implement rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
ericcccsliu committed Mar 31, 2024
1 parent be31e57 commit 8bbfa8c
Show file tree
Hide file tree
Showing 26 changed files with 293 additions and 49 deletions.
Binary file modified api/models/__pycache__/user.cpython-311.pyc
Binary file not shown.
6 changes: 5 additions & 1 deletion api/models/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#api/models/user.py
from datetime import date
from pydantic import BaseModel

class User(BaseModel):
email: str
first_name: str
last_name: str
last_name: str
daily_flagship_usage: float = 0.0
daily_usage: float = 0.0
last_usage_update: date = None
Binary file modified api/routes/__pycache__/auth_routes.cpython-311.pyc
Binary file not shown.
Binary file modified api/routes/__pycache__/conversation_routes.cpython-311.pyc
Binary file not shown.
2 changes: 2 additions & 0 deletions api/routes/auth_routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#api/routes/auth_routes.py

import httpx
from fastapi import APIRouter, Request, HTTPException, Depends
from authlib.integrations.starlette_client import OAuth
Expand Down
32 changes: 30 additions & 2 deletions api/routes/conversation_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from starlette.config import Config
from pydantic import BaseModel
from api.models.conversation import Message
from api.utils.conversation_utils import update_conversation_messages, create_conversation, add_message, get_conversation_by_id, get_conversations_by_user, update_conversation_model
from api.utils.llm_utils import generate_response_stream
from api.utils.llm_provider_info import LLM_PROVIDERS
from api.utils.auth_utils import get_current_user
from api.models.user import User
from api.utils.db_utils import get_db


router = APIRouter()
config = Config('.env')

class ConversationCreate(BaseModel):
model_provider: str
Expand All @@ -33,10 +37,25 @@ async def create_conversation_route(conversation_create: ConversationCreate, cur

@router.post("/conversations/{conversation_id}/message")
async def add_message_route(conversation_id: str, message_create: MessageCreate, current_user: User = Depends(get_current_user), db = Depends(get_db)):
daily_flagship_usage_limit = float(config('DAILY_FLAGSHIP_USAGE_LIMIT', default='1.0'))
daily_usage_limit = float(config('DAILY_USAGE_LIMIT', default='2.0'))

conversation = await get_conversation_by_id(conversation_id, current_user.email)

# Check if the user has exceeded their daily flagship limit
if current_user.daily_flagship_usage >= daily_flagship_usage_limit:
if any(p.model_name == conversation.model.name and p.is_flagship for p in LLM_PROVIDERS):
raise HTTPException(status_code=429, detail="Daily flagship usage limit exceeded")

# Check if the user has exceeded their daily general limit
if current_user.daily_usage >= daily_usage_limit:
raise HTTPException(status_code=429, detail="Daily usage limit exceeded")

user_message = Message(role='user', content=message_create.message)

await add_message(conversation_id, user_message)

conversation = await get_conversation_by_id(conversation_id, current_user.email)

if conversation:
return StreamingResponse(generate_response_stream(conversation), media_type="text/event-stream")
else:
Expand Down Expand Up @@ -91,4 +110,13 @@ async def get_conversation_route(conversation_id: str, current_user: User = Depe
if conversation:
return conversation
else:
raise HTTPException(status_code=404, detail="Conversation not found")
raise HTTPException(status_code=404, detail="Conversation not found")


@router.get("/usage")
async def get_usage_route(current_user: User = Depends(get_current_user), db = Depends(get_db)):
return {
"daily_flagship_usage": current_user.daily_flagship_usage,
"daily_usage": current_user.daily_usage,
"last_usage_update": current_user.last_usage_update.isoformat() if current_user.last_usage_update else None,
}
Binary file modified api/utils/__pycache__/auth_utils.cpython-311.pyc
Binary file not shown.
Binary file modified api/utils/__pycache__/conversation_utils.cpython-311.pyc
Binary file not shown.
Binary file modified api/utils/__pycache__/db_utils.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified api/utils/__pycache__/llm_utils.cpython-311.pyc
Binary file not shown.
24 changes: 21 additions & 3 deletions api/utils/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from api.utils.db_utils import get_db
from api.models.user import User
from starlette.config import Config
from datetime import datetime
import pytz

config = Config('.env')

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
SECRET_KEY = config('SECRET_KEY')

Expand All @@ -20,12 +21,29 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
db = await get_db()
users_collection = db["users"]
user_dict = await users_collection.find_one({"email": email})

if user_dict is None:
raise HTTPException(status_code=401, detail="User not found")

user = User(**user_dict)
return user

# Check if the day has changed and reset the API usage if necessary
central_tz = pytz.timezone('US/Central')
today = datetime.now(central_tz).date().isoformat()
if user.last_usage_update is None or user.last_usage_update.isoformat() < today:
# Reset the daily usage and flagship usage
await users_collection.update_one(
{"email": email},
{"$set": {
"daily_usage": 0,
"daily_flagship_usage": 0,
"last_usage_update": today
}}
)
# Update the user object with the reset values
user.daily_usage = 0
user.daily_flagship_usage = 0
user.last_usage_update = datetime.strptime(today, "%Y-%m-%d").date()

return user
except JWTError:
raise HTTPException(status_code=401, detail="Invalid authentication token")
42 changes: 41 additions & 1 deletion api/utils/conversation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import List

from api.utils.db_utils import get_db
from api.utils.llm_provider_info import LLM_PROVIDERS
from api.models.conversation import Message, Conversation, LanguageModel
from bson import ObjectId
from fastapi import HTTPException
import pytz

async def create_conversation(user_email: str, name: str, model_provider: str, model_name: str):
db = await get_db()
Expand Down Expand Up @@ -91,4 +93,42 @@ async def update_conversation_messages(conversation_id: str, updated_messages: L
)
return True
else:
return False
return False

async def update_user_usage(user_email: str, model_name: str, input_tokens: int, output_tokens: int):
db = await get_db()
central_tz = pytz.timezone('US/Central')
today = datetime.now(central_tz).date().isoformat()
user_collection = db['users']

llm_provider = next((p for p in LLM_PROVIDERS if p.model_name == model_name), None)
if not llm_provider:
raise ValueError(f"Unknown model: {model_name}")

input_cost = input_tokens * llm_provider.input_token_cost
output_cost = output_tokens * llm_provider.output_token_cost
total_cost = input_cost + output_cost

update_query = {
'$inc': {
'daily_usage': total_cost,
}
}

if llm_provider.is_flagship:
update_query['$inc']['daily_flagship_usage'] = total_cost

# Update the user's daily usage
result = await user_collection.update_one(
{'email': user_email, 'last_usage_update': today},
update_query,
)

# If no document was updated, it means it's a new day or the user doesn't exist
if result.modified_count == 0:
# Set the daily usage to the current cost and update the last usage update date
await user_collection.update_one(
{'email': user_email},
{'$set': {'daily_usage': total_cost, 'daily_flagship_usage': total_cost if llm_provider.is_flagship else 0, 'last_usage_update': today}},
upsert=True,
)
2 changes: 2 additions & 0 deletions api/utils/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#api/utils/db_utils.py

from motor.motor_asyncio import AsyncIOMotorClient
from starlette.config import Config

Expand Down
16 changes: 16 additions & 0 deletions api/utils/llm_provider_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class LLMProvider:
def __init__(self, model_name: str, model_provider: str, display_name: str, input_token_cost: float, output_token_cost: float, is_flagship: bool):
self.model_name = model_name
self.model_provider = model_provider
self.display_name = display_name
self.input_token_cost = input_token_cost
self.output_token_cost = output_token_cost
self.is_flagship = is_flagship

LLM_PROVIDERS = [
LLMProvider("gpt-4-0125-preview", "openai", "gpt-4 turbo", 0.03 / 1000, 0.06 / 1000, True),
LLMProvider("gpt-3.5-turbo-0125", "openai", "gpt-3.5 turbo", 0.002 / 1000, 0.002 / 1000, False),
LLMProvider("claude-3-opus-20240229", "anthropic", "claude 3 opus", 0.02 / 1000, 0.04 / 1000, True),
LLMProvider("claude-3-sonnet-20240229", "anthropic", "claude 3 sonnet", 0.001 / 1000, 0.001 / 1000, False),
LLMProvider("claude-3-haiku-20240307", "anthropic", "claude 3 haiku", 0.001 / 1000, 0.001 / 1000, False),
]
Binary file modified api/utils/llm_providers/__pycache__/anthropic.cpython-311.pyc
Binary file not shown.
Binary file modified api/utils/llm_providers/__pycache__/openai.cpython-311.pyc
Binary file not shown.
13 changes: 12 additions & 1 deletion api/utils/llm_providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from anthropic import AsyncAnthropic
from starlette.config import Config
from api.utils.conversation_utils import update_user_usage

config = Config('.env')
client = AsyncAnthropic(api_key=config("ANTHROPIC_API_KEY"))
Expand All @@ -10,6 +11,9 @@ async def anthropic_generate_response(conversation):
for message in conversation.messages
]

input_tokens = 0
output_tokens = 0

stream = await client.messages.create(
model=conversation.model.name,
messages=messages,
Expand All @@ -18,10 +22,17 @@ async def anthropic_generate_response(conversation):
)

async for event in stream:
if event.type == "content_block_delta":
if event.type == "message_start":
input_tokens = event.message.usage.input_tokens
elif event.type == "message_delta":
output_tokens = event.usage.output_tokens
elif event.type == "content_block_delta":
content = event.delta.text
yield content

# Update the user's usage
await update_user_usage(conversation.user_email, conversation.model.name, input_tokens, output_tokens)

async def generate_conversation_name(conversation):
messages = [
{"role": message.role, "content": message.content}
Expand Down
21 changes: 18 additions & 3 deletions api/utils/llm_providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# api/utils/llm_providers/openai.py

import tiktoken
from openai import AsyncOpenAI
from starlette.config import Config
from api.utils.conversation_utils import update_user_usage

config = Config('.env')
client = AsyncOpenAI(api_key=config("OPENAI_API_KEY"))
Expand All @@ -10,14 +14,25 @@ async def openai_generate_response(conversation):
for message in conversation.messages
]

# Count the input tokens
encoding = tiktoken.encoding_for_model(conversation.model.name)
input_tokens = sum(len(encoding.encode(message["content"])) for message in messages)

stream = await client.chat.completions.create(
model=conversation.model.name,
messages=messages,
stream=True,
)

collected_chunks = []
output_tokens = 0
async for chunk in stream:
content = chunk.choices[0].delta.content
if content is None:
content = chunk.choices[0].delta.content
if content is None:
content = ""
yield content
collected_chunks.append(content)
output_tokens += len(encoding.encode(content))
yield content

# Update the user's usage
await update_user_usage(conversation.user_email, conversation.model.name, input_tokens,output_tokens)
2 changes: 2 additions & 0 deletions api/utils/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#api/utils/llm_utils.py

import asyncio
from api.utils.conversation_utils import add_message, get_conversation_by_id, update_conversation_model
from api.models.conversation import Message
Expand Down
8 changes: 7 additions & 1 deletion app/llm_providers.tsx
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
// app/llm_providers.tsx
export const LLMProviders = [
{
model_name: "gpt-4-0125-preview",
model_provider: "openai",
display_name: "gpt-4 turbo",
isFlagship: true,
},
{
model_name: "gpt-3.5-turbo-0125",
model_provider: "openai",
display_name: "gpt-3.5 turbo",
isFlagship: false,
},
{
model_name: "claude-3-opus-20240229",
model_provider: "anthropic",
display_name: "claude 3 opus",
isFlagship: true,
},
{
model_name: "claude-3-sonnet-20240229",
model_provider: "anthropic",
display_name: "claude 3 sonnet",
isFlagship: false,
},
{
model_name: "claude-3-haiku-20240307",
model_provider: "anthropic",
display_name: "claude 3 haiku",
isFlagship: false,
},
];
];
Loading

0 comments on commit 8bbfa8c

Please sign in to comment.