diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index b0c1ab4b..010c1ace 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -292,9 +292,8 @@ func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.T args = args[1:] } - var ( - stop = func() {} - ) + ctx, cancel := context.WithCancel(ctx) + stop := cancel if strings.TrimSpace(rest) != "" { f, err := os.CreateTemp(env.Getenv("GPTSCRIPT_TMPDIR", envvars), version.ProgramName+requiredFileExtensions[args[0]]) @@ -303,6 +302,7 @@ func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.T } stop = func() { _ = os.Remove(f.Name()) + cancel() } _, err = f.Write([]byte(rest)) diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index f0a1c10c..b7877da3 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -19,7 +19,7 @@ var ports Ports type Ports struct { daemonPorts map[string]int64 - daemonsRunning map[string]struct{} + daemonsRunning map[string]func() daemonLock sync.Mutex startPort, endPort int64 @@ -57,6 +57,17 @@ func CloseDaemons() { ports.daemonWG.Wait() } +func StopDaemon(url string) { + ports.daemonLock.Lock() + defer ports.daemonLock.Unlock() + + if stop := ports.daemonsRunning[url]; stop != nil { + stop() + } + + delete(ports.daemonsRunning, url) +} + func nextPort() int64 { if ports.startPort == 0 { ports.startPort = 10240 @@ -118,7 +129,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { port, ok := ports.daemonPorts[tool.ID] url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path) - if ok { + if ok && ports.daemonsRunning[url] != nil { return url, nil } @@ -172,13 +183,13 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { if ports.daemonPorts == nil { ports.daemonPorts = map[string]int64{} - ports.daemonsRunning = map[string]struct{}{} + ports.daemonsRunning = map[string]func(){} } ports.daemonPorts[tool.ID] = port - ports.daemonsRunning[url] = struct{}{} + ports.daemonsRunning[url] = stop - killedCtx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) + killedCtx, killedCancel := context.WithCancelCause(ctx) + defer killedCancel(nil) ports.daemonWG.Add(1) go func() { @@ -189,7 +200,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { _ = r.Close() _ = w.Close() - cancel(err) + killedCancel(err) stop() ports.daemonLock.Lock() defer ports.daemonLock.Unlock()