From 87b56787202f9f2bd0eb07d6edaf0c5807c804fc Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sat, 16 Dec 2023 11:58:32 +0800 Subject: [PATCH 1/3] gemini adapt failed --- api/chat_main_handler.go | 144 ++++++++++++++++++++++++++++++++- api/main.go | 2 - api/middleware_authenticate.go | 2 +- 3 files changed, 144 insertions(+), 4 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 9364f74b..598718ab 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -264,7 +264,10 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession, isClaude := strings.HasPrefix(model, "claude") isChatGPT := strings.HasPrefix(model, "gpt") isOllama := strings.HasPrefix(model, "ollama-") + isGemini := strings.HasPrefix(model, "gemini") + completionModel := mapset.NewSet[string]() + completionModel.Add(openai.GPT3TextDavinci003) completionModel.Add(openai.GPT3TextDavinci002) isCompletion := completionModel.Contains(model) @@ -280,6 +283,8 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession, chatStreamFn = h.chatOllamStram } else if isCompletion { chatStreamFn = h.CompletionStream + } else if isGemini { + chatStreamFn = h.chatStreamGemini } return chatStreamFn } @@ -722,7 +727,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que prompt = formatNeuralChatPrompt(chat_compeletion_messages) } else if chatSession.Model == "ollama-minstral" { prompt = formatMinstralPrompt(chat_compeletion_messages) - } else if chatSession.Model =="ollama-openhermes-neural-chat" { + } else if chatSession.Model == "ollama-openhermes-neural-chat" { prompt = formatNeuralChatPrompt(chat_compeletion_messages) } else { prompt = formatNeuralChatPrompt(chat_compeletion_messages) @@ -1024,3 +1029,140 @@ func constructChatCompletionStreamReponse(answer_id string, answer string) opena } return resp } + +// Generated by curl-to-Go: https://mholt.github.io/curl-to-go + +// curl https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$API_KEY \ +// -H 'Content-Type: application/json' \ +// -X POST \ +// -d '{ +// "contents": [{ +// "parts":[{ +// "text": "Write a story about a magic backpack."}]}]}' 2> /dev/null + +func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []Message, chatUuid string, regenerate bool) (string, string, bool) { + type Part struct { + Text string `json:"text"` + } + + type GeminiMessage struct { + Role string `json:"role"` + Parts []Part `json:"parts"` + } + + type Payload struct { + Contents []GeminiMessage `json:"contents"` + } + + payload := Payload{ + Contents: make([]GeminiMessage, len(chat_compeletion_messages)), + } + + for i, message := range chat_compeletion_messages { + geminiMessage := GeminiMessage{ + Role: message.Role, + Parts: []Part{ + {Text: message.Content}, + }, + } + + if message.Role == "assistant" { + geminiMessage.Role = "model" + } else if message.Role == "system" { + geminiMessage.Role = "model" + } + + payload.Contents[i] = geminiMessage + } + + payloadBytes, err := json.Marshal(payload) + fmt.Printf("%s\n", string(payloadBytes)) + if err != nil { + fmt.Println("Error marshalling payload:", err) + // handle err + return "", "", true + } + body := bytes.NewBuffer(payloadBytes) + url := os.ExpandEnv("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$GEMINI_API_KEY") + println(url) + req, err := http.NewRequest("POST", url, body) + + if err != nil { + // handle err + fmt.Println("Error while creating request: ", err) + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemini api").Error(), err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + // handle err + fmt.Println("Error while do request: ", err) + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemni api").Error(), err) + } + + defer resp.Body.Close() + + type Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + Role string `json:"role"` + } + + type SafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` + } + + type Candidate struct { + Content Content `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` + SafetyRatings []SafetyRating `json:"safetyRatings"` + } + + type PromptFeedback struct { + SafetyRatings []SafetyRating `json:"safetyRatings"` + } + + type RequestBody struct { + Candidates []Candidate `json:"candidates"` + PromptFeedback PromptFeedback `json:"promptFeedback"` + } + + var requestBody RequestBody + decoder := json.NewDecoder(resp.Body) + if err := decoder.Decode(&requestBody); err != nil { + fmt.Println("Failed to parse request body:", err) + } + + // Access the parsed data + for _, candidate := range requestBody.Candidates { + fmt.Println("Content Role:", candidate.Content.Role) + fmt.Println("Finish Reason:", candidate.FinishReason) + fmt.Println("Index:", candidate.Index) + for _, part := range candidate.Content.Parts { + fmt.Println(part.Text) + } + + for _, safetyRating := range candidate.SafetyRatings { + fmt.Println("Safety Category:", safetyRating.Category) + fmt.Println("Safety Probability:", safetyRating.Probability) + } + } + + fmt.Println("Prompt Feedback Safety Ratings:") + for _, safetyRating := range requestBody.PromptFeedback.SafetyRatings { + fmt.Println("Safety Category:", safetyRating.Category) + fmt.Println("Safety Probability:", safetyRating.Probability) + } + var answer string + var answer_id string + + if regenerate { + answer_id = chatUuid + } + print(answer) + print(answer_id) + return "", "", false +} diff --git a/api/main.go b/api/main.go index 4ed69a6f..235c6b87 100644 --- a/api/main.go +++ b/api/main.go @@ -240,10 +240,8 @@ func main() { http.Redirect(w, r, "/static/", http.StatusMovedPermanently) }) - router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", makeGzipHandler(cacheHandler))) - // fly.io if os.Getenv("FLY_APP_NAME") != "" { router.Use(UpdateLastRequestTime) diff --git a/api/middleware_authenticate.go b/api/middleware_authenticate.go index 066b33de..24179811 100644 --- a/api/middleware_authenticate.go +++ b/api/middleware_authenticate.go @@ -96,7 +96,7 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler { "/": true, "/login": true, "/signup": true, - "/tts": true, + "/tts": true, } jwtSigningKey := []byte(jwtSecretAndAud.Secret) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { From 0c79092888f59a70a02ecc56ddb2de8f37b4b0e1 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sun, 17 Dec 2023 18:51:39 +0800 Subject: [PATCH 2/3] works --- api/chat_main_handler.go | 51 ++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 8d523c36..85a33107 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -11,6 +11,7 @@ import ( "io" "log" "net/http" + "net/http/httputil" "os" "strconv" "strings" @@ -698,7 +699,7 @@ type OllamaResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Done bool `json:"done"` - Message Message `json:"message"` + Message Message `json:"message"` TotalDuration int64 `json:"total_duration"` LoadDuration int64 `json:"load_duration"` PromptEvalCount int `json:"prompt_eval_count"` @@ -715,7 +716,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que return "", "", true } jsonData := map[string]interface{}{ - "model": strings.Replace(chatSession.Model, "ollama-", "", 1), + "model": strings.Replace(chatSession.Model, "ollama-", "", 1), "messages": chat_compeletion_messages, } // convert data to json format @@ -1034,6 +1035,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q } for i, message := range chat_compeletion_messages { + println(message.Role) geminiMessage := GeminiMessage{ Role: message.Role, Parts: []Part{ @@ -1044,7 +1046,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q if message.Role == "assistant" { geminiMessage.Role = "model" } else if message.Role == "system" { - geminiMessage.Role = "model" + geminiMessage.Role = "user" } payload.Contents[i] = geminiMessage @@ -1057,11 +1059,9 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q // handle err return "", "", true } - body := bytes.NewBuffer(payloadBytes) url := os.ExpandEnv("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$GEMINI_API_KEY") println(url) - req, err := http.NewRequest("POST", url, body) - + req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes)) if err != nil { // handle err fmt.Println("Error while creating request: ", err) @@ -1077,6 +1077,20 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q defer resp.Body.Close() + setSSEHeader(w) + + flusher, ok := w.(http.Flusher) + if !ok { + RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil) + return "", "", true + } + // respDump, err := httputil.DumpResponse(resp, true) + // if err != nil { + // log.Fatal(err) + // } + + // fmt.Printf("RESPONSE:\n%s", string(respDump)) + // println(resp.Status) type Content struct { Parts []struct { Text string `json:"text"` @@ -1111,13 +1125,19 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q fmt.Println("Failed to parse request body:", err) } + var answer string + answer_id := chatUuid + if !regenerate { + answer_id = uuid.NewString() + } + // Access the parsed data for _, candidate := range requestBody.Candidates { - fmt.Println("Content Role:", candidate.Content.Role) fmt.Println("Finish Reason:", candidate.FinishReason) fmt.Println("Index:", candidate.Index) for _, part := range candidate.Content.Parts { fmt.Println(part.Text) + answer += part.Text } for _, safetyRating := range candidate.SafetyRatings { @@ -1126,18 +1146,9 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q } } - fmt.Println("Prompt Feedback Safety Ratings:") - for _, safetyRating := range requestBody.PromptFeedback.SafetyRatings { - fmt.Println("Safety Category:", safetyRating.Category) - fmt.Println("Safety Probability:", safetyRating.Probability) - } - var answer string - var answer_id string + data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer)) + fmt.Fprintf(w, "data: %v\n\n", string(data)) + flusher.Flush() + return answer, answer_id, false - if regenerate { - answer_id = chatUuid - } - print(answer) - print(answer_id) - return "", "", false } From 3bbc3eda8e5acd16491491c3ccf604a6f6f1fa1f Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sun, 17 Dec 2023 18:57:35 +0800 Subject: [PATCH 3/3] update --- api/chat_main_handler.go | 1 - 1 file changed, 1 deletion(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 85a33107..b32b1188 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -11,7 +11,6 @@ import ( "io" "log" "net/http" - "net/http/httputil" "os" "strconv" "strings"