Skip to content

Commit

Permalink
works
Browse files Browse the repository at this point in the history
  • Loading branch information
swuecho committed Dec 17, 2023
1 parent 3fb5014 commit 0c79092
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"log"
"net/http"
"net/http/httputil"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -698,7 +699,7 @@ type OllamaResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Done bool `json:"done"`
Message Message `json:"message"`
Message Message `json:"message"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
Expand All @@ -715,7 +716,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que
return "", "", true
}
jsonData := map[string]interface{}{
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"messages": chat_compeletion_messages,
}
// convert data to json format
Expand Down Expand Up @@ -1034,6 +1035,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
}

for i, message := range chat_compeletion_messages {
println(message.Role)
geminiMessage := GeminiMessage{
Role: message.Role,
Parts: []Part{
Expand All @@ -1044,7 +1046,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
if message.Role == "assistant" {
geminiMessage.Role = "model"
} else if message.Role == "system" {
geminiMessage.Role = "model"
geminiMessage.Role = "user"
}

payload.Contents[i] = geminiMessage
Expand All @@ -1057,11 +1059,9 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
// handle err
return "", "", true
}
body := bytes.NewBuffer(payloadBytes)
url := os.ExpandEnv("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$GEMINI_API_KEY")
println(url)
req, err := http.NewRequest("POST", url, body)

req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
if err != nil {
// handle err
fmt.Println("Error while creating request: ", err)
Expand All @@ -1077,6 +1077,20 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q

defer resp.Body.Close()

setSSEHeader(w)

flusher, ok := w.(http.Flusher)
if !ok {
RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil)
return "", "", true
}
// respDump, err := httputil.DumpResponse(resp, true)
// if err != nil {
// log.Fatal(err)
// }

// fmt.Printf("RESPONSE:\n%s", string(respDump))
// println(resp.Status)
type Content struct {
Parts []struct {
Text string `json:"text"`
Expand Down Expand Up @@ -1111,13 +1125,19 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
fmt.Println("Failed to parse request body:", err)
}

var answer string
answer_id := chatUuid
if !regenerate {
answer_id = uuid.NewString()
}

// Access the parsed data
for _, candidate := range requestBody.Candidates {
fmt.Println("Content Role:", candidate.Content.Role)
fmt.Println("Finish Reason:", candidate.FinishReason)
fmt.Println("Index:", candidate.Index)
for _, part := range candidate.Content.Parts {
fmt.Println(part.Text)
answer += part.Text
}

for _, safetyRating := range candidate.SafetyRatings {
Expand All @@ -1126,18 +1146,9 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
}
}

fmt.Println("Prompt Feedback Safety Ratings:")
for _, safetyRating := range requestBody.PromptFeedback.SafetyRatings {
fmt.Println("Safety Category:", safetyRating.Category)
fmt.Println("Safety Probability:", safetyRating.Probability)
}
var answer string
var answer_id string
data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer))
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()
return answer, answer_id, false

if regenerate {
answer_id = chatUuid
}
print(answer)
print(answer_id)
return "", "", false
}

0 comments on commit 0c79092

Please sign in to comment.