Skip to content

Commit

Permalink
gemini pro ok (#406)
Browse files Browse the repository at this point in the history
* gemini adapt failed

* works

* update
  • Loading branch information
swuecho authored Dec 17, 2023
1 parent 3062bad commit b8d5cd5
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 5 deletions.
156 changes: 154 additions & 2 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession,
isClaude := strings.HasPrefix(model, "claude")
isChatGPT := strings.HasPrefix(model, "gpt")
isOllama := strings.HasPrefix(model, "ollama-")
isGemini := strings.HasPrefix(model, "gemini")

completionModel := mapset.NewSet[string]()

completionModel.Add(openai.GPT3TextDavinci003)
completionModel.Add(openai.GPT3TextDavinci002)
isCompletion := completionModel.Contains(model)
Expand All @@ -280,6 +283,8 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession,
chatStreamFn = h.chatOllamStram
} else if isCompletion {
chatStreamFn = h.CompletionStream
} else if isGemini {
chatStreamFn = h.chatStreamGemini
}
return chatStreamFn
}
Expand Down Expand Up @@ -693,7 +698,7 @@ type OllamaResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Done bool `json:"done"`
Message Message `json:"message"`
Message Message `json:"message"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
Expand All @@ -710,7 +715,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que
return "", "", true
}
jsonData := map[string]interface{}{
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"messages": chat_compeletion_messages,
}
// convert data to json format
Expand Down Expand Up @@ -999,3 +1004,150 @@ func constructChatCompletionStreamReponse(answer_id string, answer string) opena
}
return resp
}

// Generated by curl-to-Go: https://mholt.github.io/curl-to-go

// curl https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$API_KEY \
// -H 'Content-Type: application/json' \
// -X POST \
// -d '{
// "contents": [{
// "parts":[{
// "text": "Write a story about a magic backpack."}]}]}' 2> /dev/null

func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []Message, chatUuid string, regenerate bool) (string, string, bool) {
type Part struct {
Text string `json:"text"`
}

type GeminiMessage struct {
Role string `json:"role"`
Parts []Part `json:"parts"`
}

type Payload struct {
Contents []GeminiMessage `json:"contents"`
}

payload := Payload{
Contents: make([]GeminiMessage, len(chat_compeletion_messages)),
}

for i, message := range chat_compeletion_messages {
println(message.Role)
geminiMessage := GeminiMessage{
Role: message.Role,
Parts: []Part{
{Text: message.Content},
},
}

if message.Role == "assistant" {
geminiMessage.Role = "model"
} else if message.Role == "system" {
geminiMessage.Role = "user"
}

payload.Contents[i] = geminiMessage
}

payloadBytes, err := json.Marshal(payload)
fmt.Printf("%s\n", string(payloadBytes))
if err != nil {
fmt.Println("Error marshalling payload:", err)
// handle err
return "", "", true
}
url := os.ExpandEnv("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$GEMINI_API_KEY")
println(url)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
if err != nil {
// handle err
fmt.Println("Error while creating request: ", err)
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemini api").Error(), err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
// handle err
fmt.Println("Error while do request: ", err)
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemni api").Error(), err)
}

defer resp.Body.Close()

setSSEHeader(w)

flusher, ok := w.(http.Flusher)
if !ok {
RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil)
return "", "", true
}
// respDump, err := httputil.DumpResponse(resp, true)
// if err != nil {
// log.Fatal(err)
// }

// fmt.Printf("RESPONSE:\n%s", string(respDump))
// println(resp.Status)
type Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
}

type SafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

type Candidate struct {
Content Content `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
SafetyRatings []SafetyRating `json:"safetyRatings"`
}

type PromptFeedback struct {
SafetyRatings []SafetyRating `json:"safetyRatings"`
}

type RequestBody struct {
Candidates []Candidate `json:"candidates"`
PromptFeedback PromptFeedback `json:"promptFeedback"`
}

var requestBody RequestBody
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&requestBody); err != nil {
fmt.Println("Failed to parse request body:", err)
}

var answer string
answer_id := chatUuid
if !regenerate {
answer_id = uuid.NewString()
}

// Access the parsed data
for _, candidate := range requestBody.Candidates {
fmt.Println("Finish Reason:", candidate.FinishReason)
fmt.Println("Index:", candidate.Index)
for _, part := range candidate.Content.Parts {
fmt.Println(part.Text)
answer += part.Text
}

for _, safetyRating := range candidate.SafetyRatings {
fmt.Println("Safety Category:", safetyRating.Category)
fmt.Println("Safety Probability:", safetyRating.Probability)
}
}

data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer))
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()
return answer, answer_id, false

}
2 changes: 0 additions & 2 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,8 @@ func main() {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})


router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", makeGzipHandler(cacheHandler)))


// fly.io
if os.Getenv("FLY_APP_NAME") != "" {
router.Use(UpdateLastRequestTime)
Expand Down
2 changes: 1 addition & 1 deletion api/middleware_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler {
"/": true,
"/login": true,
"/signup": true,
"/tts": true,
"/tts": true,
}
jwtSigningKey := []byte(jwtSecretAndAud.Secret)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit b8d5cd5

Please sign in to comment.