From 2fefae3fc5837a0d0a3e42afe1de4ac5ff019c57 Mon Sep 17 00:00:00 2001 From: JINSONG WANG Date: Tue, 18 Jun 2024 15:53:32 -0700 Subject: [PATCH] feat: add watsonx ai provider Signed-off-by: JINSONG WANG --- README.md | 1 + go.mod | 1 + go.sum | 2 ++ pkg/ai/iai.go | 4 ++- pkg/ai/watsonxai.go | 84 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 pkg/ai/watsonxai.go diff --git a/README.md b/README.md index cf9b66b7d9..92e61bad27 100644 --- a/README.md +++ b/README.md @@ -316,6 +316,7 @@ Unused: > huggingface > noopai > googlevertexai +> watsonxai ``` For detailed documentation on how to configure and use each provider see [here](https://docs.k8sgpt.ai/reference/providers/backend/). diff --git a/go.mod b/go.mod index 533075d733..f20d94b064 100644 --- a/go.mod +++ b/go.mod @@ -62,6 +62,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect + github.com/IBM/watsonx-go v1.0.0 // indirect github.com/Microsoft/hcsshim v0.12.4 // indirect github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9 // indirect github.com/anchore/go-struct-converter v0.0.0-20230627203149-c72ef8859ca9 // indirect diff --git a/go.sum b/go.sum index 9c702a3541..9e864c6aac 100644 --- a/go.sum +++ b/go.sum @@ -1245,6 +1245,8 @@ github.com/Code-Hex/go-generics-cache v1.3.1 h1:i8rLwyhoyhaerr7JpjtYjJZUcCbWOdiY github.com/Code-Hex/go-generics-cache v1.3.1/go.mod h1:qxcC9kRVrct9rHeiYpFWSoW1vxyillCVzX13KZG8dl4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/IBM/watsonx-go v1.0.0 h1:xG7xA2W9N0RsiztR26dwBI8/VxIX4wTBhdYmEis2Yl8= +github.com/IBM/watsonx-go v1.0.0/go.mod h1:8lzvpe/158JkrzvcoIcIj6OdNty5iC9co5nQHfkhRtM= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 08caa02079..3e2ee96bde 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -30,6 +30,7 @@ var ( &HuggingfaceClient{}, &GoogleVertexAIClient{}, &OCIGenAIClient{}, + &WatsonxAIClient{}, } Backends = []string{ openAIClientName, @@ -43,6 +44,7 @@ var ( huggingfaceAIClientName, googleVertexAIClientName, ociClientName, + watsonxAIClientName, } ) @@ -170,7 +172,7 @@ func (p *AIProvider) GetOrganizationId() string { return p.OrganizationId } -var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} +var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"} func NeedPassword(backend string) bool { for _, b := range passwordlessProviders { diff --git a/pkg/ai/watsonxai.go b/pkg/ai/watsonxai.go new file mode 100644 index 0000000000..f6ce81c1a6 --- /dev/null +++ b/pkg/ai/watsonxai.go @@ -0,0 +1,84 @@ +package ai + +import ( + "os" + "fmt" + "context" + "errors" + + wx "github.com/IBM/watsonx-go/pkg/models" +) + +const watsonxAIClientName = "watsonxai" + +type WatsonxAIClient struct { + nopCloser + + client *wx.Client + model string + temperature float32 + topP float32 + topK int32 + maxNewTokens int +} + +const ( + modelMetallama = "ibm/granite-13b-chat-v2" +) + +func (c *WatsonxAIClient) Configure(config IAIConfig) error { + if(config.GetModel() == "") { + c.model = config.GetModel() + } else { + c.model = modelMetallama + } + c.temperature = config.GetTemperature() + c.topP = config.GetTopP() + c.topK = config.GetTopK() + c.maxNewTokens = config.GetMaxTokens() + + // WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY" + // WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" + apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName) + + if apiKey == "" { + return errors.New("No watsonx API key provided") + } + if projectID == "" { + return errors.New("No watsonx project ID provided") + } + + client, err := wx.NewClient( + wx.WithWatsonxAPIKey(apiKey), + wx.WithWatsonxProjectID(projectID), + ) + if err != nil { + return fmt.Errorf("Failed to create client for testing. Error: %v", err) + } + c.client = client + + return nil +} + +func (c *WatsonxAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { + result, err := c.client.GenerateText( + c.model, + prompt, + wx.WithTemperature((float64)(c.temperature)), + wx.WithTopP((float64)(c.topP)), + wx.WithTopK((uint)(c.topK)), + wx.WithMaxNewTokens((uint)(c.maxNewTokens)), + ) + if err != nil { + return "", fmt.Errorf("Expected no error, but got an error: %v", err) + } + if result.Text == "" { + return "", errors.New("Expected a result, but got an empty string") + } + + return result.Text, nil +} + +func (c *WatsonxAIClient) GetName() string { + return watsonxAIClientName +}