diff --git a/internal/models/openai_embedding.go b/internal/models/openai_embedding.go new file mode 100644 index 0000000000000..dbb568377c648 --- /dev/null +++ b/internal/models/openai_embedding.go @@ -0,0 +1,192 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const ( + TextEmbeddingAda002 string = "text-embedding-ada-002" + TextEmbedding3Small string = "text-embedding-3-small" + TextEmbedding3Large string = "text-embedding-3-large" +) + + +type EmbeddingRequest struct { + // ID of the model to use. + Model string `json:"model"` + + // Input text to embed, encoded as a string. + Input []string `json:"input"` + + // A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + User string `json:"user,omitempty"` + + // The format to return the embeddings in. Can be either float or base64. + EncodingFormat string `json:"encoding_format,omitempty"` + + // The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` +} + +type Usage struct { + // The number of tokens used by the prompt. + PromptTokens int `json:"prompt_tokens"` + + // The total number of tokens used by the request. + TotalTokens int `json:"total_tokens"` +} + + +type EmbeddingData struct { + // The object type, which is always "embedding". + Object string `json:"object"` + + // The embedding vector, which is a list of floats. + Embedding []float32 `json:"embedding"` + + // The index of the embedding in the list of embeddings. + Index int `json:"index"` +} + + +type EmbeddingResponse struct { + // The object type, which is always "list". + Object string `json:"object"` + + // The list of embeddings generated by the model. + Data []EmbeddingData `json:"data"` + + // The name of the model used to generate the embedding. + Model string `json:"model"` + + // The usage information for the request. + Usage Usage `json:"usage"` +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + Param string `json:"param,omitempty"` + Type string `json:"type"` +} + +type EmbedddingError struct { + Error ErrorInfo `json:"error"` +} + +type OpenAIEmbeddingClient struct { + api_key string + uri string + model_name string +} + +func (c *OpenAIEmbeddingClient) Check() error { + if c.model_name != TextEmbeddingAda002 && c.model_name != TextEmbedding3Small && c.model_name != TextEmbedding3Large { + return fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + c.model_name, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) + } + + if c.api_key == "" { + return fmt.Errorf("OpenAI api key is empty") + } + + if c.uri == "" { + return fmt.Errorf("OpenAI embedding uri is empty") + } + return nil +} + + +func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error { + // call openai + resp, err := client.Do(req) + + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + if resp.StatusCode != 200 { + return fmt.Errorf(string(body)) + } + + err = json.Unmarshal(body, &res) + if err != nil { + return err + } + return nil +} + +func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, max_retries int) error { + var err error + for i := 0; i < max_retries; i++ { + err = c.send(client, req, res) + if err == nil { + return nil + } + } + return err +} + +func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, timeout_sec time.Duration) (EmbeddingResponse, error) { + var r EmbeddingRequest + r.Model = c.model_name + r.Input = texts + r.EncodingFormat = "float" + if user != "" { + r.User = user + } + if dim != 0 { + r.Dimensions = dim + } + + var res EmbeddingResponse + data, err := json.Marshal(r) + if err != nil { + return res, err + } + + // call openai + if timeout_sec <= 0 { + timeout_sec = 30 + } + client := &http.Client{ + Timeout: timeout_sec * time.Second, + } + req, err := http.NewRequest("POST" , c.uri, bytes.NewBuffer(data)) + if err != nil { + return res, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", c.api_key) + + err = c.sendWithRetry(client, req, &res, 3) + return res, err + +} diff --git a/internal/models/openai_embedding_test.go b/internal/models/openai_embedding_test.go new file mode 100644 index 0000000000000..788e95e84f1cd --- /dev/null +++ b/internal/models/openai_embedding_test.go @@ -0,0 +1,185 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + // "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + { + c := OpenAIEmbeddingClient{"mock_key", "mock_uri", "unknow_model"} + err := c.Check(); + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := OpenAIEmbeddingClient{"", "mock_uri", TextEmbeddingAda002} + err := c.Check(); + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := OpenAIEmbeddingClient{"mock_key", "", TextEmbedding3Small} + err := c.Check(); + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := OpenAIEmbeddingClient{"mock_key", "mock_uri", TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + } +} + + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + res.Object = "list" + res.Model = TextEmbedding3Small + res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 0, + }, + } + res.Usage = Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret, res) + } +} + + +func TestEmbeddingRetry(t *testing.T) { + var res EmbeddingResponse + res.Object = "list" + res.Model = TextEmbedding3Small + res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 0, + }, + } + res.Usage = Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + + var count = 0 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count < 2 { + count += 1 + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + } + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret, res) + assert.Equal(t, count, 2) + } +} + + +func TestEmbeddingFailed(t *testing.T) { + var count = 0 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count += 1 + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + _, err = c.Embedding([]string{"sentence"}, 0, "", 0) + assert.True(t, err != nil) + assert.Equal(t, count, 3) + } +} + +func TestTimeout(t *testing.T) { + var st = "Doing" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + st = "Done" + w.WriteHeader(http.StatusUnauthorized) + + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + _, err = c.Embedding([]string{"sentence"}, 0, "", 1) + assert.True(t, err != nil) + assert.Equal(t, st, "Doing") + time.Sleep(3 * time.Second) + assert.Equal(t, st, "Done") + } +}