Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville committed Dec 17, 2024
1 parent 2b7cb50 commit aa5ef57
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 23 deletions.
23 changes: 23 additions & 0 deletions pkg/engine/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,29 @@ type Certs struct {
lock sync.Mutex
}

func GetClientCert() (certs.CertAndKey, error) {
certificates.lock.Lock()
defer certificates.lock.Unlock()
if len(certificates.clientCert.Cert) == 0 {
cert, err := certs.GenerateGPTScriptCert()
if err != nil {
return certs.CertAndKey{}, fmt.Errorf("failed to generate GPTScript certificate: %v", err)
}
certificates.clientCert = cert
}
return certificates.clientCert, nil
}

func GetDaemonCert(toolID string) ([]byte, error) {
certificates.lock.Lock()
defer certificates.lock.Unlock()
cert, exists := certificates.daemonCerts[toolID]
if !exists {
return nil, fmt.Errorf("daemon certificate for [%s] not found", toolID)
}
return cert.Cert, nil
}

func IsDaemonRunning(url string) bool {
ports.daemonLock.Lock()
defer ports.daemonLock.Unlock()
Expand Down
54 changes: 41 additions & 13 deletions pkg/engine/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"slices"
"strings"

"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/types"
)

Expand Down Expand Up @@ -74,22 +75,22 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID)
}

// Create a pool for the certificate to treat as a CA
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(daemonCert.Cert) {
return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID)
}

tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key)
tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert)
if err != nil {
return nil, fmt.Errorf("failed to create client certificate: %v", err)
return nil, err
}
} else if isLocalhostHTTPS(toolURL) {
// This sometimes happens when talking to a model provider
certificates.lock.Lock()
daemonCert, exists := certificates.daemonCerts[tool.ID]
clientCert := certificates.clientCert
certificates.lock.Unlock()

// Create TLS config for use in the HTTP client later
tlsConfigForDaemonRequest = &tls.Config{
Certificates: []tls.Certificate{tlsClientCert},
RootCAs: pool,
InsecureSkipVerify: false,
if exists {
tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert)
if err != nil {
return nil, err
}
}
}

Expand Down Expand Up @@ -185,3 +186,30 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
Result: &s,
}, nil
}

func isLocalhostHTTPS(u string) bool {
parsed, err := url.Parse(u)
if err != nil {
return false
}

return parsed.Scheme == "https" && (parsed.Hostname() == "localhost" || parsed.Hostname() == "127.0.0.1")
}

func getTLSConfig(clientCert certs.CertAndKey, daemonCert []byte) (*tls.Config, error) {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(daemonCert) {
return nil, fmt.Errorf("failed to append daemon certificate")
}

tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key)
if err != nil {
return nil, fmt.Errorf("failed to create client certificate: %v", err)
}

return &tls.Config{
Certificates: []tls.Certificate{tlsClientCert},
RootCAs: pool,
InsecureSkipVerify: false,
}, nil
}
51 changes: 44 additions & 7 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package openai

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"log/slog"
"net/http"
"os"
"slices"
"sort"
Expand All @@ -13,6 +16,7 @@ import (

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/hash"
Expand Down Expand Up @@ -51,13 +55,15 @@ type Client struct {
}

type Options struct {
BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"`
APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"`
OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"`
DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"`
ConfigFile string `usage:"Path to GPTScript config file" name:"config"`
SetSeed bool `usage:"-"`
CacheKey string `usage:"-"`
BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"`
APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"`
OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"`
DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"`
ConfigFile string `usage:"Path to GPTScript config file" name:"config"`
SetSeed bool `usage:"-"`
CacheKey string `usage:"-"`
ClientCert certs.CertAndKey `usage:"-"`
ServerCert []byte `usage:"-"`
Cache *cache.Client
}

Expand All @@ -70,6 +76,14 @@ func Complete(opts ...Options) (result Options) {
result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel)
result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed)
result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey)

if len(opt.ClientCert.Cert) > 0 {
result.ClientCert = opt.ClientCert
}

if len(opt.ServerCert) > 0 {
result.ServerCert = opt.ServerCert
}
}

return result
Expand Down Expand Up @@ -116,6 +130,29 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts
cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL)
cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID)

// Set up for mTLS, if configured.
if opt.ServerCert != nil && len(opt.ClientCert.Cert) > 0 {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(opt.ServerCert) {
return nil, errors.New("failed to append server cert to pool")
}

clientCert, err := tls.X509KeyPair(opt.ClientCert.Cert, opt.ClientCert.Key)
if err != nil {
return nil, err
}

cfg.HTTPClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{clientCert},
RootCAs: pool,
InsecureSkipVerify: false,
},
},
}
}

cacheKeyBase := opt.CacheKey
if cacheKeyBase == "" {
cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL)
Expand Down
18 changes: 15 additions & 3 deletions pkg/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,22 @@ func (c *Client) load(ctx context.Context, toolName string, env ...string) (*ope
return nil, err
}

clientCert, err := engine.GetClientCert()
if err != nil {
return nil, err
}

serverCert, err := engine.GetDaemonCert(prg.EntryToolID)
if err != nil {
return nil, err
}

oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
Cache: c.cache,
CacheKey: prg.EntryToolID,
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
Cache: c.cache,
CacheKey: prg.EntryToolID,
ClientCert: clientCert,
ServerCert: serverCert,
})
if err != nil {
return nil, err
Expand Down

0 comments on commit aa5ef57

Please sign in to comment.