From 357e9e5286f13ef67f5b36dd179b77b634fa95c0 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sat, 7 Sep 2024 11:43:32 +0800 Subject: [PATCH] ``` Refactor API handlers and improve auto_commit script: - Increased timeout in auto_commit.py for API requests. - Renamed OpenAIChatCompletionAPIWithStreamHandler to ChatCompletionHandler. - Added ChatBotCompletionHandler for bot functionality. - Updated SQL queries to include ChatSnapshotByUserIdAndUuid. - Improved error handling and logging in ChatCompletionHandler. - Fixed proxy variable name in auto_commit.py. ``` --- api/chat_main_handler.go | 127 ++++++++++++++++++++++++-- api/chat_snapshot_handler.go | 2 +- api/sqlc/queries/chat_snapshot.sql | 2 + api/sqlc_queries/chat_snapshot.sql.go | 30 ++++++ scripts/auto_commit.py | 10 +- 5 files changed, 158 insertions(+), 13 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 264c01d9..64574e30 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -44,18 +44,18 @@ func NewChatHandler(sqlc_q *sqlc_queries.Queries) *ChatHandler { } func (h *ChatHandler) Register(router *mux.Router) { - router.HandleFunc("/chat_stream", h.OpenAIChatCompletionAPIWithStreamHandler).Methods(http.MethodPost) + router.HandleFunc("/chat_stream", h.ChatCompletionHandler).Methods(http.MethodPost) + // for bot + // given a chat_uuid, a user message, return the answer + // + router.HandleFunc("/chat_bot", h.ChatBotCompletionHandler).Methods(http.MethodPost) } -type ChatOptions struct { - Uuid string -} type ChatRequest struct { Prompt string SessionUuid string ChatUuid string Regenerate bool - Options ChatOptions } type ChatCompletionResponse struct { @@ -86,8 +86,70 @@ func NewUserMessage(content string) openai.ChatCompletionMessage { return openai.ChatCompletionMessage{Role: "user", Content: content} } -// OpenAIChatCompletionAPIWithStreamHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE) -func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWriter, r *http.Request) { + +type BotRequest struct { + Message string + SnapshotUuid string + Stream bool +} + +// ChatCompletionHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE) +func (h *ChatHandler) ChatCompletionHandler(w http.ResponseWriter, r *http.Request) { + var req BotRequest + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + RespondWithError(w, http.StatusBadRequest, err.Error(), err) + return + } + + snapshotUuid := req.SnapshotUuid + newQuestion := req.Message + + + log.Printf("snapshotUuid: %s", snapshotUuid) + log.Printf("newQuestion: %s", newQuestion) + + ctx := r.Context() + + userID, err := getUserID(ctx) + if err != nil { + RespondWithError(w, http.StatusBadRequest, err.Error(), err) + return + } + + fmt.Printf("userID: %d", userID) + + chatSnapshot, err := h.service.q.ChatSnapshotByUserIdAndUuid(ctx, sqlc_queries.ChatSnapshotByUserIdAndUuidParams{ + UserID: userID, + Uuid: snapshotUuid, + }) + if err != nil { + RespondWithError(w, http.StatusBadRequest, eris.Wrap(err, "fail to get chat snapshot").Error(), err) + return + } + + fmt.Printf("chatSnapshot: %+v", chatSnapshot) + + var session sqlc_queries.ChatSession + err = json.Unmarshal(chatSnapshot.Session, &session) + if err != nil { + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "fail to deserialize chat session").Error(), err) + return + } + var simpleChatMessages []SimpleChatMessage + err = json.Unmarshal(chatSnapshot.Conversation, &simpleChatMessages) + if err != nil { + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "fail to deserialize conversation").Error(), err) + return + } + + genBotAnswer(h, w, session, simpleChatMessages, newQuestion, userID) + +} + +// ChatCompletionHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE) +func (h *ChatHandler) ChatBotCompletionHandler(w http.ResponseWriter, r *http.Request) { var req ChatRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -217,6 +279,57 @@ func genAnswer(h *ChatHandler, w http.ResponseWriter, chatSessionUuid string, ch } } +func genBotAnswer(h *ChatHandler, w http.ResponseWriter, session sqlc_queries.ChatSession, simpleChatMessages []SimpleChatMessage, newQuestion string, userID int32) { + chatModel, err := h.service.q.ChatModelByName(context.Background(), session.Model) + if err != nil { + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "get chat model").Error(), err) + return + } + + baseURL, _ := getModelBaseUrl(chatModel.Url) + + messages := simpleChatMessagesToMessages(simpleChatMessages) + messages = append(messages, models.Message{ + Role: "user", + Content: newQuestion, + }) + chatStreamFn := h.chooseChatStreamFn(session, messages) + + answerText, answerID, shouldReturn := chatStreamFn(w, session, messages, "", false) + if shouldReturn { + return + } + + if !isTest(messages) { + h.service.logChat(session, messages, answerText) + } + + ctx := context.Background() + if _, err := h.service.CreateChatMessageSimple(ctx, session.Uuid, answerID, "assistant", answerText, userID, baseURL, session.SummarizeMode); err != nil { + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "failed to create message").Error(), nil) + return + } +} + +// Helper function to convert SimpleChatMessage to Message +func simpleChatMessagesToMessages(simpleChatMessages []SimpleChatMessage) []models.Message { + messages := make([]models.Message, len(simpleChatMessages)) + for i, scm := range simpleChatMessages { + role := "user" + if scm.Inversion { + role = "assistant" + } + if i == 0 { + role = "system" + } + messages[i] = models.Message{ + Role: role, + Content: scm.Text, + } + } + return messages +} + func regenerateAnswer(h *ChatHandler, w http.ResponseWriter, chatSessionUuid string, chatUuid string) { ctx := context.Background() chatSession, err := h.service.q.GetChatSessionByUUID(ctx, chatSessionUuid) diff --git a/api/chat_snapshot_handler.go b/api/chat_snapshot_handler.go index 9912ce50..f0a48623 100644 --- a/api/chat_snapshot_handler.go +++ b/api/chat_snapshot_handler.go @@ -36,7 +36,7 @@ func (h *ChatSnapshotHandler) CreateChatSnapshot(w http.ResponseWriter, r *http. RespondWithError(w, http.StatusInternalServerError, err.Error(), err) return } - uuid, err := h.service.CreateChatSnapshot(r.Context(), chatSessionUuid, user_id) +uuid, err := h.service.CreateChatSnapshot(r.Context(), chatSessionUuid, user_id) if err != nil { RespondWithError(w, http.StatusInternalServerError, err.Error(), err) } diff --git a/api/sqlc/queries/chat_snapshot.sql b/api/sqlc/queries/chat_snapshot.sql index 03fe0636..6074dad1 100644 --- a/api/sqlc/queries/chat_snapshot.sql +++ b/api/sqlc/queries/chat_snapshot.sql @@ -30,6 +30,8 @@ RETURNING *; -- name: ChatSnapshotByUUID :one SELECT * FROM chat_snapshot WHERE uuid = $1; +-- name: ChatSnapshotByUserIdAndUuid :one +SELECT * FROM chat_snapshot WHERE user_id = $1 AND uuid = $2; -- name: ChatSnapshotMetaByUserID :many SELECT uuid, title, summary, tags, created_at, typ diff --git a/api/sqlc_queries/chat_snapshot.sql.go b/api/sqlc_queries/chat_snapshot.sql.go index 0de857dd..453421db 100644 --- a/api/sqlc_queries/chat_snapshot.sql.go +++ b/api/sqlc_queries/chat_snapshot.sql.go @@ -61,6 +61,36 @@ func (q *Queries) ChatSnapshotByUUID(ctx context.Context, uuid string) (ChatSnap return i, err } +const chatSnapshotByUserIdAndUuid = `-- name: ChatSnapshotByUserIdAndUuid :one +SELECT id, typ, uuid, user_id, title, summary, model, tags, session, conversation, created_at, text, search_vector FROM chat_snapshot WHERE user_id = $1 AND uuid = $2 +` + +type ChatSnapshotByUserIdAndUuidParams struct { + UserID int32 `json:"userID"` + Uuid string `json:"uuid"` +} + +func (q *Queries) ChatSnapshotByUserIdAndUuid(ctx context.Context, arg ChatSnapshotByUserIdAndUuidParams) (ChatSnapshot, error) { + row := q.db.QueryRowContext(ctx, chatSnapshotByUserIdAndUuid, arg.UserID, arg.Uuid) + var i ChatSnapshot + err := row.Scan( + &i.ID, + &i.Typ, + &i.Uuid, + &i.UserID, + &i.Title, + &i.Summary, + &i.Model, + &i.Tags, + &i.Session, + &i.Conversation, + &i.CreatedAt, + &i.Text, + &i.SearchVector, + ) + return i, err +} + const chatSnapshotMetaByUserID = `-- name: ChatSnapshotMetaByUserID :many SELECT uuid, title, summary, tags, created_at, typ FROM chat_snapshot WHERE user_id = $1 diff --git a/scripts/auto_commit.py b/scripts/auto_commit.py index ee1bd623..5b931162 100644 --- a/scripts/auto_commit.py +++ b/scripts/auto_commit.py @@ -17,7 +17,7 @@ LLM_URL = "https://api.deepseek.com/v1/chat/completions" # 设置你的 Proxy,默认使用HTTPS_PROXY环境变量 -CURL_PROXY = os.getenv('HTTPS_PROXY', '') +HTTP_PROXY = os.getenv('HTTPS_PROXY', '') def get_git_diff(diff_type): try: @@ -44,15 +44,15 @@ def generate_commit_message(diff): "temperature": 0.7, } proxies = { - "https": CURL_PROXY, - } if CURL_PROXY else {} + "https": HTTP_PROXY, + } if HTTP_PROXY else {} try: response = requests.post(LLM_URL, headers=headers, proxies=proxies, json=payload, - timeout=5) + timeout=20) response.raise_for_status() return response.json()['choices'][0]['message']['content'] except requests.RequestException as e: @@ -94,7 +94,7 @@ def main(): return # 提交代码 - print("提交代码...") + print("git commit...") try: subprocess.run(['git', 'commit', '-m', commit_message], check=True) except subprocess.CalledProcessError as e: