diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 899b27fc..59c93361 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -40,6 +40,29 @@ type Certs struct { lock sync.Mutex } +func GetClientCert() (certs.CertAndKey, error) { + certificates.lock.Lock() + defer certificates.lock.Unlock() + if len(certificates.clientCert.Cert) == 0 { + cert, err := certs.GenerateGPTScriptCert() + if err != nil { + return certs.CertAndKey{}, fmt.Errorf("failed to generate GPTScript certificate: %v", err) + } + certificates.clientCert = cert + } + return certificates.clientCert, nil +} + +func GetDaemonCert(toolID string) ([]byte, error) { + certificates.lock.Lock() + defer certificates.lock.Unlock() + cert, exists := certificates.daemonCerts[toolID] + if !exists { + return nil, fmt.Errorf("daemon certificate for [%s] not found", toolID) + } + return cert.Cert, nil +} + func IsDaemonRunning(url string) bool { ports.daemonLock.Lock() defer ports.daemonLock.Unlock() diff --git a/pkg/engine/http.go b/pkg/engine/http.go index d8bf0ef2..222d6977 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -14,6 +14,7 @@ import ( "slices" "strings" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/types" ) @@ -74,22 +75,22 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID) } - // Create a pool for the certificate to treat as a CA - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(daemonCert.Cert) { - return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID) - } - - tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key) + tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert) if err != nil { - return nil, fmt.Errorf("failed to create client certificate: %v", err) + return nil, err } + } else if isLocalhostHTTPS(toolURL) { + // This sometimes happens when talking to a model provider + certificates.lock.Lock() + daemonCert, exists := certificates.daemonCerts[tool.ID] + clientCert := certificates.clientCert + certificates.lock.Unlock() - // Create TLS config for use in the HTTP client later - tlsConfigForDaemonRequest = &tls.Config{ - Certificates: []tls.Certificate{tlsClientCert}, - RootCAs: pool, - InsecureSkipVerify: false, + if exists { + tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert) + if err != nil { + return nil, err + } } } @@ -185,3 +186,30 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too Result: &s, }, nil } + +func isLocalhostHTTPS(u string) bool { + parsed, err := url.Parse(u) + if err != nil { + return false + } + + return parsed.Scheme == "https" && (parsed.Hostname() == "localhost" || parsed.Hostname() == "127.0.0.1") +} + +func getTLSConfig(clientCert certs.CertAndKey, daemonCert []byte) (*tls.Config, error) { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(daemonCert) { + return nil, fmt.Errorf("failed to append daemon certificate") + } + + tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate: %v", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsClientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, nil +} diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 1894bdda..0c660b4c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -2,9 +2,12 @@ package openai import ( "context" + "crypto/tls" + "crypto/x509" "errors" "io" "log/slog" + "net/http" "os" "slices" "sort" @@ -13,6 +16,7 @@ import ( openai "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/hash" @@ -51,13 +55,15 @@ type Client struct { } type Options struct { - BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"` - APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"` - OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` - DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"` - ConfigFile string `usage:"Path to GPTScript config file" name:"config"` - SetSeed bool `usage:"-"` - CacheKey string `usage:"-"` + BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"` + APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"` + OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` + DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"` + ConfigFile string `usage:"Path to GPTScript config file" name:"config"` + SetSeed bool `usage:"-"` + CacheKey string `usage:"-"` + ClientCert certs.CertAndKey `usage:"-"` + ServerCert []byte `usage:"-"` Cache *cache.Client } @@ -70,6 +76,14 @@ func Complete(opts ...Options) (result Options) { result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed) result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey) + + if len(opt.ClientCert.Cert) > 0 { + result.ClientCert = opt.ClientCert + } + + if len(opt.ServerCert) > 0 { + result.ServerCert = opt.ServerCert + } } return result @@ -116,6 +130,29 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL) cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID) + // Set up for mTLS, if configured. + if opt.ServerCert != nil && len(opt.ClientCert.Cert) > 0 { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(opt.ServerCert) { + return nil, errors.New("failed to append server cert to pool") + } + + clientCert, err := tls.X509KeyPair(opt.ClientCert.Cert, opt.ClientCert.Key) + if err != nil { + return nil, err + } + + cfg.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, + }, + } + } + cacheKeyBase := opt.CacheKey if cacheKeyBase == "" { cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL) diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 5542372b..e5132d78 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -166,10 +166,22 @@ func (c *Client) load(ctx context.Context, toolName string, env ...string) (*ope return nil, err } + clientCert, err := engine.GetClientCert() + if err != nil { + return nil, err + } + + serverCert, err := engine.GetDaemonCert(prg.EntryToolID) + if err != nil { + return nil, err + } + oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{ - BaseURL: strings.TrimSuffix(url, "/") + "/v1", - Cache: c.cache, - CacheKey: prg.EntryToolID, + BaseURL: strings.TrimSuffix(url, "/") + "/v1", + Cache: c.cache, + CacheKey: prg.EntryToolID, + ClientCert: clientCert, + ServerCert: serverCert, }) if err != nil { return nil, err