From b30da11d975736d4b88e7b08e2d50574eb4cb4c3 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Tue, 21 Mar 2023 17:23:50 +0800 Subject: [PATCH] rate limit 100 requests / 10mins (#14) * rate format * update --- api/chat_message_handler.go | 2 +- api/chat_message_service.go | 12 +++++- api/main.go | 6 ++- api/middleware_authenticate.go | 1 + api/middleware_log.go | 56 ---------------------------- api/middleware_rateLimit.go | 37 ++++++++++++++++++ api/sqlc/queries/chat_message.sql | 10 ++++- api/sqlc_queries/chat_message.sql.go | 15 ++++++++ chat.code-workspace | 14 +++++++ 9 files changed, 91 insertions(+), 62 deletions(-) delete mode 100644 api/middleware_log.go create mode 100644 api/middleware_rateLimit.go create mode 100644 chat.code-workspace diff --git a/api/chat_message_handler.go b/api/chat_message_handler.go index 59239d64..4250ddd5 100644 --- a/api/chat_message_handler.go +++ b/api/chat_message_handler.go @@ -219,4 +219,4 @@ func (h *ChatMessageHandler) DeleteChatMessagesBySesionUUID(w http.ResponseWrite return } w.WriteHeader(http.StatusOK) -} \ No newline at end of file +} diff --git a/api/chat_message_service.go b/api/chat_message_service.go index 5a2d6a10..31cd52bf 100644 --- a/api/chat_message_service.go +++ b/api/chat_message_service.go @@ -202,11 +202,19 @@ func (s *ChatMessageService) GetLastNChatMessages(ctx context.Context, uuid stri return message, nil } -//DeleteChatMessagesBySesionUUID deletes chat messages by session uuid. +// DeleteChatMessagesBySesionUUID deletes chat messages by session uuid. func (s *ChatMessageService) DeleteChatMessagesBySesionUUID(ctx context.Context, uuid string) error { err := s.q.DeleteChatMessagesBySesionUUID(ctx, uuid) if err != nil { return errors.New("failed to delete message") } return nil -} \ No newline at end of file +} + +func (s *ChatMessageService) GetChatMessagesCount(ctx context.Context, userID int32) (int32, error) { + count, err := s.q.GetChatMessagesCount(ctx, userID) + if err != nil { + return 0, err + } + return int32(count), nil +} diff --git a/api/main.go b/api/main.go index 81edc1eb..118e9183 100644 --- a/api/main.go +++ b/api/main.go @@ -27,7 +27,6 @@ func main() { } OPENAI_API_KEY = os.Getenv("OPENAI_API_KEY") - if JWT_SECRET, exists = os.LookupEnv("JWT_SECRET"); !exists { log.Fatal("JWT_SECRET not set") } @@ -36,7 +35,7 @@ func main() { if JWT_AUD, exists = os.LookupEnv("JWT_AUD"); !exists { log.Fatal("JWT_AUD not set") } - JWT_AUD= os.Getenv("JWT_AUD") + JWT_AUD = os.Getenv("JWT_AUD") // Create a new logger instance, configure it as desired logger = log.New() @@ -165,6 +164,9 @@ func main() { router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) router.Use(IsAuthorizedMiddleware) // Wrap the router with the logging middleware + // 10 min < 100 requests + limitedRouter := RateLimitByUserID(sqlc_q) + router.Use(limitedRouter) // loggedMux := loggingMiddleware(router, logger) loggedRouter := handlers.LoggingHandler(logger.Out, router) err = http.ListenAndServe(":8077", loggedRouter) diff --git a/api/middleware_authenticate.go b/api/middleware_authenticate.go index cc3f469d..07242570 100644 --- a/api/middleware_authenticate.go +++ b/api/middleware_authenticate.go @@ -91,6 +91,7 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler { } ctx := context.WithValue(r.Context(), userContextKey, userID) ctx = context.WithValue(ctx, roleContextKey, role) + // TODO: get trace id and add it to context //traceID := r.Header.Get("X-Request-Id") //if len(traceID) > 0 { diff --git a/api/middleware_log.go b/api/middleware_log.go deleted file mode 100644 index 4004dae8..00000000 --- a/api/middleware_log.go +++ /dev/null @@ -1,56 +0,0 @@ -package main - -import ( - "net/http" - - log "github.com/sirupsen/logrus" -) - -func LoggingMiddleware(next http.Handler, logger *log.Logger) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Log the incoming request - logger.WithFields(log.Fields{ - "method": r.Method, - "path": r.URL.Path, - "remote": r.RemoteAddr, - }).Info("Request received") - - // Create a new ResponseWriter to capture status code - recorder := newLoggingResponseWriter(w) - - // Call the next handler in the chain - next.ServeHTTP(recorder, r) - - // Log the outgoing response - logger.WithFields(log.Fields{ - "method": r.Method, - "path": r.URL.Path, - "remote": r.RemoteAddr, - "status": recorder.statusCode, - "size": recorder.size, - }).Info("Response sent") - }) -} - -func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { - return &loggingResponseWriter{w, http.StatusOK, 0} -} - -// loggingResponseWriter wraps an http.ResponseWriter, allowing us to log -// the response status code and size. -type loggingResponseWriter struct { - http.ResponseWriter - statusCode int - size int -} - -func (lrw *loggingResponseWriter) WriteHeader(statusCode int) { - lrw.statusCode = statusCode - lrw.ResponseWriter.WriteHeader(statusCode) -} - -func (lrw *loggingResponseWriter) Write(b []byte) (int, error) { - size, err := lrw.ResponseWriter.Write(b) - lrw.size += size - return size, err -} diff --git a/api/middleware_rateLimit.go b/api/middleware_rateLimit.go new file mode 100644 index 00000000..789899ef --- /dev/null +++ b/api/middleware_rateLimit.go @@ -0,0 +1,37 @@ +package main + +import ( + "net/http" + "strconv" + + "github.com/swuecho/chatgpt_backend/sqlc_queries" +) + +// This function returns a middleware that limits requests from each user by their ID. +func RateLimitByUserID(q *sqlc_queries.Queries) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get the user ID from the request, e.g. from a JWT token. + ctx := r.Context() + userIDStr := ctx.Value(userContextKey).(string) + userIDInt, err := strconv.Atoi(userIDStr) + if err != nil { + http.Error(w, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", http.StatusBadRequest) + return + } + messageCount, err := q.GetChatMessagesCount(r.Context(), int32(userIDInt)) + if err != nil { + http.Error(w, "Error: Could not get message count.", http.StatusInternalServerError) + return + } + + // Check if the request exceeds the rate limit. + if messageCount > 100 { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + // Call the next handler. + next.ServeHTTP(w, r) + }) + } +} diff --git a/api/sqlc/queries/chat_message.sql b/api/sqlc/queries/chat_message.sql index 6fd75c90..7b9b9a51 100644 --- a/api/sqlc/queries/chat_message.sql +++ b/api/sqlc/queries/chat_message.sql @@ -108,4 +108,12 @@ WHERE uuid = $1 ; -- name: DeleteChatMessagesBySesionUUID :exec DELETE FROM chat_message -WHERE chat_session_uuid = $1; \ No newline at end of file +WHERE chat_session_uuid = $1; + + +-- name: GetChatMessagesCount :one +-- Get total chat message count for user in last 10 minutes +SELECT COUNT(*) +FROM chat_message +WHERE user_id = $1 +AND created_at >= NOW() - INTERVAL '10 minutes'; diff --git a/api/sqlc_queries/chat_message.sql.go b/api/sqlc_queries/chat_message.sql.go index bbfae5b9..4305f8e4 100644 --- a/api/sqlc_queries/chat_message.sql.go +++ b/api/sqlc_queries/chat_message.sql.go @@ -311,6 +311,21 @@ func (q *Queries) GetChatMessagesBySessionUUID(ctx context.Context, arg GetChatM return items, nil } +const getChatMessagesCount = `-- name: GetChatMessagesCount :one +SELECT COUNT(*) +FROM chat_message +WHERE user_id = $1 +AND created_at >= NOW() - INTERVAL '10 minutes' +` + +// Get total chat message count for user in last 10 minutes +func (q *Queries) GetChatMessagesCount(ctx context.Context, userID int32) (int64, error) { + row := q.db.QueryRowContext(ctx, getChatMessagesCount, userID) + var count int64 + err := row.Scan(&count) + return count, err +} + const getFirstMessageBySessionUUID = `-- name: GetFirstMessageBySessionUUID :one SELECT id, uuid, chat_session_uuid, role, content, score, user_id, created_at, updated_at, created_by, updated_by, raw FROM chat_message diff --git a/chat.code-workspace b/chat.code-workspace new file mode 100644 index 00000000..59e7fa36 --- /dev/null +++ b/chat.code-workspace @@ -0,0 +1,14 @@ +{ + "folders": [ + { + "path": "web" + }, + { + "path": "api" + }, + { + "path": "e2e" + } + ], + "settings": {} +} \ No newline at end of file