Skip to content

Commit

Permalink
Add vertexai
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Dec 18, 2024
1 parent e591352 commit e7e5e7c
Show file tree
Hide file tree
Showing 21 changed files with 892 additions and 57 deletions.
6 changes: 3 additions & 3 deletions internal/models/openai/openai_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string,
params.Add("api-version", c.apiVersion)
base.RawQuery = params.Encode()

ctx, cancel := context.WithTimeout(context.Background(), timeoutSec)
ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, base.String(), bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
req.Header.Set("api-key", c.apiKey)
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
Expand Down
62 changes: 60 additions & 2 deletions internal/models/openai/openai_embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,20 @@ func TestEmbeddingOK(t *testing.T) {
}

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if r.URL.Path == "/" {
if r.Header["Authorization"][0] != "" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusBadRequest)
}
} else {
if r.Header["Api-Key"][0] != "" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusBadRequest)
}
}

data, _ := json.Marshal(res)
w.Write(data)
}))
Expand All @@ -84,7 +97,15 @@ func TestEmbeddingOK(t *testing.T) {
c := NewOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
assert.Equal(t, ret.Data[1].Index, 1)
}
{
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
Expand Down Expand Up @@ -148,6 +169,20 @@ func TestEmbeddingRetry(t *testing.T) {
assert.Equal(t, ret.Data[2], res.Data[0])
assert.Equal(t, atomic.LoadInt32(&count), int32(2))
}
{
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Usage, res.Usage)
assert.Equal(t, ret.Object, res.Object)
assert.Equal(t, ret.Model, res.Model)
assert.Equal(t, ret.Data[0], res.Data[1])
assert.Equal(t, ret.Data[1], res.Data[2])
assert.Equal(t, ret.Data[2], res.Data[0])
assert.Equal(t, atomic.LoadInt32(&count), int32(2))
}
}

func TestEmbeddingFailed(t *testing.T) {
Expand All @@ -161,13 +196,23 @@ func TestEmbeddingFailed(t *testing.T) {
url := ts.URL

{
atomic.StoreInt32(&count, 0)
c := NewOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err != nil)
assert.Equal(t, atomic.LoadInt32(&count), int32(3))
}
{
atomic.StoreInt32(&count, 0)
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err != nil)
assert.Equal(t, atomic.LoadInt32(&count), int32(3))
}
}

func TestTimeout(t *testing.T) {
Expand All @@ -182,6 +227,7 @@ func TestTimeout(t *testing.T) {
url := ts.URL

{
atomic.StoreInt32(&st, 0)
c := NewOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
Expand All @@ -191,4 +237,16 @@ func TestTimeout(t *testing.T) {
time.Sleep(3 * time.Second)
assert.Equal(t, atomic.LoadInt32(&st), int32(1))
}

{
atomic.StoreInt32(&st, 0)
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1)
assert.True(t, err != nil)
assert.Equal(t, atomic.LoadInt32(&st), int32(0))
time.Sleep(3 * time.Second)
assert.Equal(t, atomic.LoadInt32(&st), int32(1))
}
}
6 changes: 4 additions & 2 deletions internal/models/utils/embedding_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ func send(req *http.Request) ([]byte, error) {
}

func RetrySend(req *http.Request, maxRetries int) ([]byte, error) {
var err error
var res []byte
for i := 0; i < maxRetries; i++ {
res, err := send(req)
res, err = send(req)
if err == nil {
return res, nil
}
}
return nil, nil
return nil, err
}
163 changes: 163 additions & 0 deletions internal/models/vertexai/vertexai_text_embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package vertexai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/milvus-io/milvus/internal/models/utils"

"golang.org/x/oauth2/google"
)

type Instance struct {
TaskType string `json:"task_type,omitempty"`
Content string `json:"content"`
}

type Parameters struct {
OutputDimensionality int64 `json:"outputDimensionality,omitempty"`
}

type EmbeddingRequest struct {
Instances []Instance `json:"instances"`
Parameters Parameters `json:"parameters,omitempty"`
}

type Statistics struct {
Truncated bool `json:"truncated"`
TokenCount int `json:"token_count"`
}

type Embeddings struct {
Statistics Statistics `json:"statistics"`
Values []float32 `json:"values"`
}

type Prediction struct {
Embeddings Embeddings `json:"embeddings"`
}

type Metadata struct {
BillableCharacterCount int `json:"billableCharacterCount"`
}

type EmbeddingResponse struct {
Predictions []Prediction `json:"predictions"`
Metadata Metadata `json:"metadata"`
}

type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}

type VertexAIEmbedding struct {
url string
jsonKey []byte
scopes string
token string
}

func NewVertexAIEmbedding(url string, jsonKey []byte, scopes string, token string) *VertexAIEmbedding {
return &VertexAIEmbedding{
url: url,
jsonKey: jsonKey,
scopes: scopes,
token: token,
}
}

func (c *VertexAIEmbedding) Check() error {
if c.url == "" {
return fmt.Errorf("VertexAI embedding url is empty")
}
if len(c.jsonKey) == 0 {
return fmt.Errorf("jsonKey is empty")
}
if c.scopes == "" {
return fmt.Errorf("Scopes param is empty")
}
return nil
}

func (c *VertexAIEmbedding) getAccessToken() (string, error) {
ctx := context.Background()
creds, err := google.CredentialsFromJSON(ctx, c.jsonKey, c.scopes)
if err != nil {
return "", fmt.Errorf("Failed to find credentials: %v", err)
}
token, err := creds.TokenSource.Token()
if err != nil {
return "", fmt.Errorf("Failed to get token: %v", err)
}
return token.AccessToken, nil
}

func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec time.Duration) (*EmbeddingResponse, error) {
var r EmbeddingRequest
for _, text := range texts {
r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text})
}
if dim != 0 {
r.Parameters.OutputDimensionality = dim
}

data, err := json.Marshal(r)
if err != nil {
return nil, err
}

if timeoutSec <= 0 {
timeoutSec = 30
}

ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
var token string
if c.token != "" {
token = c.token
} else {
token, err = c.getAccessToken()
if err != nil {
return nil, err
}
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}
var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
return nil, err
}
return &res, err
}
90 changes: 90 additions & 0 deletions internal/models/vertexai/vertexai_text_embedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package vertexai

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestEmbeddingClientCheck(t *testing.T) {
mockJsonKey := []byte{1, 2, 3}
{
c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := NewVertexAIEmbedding("", mockJsonKey, "", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := NewVertexAIEmbedding("mock_url", mockJsonKey, "mock_scopes", "")
err := c.Check()
assert.True(t, err == nil)
}
}

func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
repStr := `{"predictions": [{"embeddings": {"statistics": {"truncated": false, "token_count": 4}, "values": [-0.028420744463801384, 0.037183016538619995]}}, {"embeddings": {"statistics": {"truncated": false, "token_count": 8}, "values": [-0.04367655888199806, 0.03777721896767616, 0.0158217903226614]}}], "metadata": {"billableCharacterCount": 27}}`
err := json.Unmarshal([]byte(repStr), &res)
assert.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))

defer ts.Close()
url := ts.URL

{
c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token")
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-005", []string{"sentence"}, 0, "query", 0)
assert.True(t, err == nil)
}
}

func TestEmbeddingFailed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))

defer ts.Close()
url := ts.URL

{
c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token")
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", 0)
assert.True(t, err != nil)
}
}
Loading

0 comments on commit e7e5e7c

Please sign in to comment.