From 2b7cb50bdb5d076e3a6211ea0961193f569a8613 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Mon, 16 Dec 2024 16:27:29 -0500 Subject: [PATCH] improvements Signed-off-by: Grant Linville --- pkg/engine/daemon.go | 23 ++++++++++++++++------- pkg/engine/engine.go | 2 -- pkg/engine/http.go | 9 +++++---- pkg/gptscript/gptscript.go | 14 ++++---------- pkg/runner/runner.go | 7 +------ pkg/tests/tester/runner.go | 6 +----- 6 files changed, 27 insertions(+), 34 deletions(-) diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 31f96018..899b27fc 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -36,7 +36,8 @@ type Ports struct { type Certs struct { daemonCerts map[string]certs.CertAndKey - daemonLock sync.Mutex + clientCert certs.CertAndKey + lock sync.Mutex } func IsDaemonRunning(url string) bool { @@ -157,8 +158,8 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path) // Generate a certificate for the daemon, unless one already exists. - certificates.daemonLock.Lock() - defer certificates.daemonLock.Unlock() + certificates.lock.Lock() + defer certificates.lock.Unlock() cert, exists := certificates.daemonCerts[tool.ID] if !exists { var err error @@ -173,12 +174,21 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { certificates.daemonCerts[tool.ID] = cert } + // Set the client certificate if there isn't one already. + if len(certificates.clientCert.Cert) == 0 { + gptscriptCert, err := certs.GenerateGPTScriptCert() + if err != nil { + return "", fmt.Errorf("failed to generate GPTScript certificate: %v", err) + } + certificates.clientCert = gptscriptCert + } + cmd, stop, err := e.newCommand(ctx, []string{ fmt.Sprintf("PORT=%d", port), fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)), fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)), fmt.Sprintf("GPTSCRIPT_PORT=%d", port), - fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(e.GPTScriptCert.Cert)), + fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(certificates.clientCert.Cert)), }, tool, "{}", @@ -241,7 +251,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { }() // Build HTTP client for checking the health of the daemon - clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key) + tlsClientCert, err := tls.X509KeyPair(certificates.clientCert.Cert, certificates.clientCert.Key) if err != nil { return "", fmt.Errorf("failed to create client certificate: %v", err) } @@ -254,7 +264,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { httpClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ - Certificates: []tls.Certificate{clientCert}, + Certificates: []tls.Certificate{tlsClientCert}, RootCAs: pool, InsecureSkipVerify: false, }, @@ -271,7 +281,6 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { }() return url, nil } - _ = resp.Body.Close() select { case <-killedCtx.Done(): return url, fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx)) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 88cb07ae..a195a8b4 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" @@ -23,7 +22,6 @@ type RuntimeManager interface { } type Engine struct { - GPTScriptCert certs.CertAndKey Model Model RuntimeManager RuntimeManager Env []string diff --git a/pkg/engine/http.go b/pkg/engine/http.go index d06c7169..d8bf0ef2 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -65,9 +65,10 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too toolURL = parsed.String() // Find the certificate corresponding to this daemon tool - certificates.daemonLock.Lock() + certificates.lock.Lock() daemonCert, exists := certificates.daemonCerts[referencedTool.ID] - certificates.daemonLock.Unlock() + clientCert := certificates.clientCert + certificates.lock.Unlock() if !exists { return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID) @@ -79,14 +80,14 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID) } - clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key) + tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key) if err != nil { return nil, fmt.Errorf("failed to create client certificate: %v", err) } // Create TLS config for use in the HTTP client later tlsConfigForDaemonRequest = &tls.Config{ - Certificates: []tls.Certificate{clientCert}, + Certificates: []tls.Certificate{tlsClientCert}, RootCAs: pool, InsecureSkipVerify: false, } diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index cac519a8..dfb1771a 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -12,7 +12,6 @@ import ( "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" - "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/config" context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" @@ -108,12 +107,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) { opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir(), opts.SystemToolsDir) } - gptscriptCert, err := certs.GenerateGPTScriptCert() - if err != nil { - return nil, err - } - - simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env, gptscriptCert) + simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env) if err != nil { return nil, err } @@ -146,7 +140,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) { opts.Runner.MonitorFactory = monitor.NewConsole(opts.Monitor, monitor.Options{DebugMessages: *opts.Quiet}) } - runner, err := runner.New(registry, credStore, gptscriptCert, opts.Runner) + runner, err := runner.New(registry, credStore, opts.Runner) if err != nil { return nil, err } @@ -291,8 +285,8 @@ type simpleRunner struct { env []string } -func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string, gptscriptCert certs.CertAndKey) (*simpleRunner, error) { - runner, err := runner.New(noopModel{}, credentials.NoopStore{}, gptscriptCert, runner.Options{ +func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string) (*simpleRunner, error) { + runner, err := runner.New(noopModel{}, credentials.NoopStore{}, runner.Options{ RuntimeManager: rm, MonitorFactory: simpleMonitorFactory{}, }) diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 931ab99b..fc5737ef 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gptscript-ai/gptscript/pkg/builtin" - "github.com/gptscript-ai/gptscript/pkg/certs" context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" @@ -96,10 +95,9 @@ type Runner struct { credOverrides []string credStore credentials.CredentialStore sequential bool - gptscriptCert certs.CertAndKey } -func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCert certs.CertAndKey, opts ...Options) (*Runner, error) { +func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) { opt := complete(opts...) runner := &Runner{ @@ -111,7 +109,6 @@ func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCe credStore: credStore, sequential: opt.Sequential, auth: opt.Authorizer, - gptscriptCert: gptscriptCert, } if opt.StartPort != 0 { @@ -414,7 +411,6 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, - GPTScriptCert: r.gptscriptCert, } callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause) @@ -597,7 +593,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, - GPTScriptCert: r.gptscriptCert, } var contentInput string diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index 22095270..44ec4e3c 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/adrg/xdg" - "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" @@ -199,10 +198,7 @@ func NewRunner(t *testing.T) *Runner { rm := runtimes.Default(cacheDir, "") - gptscriptCert, err := certs.GenerateGPTScriptCert() - require.NoError(t, err) - - run, err := runner.New(c, credentials.NoopStore{}, gptscriptCert, runner.Options{ + run, err := runner.New(c, credentials.NoopStore{}, runner.Options{ Sequential: true, RuntimeManager: rm, })