Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemini pro ok #406

Merged
merged 4 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 154 additions & 2 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession,
isClaude := strings.HasPrefix(model, "claude")
isChatGPT := strings.HasPrefix(model, "gpt")
isOllama := strings.HasPrefix(model, "ollama-")
isGemini := strings.HasPrefix(model, "gemini")

completionModel := mapset.NewSet[string]()

completionModel.Add(openai.GPT3TextDavinci003)
completionModel.Add(openai.GPT3TextDavinci002)
isCompletion := completionModel.Contains(model)
Expand All @@ -280,6 +283,8 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession,
chatStreamFn = h.chatOllamStram
} else if isCompletion {
chatStreamFn = h.CompletionStream
} else if isGemini {
chatStreamFn = h.chatStreamGemini
}
return chatStreamFn
}
Expand Down Expand Up @@ -693,7 +698,7 @@ type OllamaResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Done bool `json:"done"`
Message Message `json:"message"`
Message Message `json:"message"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
Expand All @@ -710,7 +715,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que
return "", "", true
}
jsonData := map[string]interface{}{
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"messages": chat_compeletion_messages,
}
// convert data to json format
Expand Down Expand Up @@ -999,3 +1004,150 @@ func constructChatCompletionStreamReponse(answer_id string, answer string) opena
}
return resp
}

// Generated by curl-to-Go: https://mholt.github.io/curl-to-go

// curl https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$API_KEY \
// -H 'Content-Type: application/json' \
// -X POST \
// -d '{
// "contents": [{
// "parts":[{
// "text": "Write a story about a magic backpack."}]}]}' 2> /dev/null

func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []Message, chatUuid string, regenerate bool) (string, string, bool) {
type Part struct {
Text string `json:"text"`
}

type GeminiMessage struct {
Role string `json:"role"`
Parts []Part `json:"parts"`
}

type Payload struct {
Contents []GeminiMessage `json:"contents"`
}

payload := Payload{
Contents: make([]GeminiMessage, len(chat_compeletion_messages)),
}

for i, message := range chat_compeletion_messages {
println(message.Role)
geminiMessage := GeminiMessage{
Role: message.Role,
Parts: []Part{
{Text: message.Content},
},
}

if message.Role == "assistant" {
geminiMessage.Role = "model"
} else if message.Role == "system" {
geminiMessage.Role = "user"
}

payload.Contents[i] = geminiMessage
}

payloadBytes, err := json.Marshal(payload)
fmt.Printf("%s\n", string(payloadBytes))
if err != nil {
fmt.Println("Error marshalling payload:", err)
// handle err
return "", "", true
}
url := os.ExpandEnv("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=$GEMINI_API_KEY")
println(url)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
if err != nil {
// handle err
fmt.Println("Error while creating request: ", err)
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemini api").Error(), err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
// handle err
fmt.Println("Error while do request: ", err)
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "post to gemni api").Error(), err)
}

defer resp.Body.Close()

setSSEHeader(w)

flusher, ok := w.(http.Flusher)
if !ok {
RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil)
return "", "", true
}
// respDump, err := httputil.DumpResponse(resp, true)
// if err != nil {
// log.Fatal(err)
// }

// fmt.Printf("RESPONSE:\n%s", string(respDump))
// println(resp.Status)
type Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
}

type SafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

type Candidate struct {
Content Content `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
SafetyRatings []SafetyRating `json:"safetyRatings"`
}

type PromptFeedback struct {
SafetyRatings []SafetyRating `json:"safetyRatings"`
}

type RequestBody struct {
Candidates []Candidate `json:"candidates"`
PromptFeedback PromptFeedback `json:"promptFeedback"`
}

var requestBody RequestBody
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&requestBody); err != nil {
fmt.Println("Failed to parse request body:", err)
}

var answer string
answer_id := chatUuid
if !regenerate {
answer_id = uuid.NewString()
}

// Access the parsed data
for _, candidate := range requestBody.Candidates {
fmt.Println("Finish Reason:", candidate.FinishReason)
fmt.Println("Index:", candidate.Index)
for _, part := range candidate.Content.Parts {
fmt.Println(part.Text)
answer += part.Text
}

for _, safetyRating := range candidate.SafetyRatings {
fmt.Println("Safety Category:", safetyRating.Category)
fmt.Println("Safety Probability:", safetyRating.Probability)
}
}

data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer))
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()
return answer, answer_id, false

}
2 changes: 0 additions & 2 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,8 @@ func main() {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})


router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", makeGzipHandler(cacheHandler)))


// fly.io
if os.Getenv("FLY_APP_NAME") != "" {
router.Use(UpdateLastRequestTime)
Expand Down
2 changes: 1 addition & 1 deletion api/middleware_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler {
"/": true,
"/login": true,
"/signup": true,
"/tts": true,
"/tts": true,
}
jwtSigningKey := []byte(jwtSecretAndAud.Secret)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down