Skip to content

Commit

Permalink
```
Browse files Browse the repository at this point in the history
Refactor API handlers and improve auto_commit script:
- Increased timeout in auto_commit.py for API requests.
- Renamed OpenAIChatCompletionAPIWithStreamHandler to ChatCompletionHandler.
- Added ChatBotCompletionHandler for bot functionality.
- Updated SQL queries to include ChatSnapshotByUserIdAndUuid.
- Improved error handling and logging in ChatCompletionHandler.
- Fixed proxy variable name in auto_commit.py.
```
  • Loading branch information
swuecho committed Sep 7, 2024
1 parent 303d8cf commit 357e9e5
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 13 deletions.
127 changes: 120 additions & 7 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ func NewChatHandler(sqlc_q *sqlc_queries.Queries) *ChatHandler {
}

func (h *ChatHandler) Register(router *mux.Router) {
router.HandleFunc("/chat_stream", h.OpenAIChatCompletionAPIWithStreamHandler).Methods(http.MethodPost)
router.HandleFunc("/chat_stream", h.ChatCompletionHandler).Methods(http.MethodPost)
// for bot
// given a chat_uuid, a user message, return the answer
//
router.HandleFunc("/chat_bot", h.ChatBotCompletionHandler).Methods(http.MethodPost)
}

type ChatOptions struct {
Uuid string
}
type ChatRequest struct {
Prompt string
SessionUuid string
ChatUuid string
Regenerate bool
Options ChatOptions
}

type ChatCompletionResponse struct {
Expand Down Expand Up @@ -86,8 +86,70 @@ func NewUserMessage(content string) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{Role: "user", Content: content}
}

// OpenAIChatCompletionAPIWithStreamHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE)
func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWriter, r *http.Request) {

type BotRequest struct {
Message string
SnapshotUuid string
Stream bool
}

// ChatCompletionHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE)
func (h *ChatHandler) ChatCompletionHandler(w http.ResponseWriter, r *http.Request) {
var req BotRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
RespondWithError(w, http.StatusBadRequest, err.Error(), err)
return
}

snapshotUuid := req.SnapshotUuid
newQuestion := req.Message


log.Printf("snapshotUuid: %s", snapshotUuid)
log.Printf("newQuestion: %s", newQuestion)

ctx := r.Context()

userID, err := getUserID(ctx)
if err != nil {
RespondWithError(w, http.StatusBadRequest, err.Error(), err)
return
}

fmt.Printf("userID: %d", userID)

chatSnapshot, err := h.service.q.ChatSnapshotByUserIdAndUuid(ctx, sqlc_queries.ChatSnapshotByUserIdAndUuidParams{
UserID: userID,
Uuid: snapshotUuid,
})
if err != nil {
RespondWithError(w, http.StatusBadRequest, eris.Wrap(err, "fail to get chat snapshot").Error(), err)
return
}

fmt.Printf("chatSnapshot: %+v", chatSnapshot)

var session sqlc_queries.ChatSession
err = json.Unmarshal(chatSnapshot.Session, &session)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "fail to deserialize chat session").Error(), err)
return
}
var simpleChatMessages []SimpleChatMessage
err = json.Unmarshal(chatSnapshot.Conversation, &simpleChatMessages)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "fail to deserialize conversation").Error(), err)
return
}

genBotAnswer(h, w, session, simpleChatMessages, newQuestion, userID)

}

// ChatCompletionHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE)
func (h *ChatHandler) ChatBotCompletionHandler(w http.ResponseWriter, r *http.Request) {
var req ChatRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
Expand Down Expand Up @@ -217,6 +279,57 @@ func genAnswer(h *ChatHandler, w http.ResponseWriter, chatSessionUuid string, ch
}
}

func genBotAnswer(h *ChatHandler, w http.ResponseWriter, session sqlc_queries.ChatSession, simpleChatMessages []SimpleChatMessage, newQuestion string, userID int32) {
chatModel, err := h.service.q.ChatModelByName(context.Background(), session.Model)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "get chat model").Error(), err)
return
}

baseURL, _ := getModelBaseUrl(chatModel.Url)

messages := simpleChatMessagesToMessages(simpleChatMessages)
messages = append(messages, models.Message{
Role: "user",
Content: newQuestion,
})
chatStreamFn := h.chooseChatStreamFn(session, messages)

answerText, answerID, shouldReturn := chatStreamFn(w, session, messages, "", false)
if shouldReturn {
return
}

if !isTest(messages) {
h.service.logChat(session, messages, answerText)
}

ctx := context.Background()
if _, err := h.service.CreateChatMessageSimple(ctx, session.Uuid, answerID, "assistant", answerText, userID, baseURL, session.SummarizeMode); err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "failed to create message").Error(), nil)
return
}
}

// Helper function to convert SimpleChatMessage to Message
func simpleChatMessagesToMessages(simpleChatMessages []SimpleChatMessage) []models.Message {
messages := make([]models.Message, len(simpleChatMessages))
for i, scm := range simpleChatMessages {
role := "user"
if scm.Inversion {
role = "assistant"
}
if i == 0 {
role = "system"
}
messages[i] = models.Message{
Role: role,
Content: scm.Text,
}
}
return messages
}

func regenerateAnswer(h *ChatHandler, w http.ResponseWriter, chatSessionUuid string, chatUuid string) {
ctx := context.Background()
chatSession, err := h.service.q.GetChatSessionByUUID(ctx, chatSessionUuid)
Expand Down
2 changes: 1 addition & 1 deletion api/chat_snapshot_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (h *ChatSnapshotHandler) CreateChatSnapshot(w http.ResponseWriter, r *http.
RespondWithError(w, http.StatusInternalServerError, err.Error(), err)
return
}
uuid, err := h.service.CreateChatSnapshot(r.Context(), chatSessionUuid, user_id)
uuid, err := h.service.CreateChatSnapshot(r.Context(), chatSessionUuid, user_id)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, err.Error(), err)
}
Expand Down
2 changes: 2 additions & 0 deletions api/sqlc/queries/chat_snapshot.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ RETURNING *;
-- name: ChatSnapshotByUUID :one
SELECT * FROM chat_snapshot WHERE uuid = $1;

-- name: ChatSnapshotByUserIdAndUuid :one
SELECT * FROM chat_snapshot WHERE user_id = $1 AND uuid = $2;

-- name: ChatSnapshotMetaByUserID :many
SELECT uuid, title, summary, tags, created_at, typ
Expand Down
30 changes: 30 additions & 0 deletions api/sqlc_queries/chat_snapshot.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions scripts/auto_commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
LLM_URL = "https://api.deepseek.com/v1/chat/completions"

# 设置你的 Proxy,默认使用HTTPS_PROXY环境变量
CURL_PROXY = os.getenv('HTTPS_PROXY', '')
HTTP_PROXY = os.getenv('HTTPS_PROXY', '')

def get_git_diff(diff_type):
try:
Expand All @@ -44,15 +44,15 @@ def generate_commit_message(diff):
"temperature": 0.7,
}
proxies = {
"https": CURL_PROXY,
} if CURL_PROXY else {}
"https": HTTP_PROXY,
} if HTTP_PROXY else {}

try:
response = requests.post(LLM_URL,
headers=headers,
proxies=proxies,
json=payload,
timeout=5)
timeout=20)
response.raise_for_status()
return response.json()['choices'][0]['message']['content']
except requests.RequestException as e:
Expand Down Expand Up @@ -94,7 +94,7 @@ def main():
return

# 提交代码
print("提交代码...")
print("git commit...")
try:
subprocess.run(['git', 'commit', '-m', commit_message], check=True)
except subprocess.CalledProcessError as e:
Expand Down

0 comments on commit 357e9e5

Please sign in to comment.