Skip to content

Commit

Permalink
basic ollam support (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
swuecho authored Dec 1, 2023
1 parent d736f91 commit 8b4ea0b
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession,
isTestChat := isTest(msgs)
isClaude := strings.HasPrefix(model, "claude")
isChatGPT := strings.HasPrefix(model, "gpt")
isOllama := strings.HasPrefix(model, "ollama-")
completionModel := mapset.NewSet[string]()
completionModel.Add(openai.GPT3TextDavinci003)
completionModel.Add(openai.GPT3TextDavinci002)
Expand All @@ -275,6 +276,8 @@ func (h *ChatHandler) chooseChatStreamFn(chat_session sqlc_queries.ChatSession,
chatStreamFn = h.chatStreamTest
} else if isChatGPT {
chatStreamFn = h.chatStream
} else if isOllama {
chatStreamFn = h.chatOllamStram
} else if isCompletion {
chatStreamFn = h.CompletionStream
}
Expand Down Expand Up @@ -686,6 +689,140 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q
return answer, answer_id, false
}

type OllamaResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
Context []int `json:"context"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration int64 `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration int64 `json:"eval_duration"`
}

func (h *ChatHandler) chatOllamStram(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []Message, chatUuid string, regenerate bool) (string, string, bool) {
// set the api key
chatModel, err := h.service.q.ChatModelByName(context.Background(), chatSession.Model)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "get chat model").Error(), err)
return "", "", true
}

// OPENAI_API_KEY

// create a new strings.Builder
// iterate through the messages and format them
// print the user's question
// convert assistant's response to json format
prompt := formatClaudePrompt(chat_compeletion_messages)
// create the json data
jsonData := map[string]interface{}{
"prompt": prompt,
"model": strings.Replace(chatSession.Model, "ollama-", "", 1),
"max_tokens_to_sample": chatSession.MaxTokens,
// "temperature": chatSession.Temperature,
// "stop_sequences": []string{"\n\nHuman:"},
// "stream": true,
}

// convert data to json format
jsonValue, _ := json.Marshal(jsonData)
// create the request
req, err := http.NewRequest("POST", chatModel.Url, bytes.NewBuffer(jsonValue))

if err != nil {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_make_request", err)
return "", "", true
}

// add headers to the request
apiKey := os.Getenv(chatModel.ApiAuthKey)
authHeaderName := chatModel.ApiAuthHeader
if authHeaderName != "" {
req.Header.Set(authHeaderName, apiKey)
}

req.Header.Set("Content-Type", "application/json")

// set the streaming flag
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

// create the http client and send the request
client := &http.Client{
Timeout: 2 * time.Minute,
}
resp, err := client.Do(req)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_do_request", err)
return "", "", true
}

ioreader := bufio.NewReader(resp.Body)

// read the response body
defer resp.Body.Close()
// loop over the response body and print data

setSSEHeader(w)

flusher, ok := w.(http.Flusher)
if !ok {
RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil)
return "", "", true
}

var answer string
var answer_id string

if regenerate {
answer_id = chatUuid
}

count := 0
for {
count++
// prevent infinite loop
if count > 10000 {
break
}
line, err := ioreader.ReadBytes('\n')

if err != nil {
return "", "", true
}
var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return "", "", true
}
answer += streamResp.Response
if streamResp.Done {
// stream.isFinished = true
fmt.Println("DONE break")
data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer))
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()
break
}
if answer_id == "" {
answer_id = uuid.NewString()
}

if len(answer) < 200 || len(answer)%2 == 0 {
data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer))
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()
}
}

return answer, answer_id, false
}

type CustomModelResponse struct {
Completion string `json:"completion"`
Stop string `json:"stop"`
Expand Down

0 comments on commit 8b4ea0b

Please sign in to comment.