Skip to content

Commit

Permalink
rate limit 100 requests / 10mins (#14)
Browse files Browse the repository at this point in the history
* rate format

* update
  • Loading branch information
swuecho authored Mar 21, 2023
1 parent a27a755 commit b30da11
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 62 deletions.
2 changes: 1 addition & 1 deletion api/chat_message_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,4 @@ func (h *ChatMessageHandler) DeleteChatMessagesBySesionUUID(w http.ResponseWrite
return
}
w.WriteHeader(http.StatusOK)
}
}
12 changes: 10 additions & 2 deletions api/chat_message_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,19 @@ func (s *ChatMessageService) GetLastNChatMessages(ctx context.Context, uuid stri
return message, nil
}

//DeleteChatMessagesBySesionUUID deletes chat messages by session uuid.
// DeleteChatMessagesBySesionUUID deletes chat messages by session uuid.
func (s *ChatMessageService) DeleteChatMessagesBySesionUUID(ctx context.Context, uuid string) error {
err := s.q.DeleteChatMessagesBySesionUUID(ctx, uuid)
if err != nil {
return errors.New("failed to delete message")
}
return nil
}
}

func (s *ChatMessageService) GetChatMessagesCount(ctx context.Context, userID int32) (int32, error) {
count, err := s.q.GetChatMessagesCount(ctx, userID)
if err != nil {
return 0, err
}
return int32(count), nil
}
6 changes: 4 additions & 2 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func main() {
}
OPENAI_API_KEY = os.Getenv("OPENAI_API_KEY")


if JWT_SECRET, exists = os.LookupEnv("JWT_SECRET"); !exists {
log.Fatal("JWT_SECRET not set")
}
Expand All @@ -36,7 +35,7 @@ func main() {
if JWT_AUD, exists = os.LookupEnv("JWT_AUD"); !exists {
log.Fatal("JWT_AUD not set")
}
JWT_AUD= os.Getenv("JWT_AUD")
JWT_AUD = os.Getenv("JWT_AUD")

// Create a new logger instance, configure it as desired
logger = log.New()
Expand Down Expand Up @@ -165,6 +164,9 @@ func main() {
router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
router.Use(IsAuthorizedMiddleware)
// Wrap the router with the logging middleware
// 10 min < 100 requests
limitedRouter := RateLimitByUserID(sqlc_q)
router.Use(limitedRouter)
// loggedMux := loggingMiddleware(router, logger)
loggedRouter := handlers.LoggingHandler(logger.Out, router)
err = http.ListenAndServe(":8077", loggedRouter)
Expand Down
1 change: 1 addition & 0 deletions api/middleware_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler {
}
ctx := context.WithValue(r.Context(), userContextKey, userID)
ctx = context.WithValue(ctx, roleContextKey, role)

// TODO: get trace id and add it to context
//traceID := r.Header.Get("X-Request-Id")
//if len(traceID) > 0 {
Expand Down
56 changes: 0 additions & 56 deletions api/middleware_log.go

This file was deleted.

37 changes: 37 additions & 0 deletions api/middleware_rateLimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"net/http"
"strconv"

"github.com/swuecho/chatgpt_backend/sqlc_queries"
)

// This function returns a middleware that limits requests from each user by their ID.
func RateLimitByUserID(q *sqlc_queries.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the user ID from the request, e.g. from a JWT token.
ctx := r.Context()
userIDStr := ctx.Value(userContextKey).(string)
userIDInt, err := strconv.Atoi(userIDStr)
if err != nil {
http.Error(w, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", http.StatusBadRequest)
return
}
messageCount, err := q.GetChatMessagesCount(r.Context(), int32(userIDInt))
if err != nil {
http.Error(w, "Error: Could not get message count.", http.StatusInternalServerError)
return
}

// Check if the request exceeds the rate limit.
if messageCount > 100 {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
// Call the next handler.
next.ServeHTTP(w, r)
})
}
}
10 changes: 9 additions & 1 deletion api/sqlc/queries/chat_message.sql
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,12 @@ WHERE uuid = $1 ;

-- name: DeleteChatMessagesBySesionUUID :exec
DELETE FROM chat_message
WHERE chat_session_uuid = $1;
WHERE chat_session_uuid = $1;


-- name: GetChatMessagesCount :one
-- Get total chat message count for user in last 10 minutes
SELECT COUNT(*)
FROM chat_message
WHERE user_id = $1
AND created_at >= NOW() - INTERVAL '10 minutes';
15 changes: 15 additions & 0 deletions api/sqlc_queries/chat_message.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions chat.code-workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"folders": [
{
"path": "web"
},
{
"path": "api"
},
{
"path": "e2e"
}
],
"settings": {}
}

0 comments on commit b30da11

Please sign in to comment.