diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 8d523c36..85a33107 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -11,6 +11,7 @@ import ( "io" "log" "net/http" + "net/http/httputil" "os" "strconv" "strings" @@ -698,7 +699,7 @@ type OllamaResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Done bool `json:"done"` - Message Message `json:"message"` + Message Message `json:"message"` TotalDuration int64 `json:"total_duration"` LoadDuration int64 `json:"load_duration"` PromptEvalCount int `json:"prompt_eval_count"` @@ -715,7 +716,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que return "", "", true } jsonData := map[string]interface{}{ - "model": strings.Replace(chatSession.Model, "ollama-", "", 1), + "model": strings.Replace(chatSession.Model, "ollama-", "", 1), "messages": chat_compeletion_messages, } // convert data to json format @@ -1034,6 +1035,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q } for i, message := range chat_compeletion_messages { + println(message.Role) geminiMessage := GeminiMessage{ Role: message.Role, Parts: []Part{ @@ -1044,7 +1046,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q if message.Role == "assistant" { geminiMessage.Role = "model" } else if message.Role == "system" { - geminiMessage.Role = "model" + geminiMessage.Role = "user" } payload.Contents[i] = geminiMessage @@ -1057,11 +1059,9 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q // handle err return "", "", true } - body := bytes.NewBuffer(payloadBytes) url := os.ExpandEnv("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$GEMINI_API_KEY") println(url) - req, err := http.NewRequest("POST", url, body) - + req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes)) if err != nil { // handle err fmt.Println("Error while creating request: ", err) @@ -1077,6 +1077,20 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q defer resp.Body.Close() + setSSEHeader(w) + + flusher, ok := w.(http.Flusher) + if !ok { + RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil) + return "", "", true + } + // respDump, err := httputil.DumpResponse(resp, true) + // if err != nil { + // log.Fatal(err) + // } + + // fmt.Printf("RESPONSE:\n%s", string(respDump)) + // println(resp.Status) type Content struct { Parts []struct { Text string `json:"text"` @@ -1111,13 +1125,19 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q fmt.Println("Failed to parse request body:", err) } + var answer string + answer_id := chatUuid + if !regenerate { + answer_id = uuid.NewString() + } + // Access the parsed data for _, candidate := range requestBody.Candidates { - fmt.Println("Content Role:", candidate.Content.Role) fmt.Println("Finish Reason:", candidate.FinishReason) fmt.Println("Index:", candidate.Index) for _, part := range candidate.Content.Parts { fmt.Println(part.Text) + answer += part.Text } for _, safetyRating := range candidate.SafetyRatings { @@ -1126,18 +1146,9 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q } } - fmt.Println("Prompt Feedback Safety Ratings:") - for _, safetyRating := range requestBody.PromptFeedback.SafetyRatings { - fmt.Println("Safety Category:", safetyRating.Category) - fmt.Println("Safety Probability:", safetyRating.Probability) - } - var answer string - var answer_id string + data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer)) + fmt.Fprintf(w, "data: %v\n\n", string(data)) + flusher.Flush() + return answer, answer_id, false - if regenerate { - answer_id = chatUuid - } - print(answer) - print(answer_id) - return "", "", false }