Skip to content

Commit

Permalink
ollama chat api (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
swuecho authored Dec 17, 2023
1 parent 6527e0a commit 3062bad
Showing 1 changed file with 3 additions and 28 deletions.
31 changes: 3 additions & 28 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,8 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q
type OllamaResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
Context []int `json:"context"`
Message Message `json:"message"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
Expand All @@ -710,33 +709,10 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "get chat model").Error(), err)
return "", "", true
}

// OPENAI_API_KEY

// create a new strings.Builder
// iterate through the messages and format them
// print the user's question
// convert assistant's response to json format
var prompt string
if chatSession.Model == "ollama-neural-chat" {
prompt = formatNeuralChatPrompt(chat_compeletion_messages)
} else if chatSession.Model == "ollama-minstral" {
prompt = formatMinstralPrompt(chat_compeletion_messages)
} else if chatSession.Model =="ollama-openhermes-neural-chat" {
prompt = formatNeuralChatPrompt(chat_compeletion_messages)
} else {
prompt = formatNeuralChatPrompt(chat_compeletion_messages)
}
// create the json data
jsonData := map[string]interface{}{
"prompt": prompt,
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"max_tokens_to_sample": chatSession.MaxTokens,
// "temperature": chatSession.Temperature,
// "stop_sequences": []string{"\n\nHuman:"},
// "stream": true,
"messages": chat_compeletion_messages,
}

// convert data to json format
jsonValue, _ := json.Marshal(jsonData)
// create the request
Expand Down Expand Up @@ -800,7 +776,6 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que
break
}
line, err := ioreader.ReadBytes('\n')

if err != nil {
return "", "", true
}
Expand All @@ -809,7 +784,7 @@ func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_que
if err != nil {
return "", "", true
}
answer += strings.ReplaceAll(streamResp.Response, "<0x0A>", "\n")
answer += strings.ReplaceAll(streamResp.Message.Content, "<0x0A>", "\n")
if streamResp.Done {
// stream.isFinished = true
fmt.Println("DONE break")
Expand Down

0 comments on commit 3062bad

Please sign in to comment.