Skip to content

Commit

Permalink
enhance: add mTLS between gptscript and daemon tools
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville committed Dec 4, 2024
1 parent e5fe428 commit 868dded
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 10 deletions.
63 changes: 63 additions & 0 deletions pkg/certs/certs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package certs

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"time"
)

// CertAndKey contains an x509 certificate (PEM format) and ECDSA private key (also PEM format)
type CertAndKey struct {
Cert []byte
Key []byte
}

func GenerateGPTScriptCert() (CertAndKey, error) {
return GenerateSelfSignedCert("gptscript server")
}

func GenerateSelfSignedCert(name string) (CertAndKey, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return CertAndKey{}, fmt.Errorf("failed to generate ECDSA key: %v", err)
}

marshalledPrivateKey, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return CertAndKey{}, fmt.Errorf("failed to marshal ECDSA key: %v", err)
}

marshalledPrivateKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: marshalledPrivateKey})

template := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{
CommonName: name,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0), // a year from now
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
IsCA: false,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}

cert, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
if err != nil {
return CertAndKey{}, fmt.Errorf("failed to create certificate: %v", err)
}

certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})

return CertAndKey{Cert: certPEM, Key: marshalledPrivateKeyPEM}, nil
}
58 changes: 55 additions & 3 deletions pkg/engine/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package engine

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"math/rand"
Expand All @@ -11,11 +14,13 @@ import (
"sync"
"time"

"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
)

var ports Ports
var certificates Certs

type Ports struct {
daemonPorts map[string]int64
Expand All @@ -29,6 +34,11 @@ type Ports struct {
daemonWG sync.WaitGroup
}

type Certs struct {
daemonCerts map[string]certs.CertAndKey
daemonLock sync.Mutex
}

func IsDaemonRunning(url string) bool {
ports.daemonLock.Lock()
defer ports.daemonLock.Unlock()
Expand Down Expand Up @@ -117,7 +127,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
tool.Instructions = types.CommandPrefix + instructions

port, ok := ports.daemonPorts[tool.ID]
url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
url := fmt.Sprintf("https://127.0.0.1:%d%s", port, path)
if ok {
return url, nil
}
Expand All @@ -133,11 +143,31 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {

ctx := ports.daemonCtx
port = nextPort()
url = fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path)

// Generate a certificate for the daemon, unless one already exists.
certificates.daemonLock.Lock()
defer certificates.daemonLock.Unlock()
cert, exists := certificates.daemonCerts[tool.ID]
if !exists {
var err error
cert, err = certs.GenerateSelfSignedCert(tool.ID)
if err != nil {
return "", fmt.Errorf("failed to generate certificate for daemon: %v", err)
}

if certificates.daemonCerts == nil {
certificates.daemonCerts = map[string]certs.CertAndKey{}
}
certificates.daemonCerts[tool.ID] = cert
}

cmd, stop, err := e.newCommand(ctx, []string{
fmt.Sprintf("PORT=%d", port),
fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)),
fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)),
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(e.GPTScriptCert.Cert)),
},
tool,
"{}",
Expand Down Expand Up @@ -199,8 +229,30 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
ports.daemonWG.Done()
}()

// Build HTTP client for checking the health of the daemon
clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
if err != nil {
return "", fmt.Errorf("failed to create client certificate: %v", err)
}

pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(cert.Cert) {
return "", fmt.Errorf("failed to append daemon certificate for [%s]", tool.ID)
}

httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{clientCert},
RootCAs: pool,
InsecureSkipVerify: false,
},
},
}

// Check the health of the daemon
for i := 0; i < 120; i++ {
resp, err := http.Get(url)
resp, err := httpClient.Get(url)
if err == nil && resp.StatusCode == http.StatusOK {
go func() {
_, _ = io.ReadAll(resp.Body)
Expand Down
2 changes: 2 additions & 0 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"sync"

"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
Expand All @@ -22,6 +23,7 @@ type RuntimeManager interface {
}

type Engine struct {
GPTScriptCert certs.CertAndKey
Model Model
RuntimeManager RuntimeManager
Env []string
Expand Down
43 changes: 42 additions & 1 deletion pkg/engine/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package engine

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -40,6 +42,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
return nil, err
}

var tlsConfigForDaemonRequest *tls.Config
if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
Expand All @@ -60,6 +63,33 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
}
parsed.Host = toolURLParsed.Host
toolURL = parsed.String()

// Find the certificate corresponding to this daemon tool
certificates.daemonLock.Lock()
daemonCert, exists := certificates.daemonCerts[referencedTool.ID]
certificates.daemonLock.Unlock()

if !exists {
return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID)
}

// Create a pool for the certificate to treat as a CA
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(daemonCert.Cert) {
return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID)
}

clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
if err != nil {
return nil, fmt.Errorf("failed to create client certificate: %v", err)
}

// Create TLS config for use in the HTTP client later
tlsConfigForDaemonRequest = &tls.Config{
Certificates: []tls.Certificate{clientCert},
RootCAs: pool,
InsecureSkipVerify: false,
}
}

if tool.Blocking {
Expand Down Expand Up @@ -112,7 +142,18 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
req.Header.Set("Content-Type", "text/plain")
}

resp, err := http.DefaultClient.Do(req)
var httpClient *http.Client
if tlsConfigForDaemonRequest != nil {
httpClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfigForDaemonRequest,
},
}
} else {
httpClient = http.DefaultClient
}

resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
Expand Down
14 changes: 10 additions & 4 deletions pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/config"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
Expand Down Expand Up @@ -107,7 +108,12 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir(), opts.SystemToolsDir)
}

simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env)
gptscriptCert, err := certs.GenerateGPTScriptCert()
if err != nil {
return nil, err
}

simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env, gptscriptCert)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -140,7 +146,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
opts.Runner.MonitorFactory = monitor.NewConsole(opts.Monitor, monitor.Options{DebugMessages: *opts.Quiet})
}

runner, err := runner.New(registry, credStore, opts.Runner)
runner, err := runner.New(registry, credStore, gptscriptCert, opts.Runner)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -285,8 +291,8 @@ type simpleRunner struct {
env []string
}

func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string) (*simpleRunner, error) {
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, runner.Options{
func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string, gptscriptCert certs.CertAndKey) (*simpleRunner, error) {
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, gptscriptCert, runner.Options{
RuntimeManager: rm,
MonitorFactory: simpleMonitorFactory{},
})
Expand Down
7 changes: 6 additions & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/certs"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/engine"
Expand Down Expand Up @@ -95,9 +96,10 @@ type Runner struct {
credOverrides []string
credStore credentials.CredentialStore
sequential bool
gptscriptCert certs.CertAndKey
}

func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) {
func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCert certs.CertAndKey, opts ...Options) (*Runner, error) {
opt := complete(opts...)

runner := &Runner{
Expand All @@ -109,6 +111,7 @@ func New(client engine.Model, credStore credentials.CredentialStore, opts ...Opt
credStore: credStore,
sequential: opt.Sequential,
auth: opt.Authorizer,
gptscriptCert: gptscriptCert,
}

if opt.StartPort != 0 {
Expand Down Expand Up @@ -411,6 +414,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
Progress: progress,
Env: env,
GPTScriptCert: r.gptscriptCert,
}

callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
Expand Down Expand Up @@ -593,6 +597,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
Progress: progress,
Env: env,
GPTScriptCert: r.gptscriptCert,
}

var contentInput string
Expand Down
6 changes: 5 additions & 1 deletion pkg/tests/tester/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/adrg/xdg"
"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
Expand Down Expand Up @@ -198,7 +199,10 @@ func NewRunner(t *testing.T) *Runner {

rm := runtimes.Default(cacheDir, "")

run, err := runner.New(c, credentials.NoopStore{}, runner.Options{
gptscriptCert, err := certs.GenerateGPTScriptCert()
require.NoError(t, err)

run, err := runner.New(c, credentials.NoopStore{}, gptscriptCert, runner.Options{
Sequential: true,
RuntimeManager: rm,
})
Expand Down

0 comments on commit 868dded

Please sign in to comment.