Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AzureAI #38

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions app/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,30 @@ func NewSetConfigCmd() *cobra.Command {
}

pieces := strings.Split(args[0], "=")
if len(pieces) < 2 {
return errors.New("Invalid usage; set expects an argument in the form <NAME>=<VALUE>")
}
cfgName := pieces[0]
cfgValue := pieces[1]
viper.Set(cfgName, cfgValue)
fConfig := config.GetConfig()

var fConfig *config.Config
switch cfgName {
case "azureOpenAI.deployments":
if len(pieces) != 3 {
return errors.New("Invalid argument; argument is not in the form azureOpenAI.deployments=<model>=<deployment>")
}

d := config.AzureDeployment{
Model: pieces[1],
Deployment: pieces[2],
}

fConfig = config.GetConfig()
config.SetAzureDeployment(fConfig, d)
default:
if len(pieces) < 2 {
return errors.New("Invalid usage; set expects an argument in the form <NAME>=<VALUE>")
}
cfgValue := pieces[1]
viper.Set(cfgName, cfgValue)
fConfig = config.GetConfig()
}

file := viper.ConfigFileUsed()
if file == "" {
Expand Down
2 changes: 1 addition & 1 deletion app/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (a *Agent) completeWithRetries(ctx context.Context, req *v1alpha1.GenerateR
},
}
request := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo0125,
Model: oai.DefaultModel,
Messages: messages,
MaxTokens: 2000,
Temperature: temperature,
Expand Down
19 changes: 19 additions & 0 deletions app/pkg/config/azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package config

func SetAzureDeployment(cfg *Config, d AzureDeployment) {
if cfg.AzureOpenAI == nil {
cfg.AzureOpenAI = &AzureOpenAIConfig{}
}
if cfg.AzureOpenAI.Deployments == nil {
cfg.AzureOpenAI.Deployments = make([]AzureDeployment, 0, 1)
}
// First check if there is a deployment for the model and if there is update it
for i := range cfg.AzureOpenAI.Deployments {
if cfg.AzureOpenAI.Deployments[i].Model == d.Model {
cfg.AzureOpenAI.Deployments[i].Deployment = d.Deployment
return
}
}

cfg.AzureOpenAI.Deployments = append(cfg.AzureOpenAI.Deployments, d)
}
27 changes: 27 additions & 0 deletions app/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type Config struct {
Server ServerConfig `json:"server" yaml:"server"`
Assets AssetConfig `json:"assets" yaml:"assets"`
OpenAI OpenAIConfig `json:"openai" yaml:"openai"`
// AzureOpenAI contains configuration for Azure OpenAI. A non nil value means use Azure OpenAI.
AzureOpenAI *AzureOpenAIConfig `json:"azureOpenAI,omitempty" yaml:"azureOpenAI,omitempty"`
}

// ServerConfig configures the server
Expand Down Expand Up @@ -71,6 +73,31 @@ type OpenAIConfig struct {
APIKeyFile string `json:"apiKeyFile" yaml:"apiKeyFile"`
}

type AzureOpenAIConfig struct {
// APIKeyFile is the path to the file containing the API key
APIKeyFile string `json:"apiKeyFile" yaml:"apiKeyFile"`

// BaseURL is the baseURL for the API.
// This can be obtained using the Azure CLI with the command:
// az cognitiveservices account show \
// --name <myResourceName> \
// --resource-group <myResourceGroupName> \
// | jq -r .properties.endpoint
BaseURL string `json:"baseURL" yaml:"baseURL"`

// Deployments is a list of Azure deployments of various models.
Deployments []AzureDeployment `json:"deployments" yaml:"deployments"`
}

type AzureDeployment struct {
// Deployment is the Azure Deployment name
Deployment string `json:"deployment" yaml:"deployment"`

// Model is the OpenAI name for this model
// This is used to map OpenAI models to Azure deployments
Model string `json:"model" yaml:"model"`
}

type CorsConfig struct {
// AllowedOrigins is a list of origins allowed to make cross-origin requests.
AllowedOrigins []string `json:"allowedOrigins" yaml:"allowedOrigins"`
Expand Down
120 changes: 107 additions & 13 deletions app/pkg/oai/client.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
package oai

import (
"net/url"
"strings"

"github.com/go-logr/zapr"
"go.uber.org/zap"

"github.com/hashicorp/go-retryablehttp"
"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/hydros/pkg/files"
"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
)

const (
DefaultModel = openai.GPT3Dot5Turbo0125

// AzureOpenAIVersion is the version of the Azure OpenAI API to use.
// For a list of versions see:
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
AzureOpenAIVersion = "2024-02-01"
)

// NewClient helper function to create a new OpenAI client from a config
func NewClient(cfg config.Config) (*openai.Client, error) {
if cfg.OpenAI.APIKeyFile == "" {
return nil, errors.New("OpenAI APIKeyFile is required")
}
apiKeyBytes, err := files.Read(cfg.OpenAI.APIKeyFile)
if err != nil {
return nil, errors.Wrapf(err, "could not read OpenAI APIKeyFile: %v", cfg.OpenAI.APIKeyFile)
}
// make sure there is no leading or trailing whitespace
apiKey := strings.TrimSpace(string(apiKeyBytes))

log := zapr.NewLogger(zap.L())
// ************************************************************************
// Setup middleware
// ************************************************************************
Expand All @@ -32,9 +36,99 @@ func NewClient(cfg config.Config) (*openai.Client, error) {
retryClient := retryablehttp.NewClient()
httpClient := retryClient.StandardClient()

clientCfg := openai.DefaultConfig(apiKey)
clientCfg.HTTPClient = httpClient
client := openai.NewClientWithConfig(clientCfg)
var clientConfig openai.ClientConfig
if cfg.AzureOpenAI != nil {
var clientErr error
clientConfig, clientErr = buildAzureConfig(cfg)

if clientErr != nil {
return nil, clientErr
}
} else {
log.Info("Configuring OpenAI client")
apiKey, err := readAPIKey(cfg.OpenAI.APIKeyFile)
if err != nil {
return nil, err
}
clientConfig = openai.DefaultConfig(apiKey)
}
clientConfig.HTTPClient = httpClient
client := openai.NewClientWithConfig(clientConfig)

return client, nil
}

// buildAzureConfig helper function to create a new Azure OpenAI client config
func buildAzureConfig(cfg config.Config) (openai.ClientConfig, error) {
apiKey, err := readAPIKey(cfg.AzureOpenAI.APIKeyFile)
if err != nil {
return openai.ClientConfig{}, err
}
u, err := url.Parse(cfg.AzureOpenAI.BaseURL)
if err != nil {
return openai.ClientConfig{}, errors.Wrapf(err, "could not parse Azure OpenAI BaseURL: %v", cfg.AzureOpenAI.BaseURL)
}

if u.Scheme != "https" {
return openai.ClientConfig{}, errors.Errorf("Azure BaseURL %s is not valid; it must use the scheme https", cfg.AzureOpenAI.BaseURL)
}

// Check that all required models are deployed
required := map[string]bool{
DefaultModel: true,
}

for _, d := range cfg.AzureOpenAI.Deployments {
delete(required, d.Model)
}

if len(required) > 0 {
models := make([]string, 0, len(required))
for m := range required {
models = append(models, m)
}
return openai.ClientConfig{}, errors.Errorf("Missing Azure deployments for for OpenAI models %v; update AzureOpenAIConfig.deployments in your configuration to specify deployments for these models ", strings.Join(models, ", "))
}
log := zapr.NewLogger(zap.L())
log.Info("Configuring Azure OpenAI", "baseURL", cfg.AzureOpenAI.BaseURL, "deployments", cfg.AzureOpenAI.Deployments)
clientConfig := openai.DefaultAzureConfig(apiKey, cfg.AzureOpenAI.BaseURL)
clientConfig.APIVersion = AzureOpenAIVersion
mapper := AzureModelMapper{
modelToDeployment: make(map[string]string),
}
for _, m := range cfg.AzureOpenAI.Deployments {
mapper.modelToDeployment[m.Model] = m.Deployment
}
clientConfig.AzureModelMapperFunc = mapper.Map

return clientConfig, nil
}

// AzureModelMapper maps OpenAI models to Azure deployments
type AzureModelMapper struct {
modelToDeployment map[string]string
}

// Map maps an OpenAI model to an Azure deployment
func (m AzureModelMapper) Map(model string) string {
log := zapr.NewLogger(zap.L())
deployment, ok := m.modelToDeployment[model]
if !ok {
log.Error(errors.Errorf("No AzureAI deployment found for model %v", model), "missing deployment", "model", model)
return "missing-deployment"
}
return deployment
}

func readAPIKey(apiKeyFile string) (string, error) {
if apiKeyFile == "" {
return "", errors.New("APIKeyFile is required")
}
apiKeyBytes, err := files.Read(apiKeyFile)
if err != nil {
return "", errors.Wrapf(err, "could not read APIKeyFile: %v", apiKeyFile)
}
// make sure there is no leading or trailing whitespace
apiKey := strings.TrimSpace(string(apiKeyBytes))
return apiKey, nil
}
45 changes: 45 additions & 0 deletions app/pkg/oai/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package oai

import (
"os"
"testing"

"github.com/jlewi/foyle/app/pkg/config"
)

func Test_BuildAzureAIConfig(t *testing.T) {
f, err := os.CreateTemp("", "key.txt")
if err != nil {
t.Fatalf("Error creating temp file: %v", err)
}
if _, err := f.WriteString("somekey"); err != nil {
t.Fatalf("Error writing to temp file: %v", err)
}

cfg := &config.Config{
AzureOpenAI: &config.AzureOpenAIConfig{
APIKeyFile: f.Name(),
BaseURL: "https://someurl.com",
Deployments: []config.AzureDeployment{
{
Model: DefaultModel,
Deployment: "somedeployment",
},
},
},
}

if err := f.Close(); err != nil {
t.Fatalf("Error closing temp file: %v", err)
}
defer os.Remove(f.Name())

clientConfig, err := buildAzureConfig(*cfg)
if err != nil {
t.Fatalf("Error building Azure config: %v", err)
}

if clientConfig.BaseURL != "https://someurl.com" {
t.Fatalf("Expected BaseURL to be https://someurl.com but got %v", clientConfig.BaseURL)
}
}
Loading
Loading