From b8d5cd583ed429913c72e8c4aaf6a8cb0d4a38cb Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sun, 17 Dec 2023 19:01:11 +0800 Subject: [PATCH] gemini pro ok (#406) * gemini adapt failed * works * update --- api/chat_main_handler.go | 156 ++++++++++++++++++++++++++++++++- api/main.go | 2 - api/middleware_authenticate.go | 2 +- 3 files changed, 155 insertions(+), 5 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 60ec4e1d..b32b1188 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -264,7 +264,10 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession, isClaude := strings.HasPrefix(model, "claude") isChatGPT := strings.HasPrefix(model, "gpt") isOllama := strings.HasPrefix(model, "ollama-") + isGemini := strings.HasPrefix(model, "gemini") + completionModel := mapset.NewSet[string]() + completionModel.Add(openai.GPT3TextDavinci003) completionModel.Add(openai.GPT3TextDavinci002) isCompletion := completionModel.Contains(model) @@ -280,6 +283,8 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession, chatStreamFn = h.chatOllamStram } else if isCompletion { chatStreamFn = h.CompletionStream + } else if isGemini { + chatStreamFn = h.chatStreamGemini } return chatStreamFn } @@ -693,7 +698,7 @@ type OllamaResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Done bool `json:"done"` - Message Message `json:"message"` + Message Message `json:"message"` TotalDuration int64 `json:"total_duration"` LoadDuration int64 `json:"load_duration"` PromptEvalCount int `json:"prompt_eval_count"` @@ -710,7 +715,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que return "", "", true } jsonData := map[string]interface{}{ - "model": strings.Replace(chatSession.Model, "ollama-", "", 1), + "model": strings.Replace(chatSession.Model, "ollama-", "", 1), "messages": chat_compeletion_messages, } // convert data to json format @@ -999,3 +1004,150 @@ func constructChatCompletionStreamReponse(answer_id string, answer string) opena } return resp } + +// Generated by curl-to-Go: https://mholt.github.io/curl-to-go + +// curl https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$API_KEY \ +// -H 'Content-Type: application/json' \ +// -X POST \ +// -d '{ +// "contents": [{ +// "parts":[{ +// "text": "Write a story about a magic backpack."}]}]}' 2> /dev/null + +func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []Message, chatUuid string, regenerate bool) (string, string, bool) { + type Part struct { + Text string `json:"text"` + } + + type GeminiMessage struct { + Role string `json:"role"` + Parts []Part `json:"parts"` + } + + type Payload struct { + Contents []GeminiMessage `json:"contents"` + } + + payload := Payload{ + Contents: make([]GeminiMessage, len(chat_compeletion_messages)), + } + + for i, message := range chat_compeletion_messages { + println(message.Role) + geminiMessage := GeminiMessage{ + Role: message.Role, + Parts: []Part{ + {Text: message.Content}, + }, + } + + if message.Role == "assistant" { + geminiMessage.Role = "model" + } else if message.Role == "system" { + geminiMessage.Role = "user" + } + + payload.Contents[i] = geminiMessage + } + + payloadBytes, err := json.Marshal(payload) + fmt.Printf("%s\n", string(payloadBytes)) + if err != nil { + fmt.Println("Error marshalling payload:", err) + // handle err + return "", "", true + } + url := os.ExpandEnv("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$GEMINI_API_KEY") + println(url) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes)) + if err != nil { + // handle err + fmt.Println("Error while creating request: ", err) + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemini api").Error(), err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + // handle err + fmt.Println("Error while do request: ", err) + RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemni api").Error(), err) + } + + defer resp.Body.Close() + + setSSEHeader(w) + + flusher, ok := w.(http.Flusher) + if !ok { + RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil) + return "", "", true + } + // respDump, err := httputil.DumpResponse(resp, true) + // if err != nil { + // log.Fatal(err) + // } + + // fmt.Printf("RESPONSE:\n%s", string(respDump)) + // println(resp.Status) + type Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + Role string `json:"role"` + } + + type SafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` + } + + type Candidate struct { + Content Content `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` + SafetyRatings []SafetyRating `json:"safetyRatings"` + } + + type PromptFeedback struct { + SafetyRatings []SafetyRating `json:"safetyRatings"` + } + + type RequestBody struct { + Candidates []Candidate `json:"candidates"` + PromptFeedback PromptFeedback `json:"promptFeedback"` + } + + var requestBody RequestBody + decoder := json.NewDecoder(resp.Body) + if err := decoder.Decode(&requestBody); err != nil { + fmt.Println("Failed to parse request body:", err) + } + + var answer string + answer_id := chatUuid + if !regenerate { + answer_id = uuid.NewString() + } + + // Access the parsed data + for _, candidate := range requestBody.Candidates { + fmt.Println("Finish Reason:", candidate.FinishReason) + fmt.Println("Index:", candidate.Index) + for _, part := range candidate.Content.Parts { + fmt.Println(part.Text) + answer += part.Text + } + + for _, safetyRating := range candidate.SafetyRatings { + fmt.Println("Safety Category:", safetyRating.Category) + fmt.Println("Safety Probability:", safetyRating.Probability) + } + } + + data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer)) + fmt.Fprintf(w, "data: %v\n\n", string(data)) + flusher.Flush() + return answer, answer_id, false + +} diff --git a/api/main.go b/api/main.go index 4ed69a6f..235c6b87 100644 --- a/api/main.go +++ b/api/main.go @@ -240,10 +240,8 @@ func main() { http.Redirect(w, r, "/static/", http.StatusMovedPermanently) }) - router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", makeGzipHandler(cacheHandler))) - // fly.io if os.Getenv("FLY_APP_NAME") != "" { router.Use(UpdateLastRequestTime) diff --git a/api/middleware_authenticate.go b/api/middleware_authenticate.go index 066b33de..24179811 100644 --- a/api/middleware_authenticate.go +++ b/api/middleware_authenticate.go @@ -96,7 +96,7 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler { "/": true, "/login": true, "/signup": true, - "/tts": true, + "/tts": true, } jwtSigningKey := []byte(jwtSecretAndAud.Secret) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {