diff --git a/CHANGELOG.adoc b/CHANGELOG.adoc index 476d9ac..d47b2c9 100644 --- a/CHANGELOG.adoc +++ b/CHANGELOG.adoc @@ -1,6 +1,10 @@ = Changelog :icons: font +== Unreleased + +- Add initial `docker exec` support (#37) + == 0.0.4 - Add support for container exit times (#21) diff --git a/containerd/container.go b/containerd/container.go index d912b68..c17dab9 100644 --- a/containerd/container.go +++ b/containerd/container.go @@ -1,15 +1,13 @@ package containerd import ( - "github.com/containerd/containerd/v2/api/types/task" + "github.com/containerd/containerd/v2/errdefs" "github.com/containerd/containerd/v2/mount" "github.com/containerd/containerd/v2/oci" "github.com/hashicorp/go-multierror" "golang.org/x/sys/unix" "os" - "os/exec" "sync" - "time" ) const unmountFlags = unix.MNT_FORCE @@ -20,28 +18,31 @@ type container struct { bundlePath string rootfs string dnsSocketPath string - io stdio - console *os.File - - mu sync.Mutex - cmd *exec.Cmd - waitblock chan struct{} - status task.Status - exitStatus uint32 - exitedAt time.Time + + mu sync.Mutex + + // primary is the primary process for the container. + // The lifetime of the container is tied to this process. + primary managedProcess + + // auxiliary is a map of additional processes that run in the jail. + auxiliary map[string]*managedProcess } func (c *container) destroy() (retErr error) { - if err := c.io.Close(); err != nil { - retErr = multierror.Append(retErr, err) - } + c.mu.Lock() + defer c.mu.Unlock() - if c.console != nil { - if err := c.console.Close(); err != nil { + for _, p := range c.auxiliary { + if err := p.destroy(); err != nil { retErr = multierror.Append(retErr, err) } } + if err := c.primary.destroy(); err != nil { + retErr = multierror.Append(retErr, err) + } + // Remove socket file to avoid continuity "failed to create irregular file" error during multiple Dockerfile `RUN` steps _ = os.Remove(c.dnsSocketPath) @@ -52,23 +53,23 @@ func (c *container) destroy() (retErr error) { return } -func (c *container) setStatusL(status task.Status) { +func (c *container) getProcessL(execID string) (*managedProcess, error) { c.mu.Lock() defer c.mu.Unlock() - c.status = status + return c.getProcess(execID) } -func (c *container) getStatusL() task.Status { - c.mu.Lock() - defer c.mu.Unlock() +func (c *container) getProcess(execID string) (*managedProcess, error) { + if execID == "" { + return &c.primary, nil + } - return c.status -} + p := c.auxiliary[execID] -func (c *container) getConsoleL() *os.File { - c.mu.Lock() - defer c.mu.Unlock() + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "exec not found: %s", execID) + } - return c.console + return p, nil } diff --git a/containerd/managed_process.go b/containerd/managed_process.go new file mode 100644 index 0000000..e9df76b --- /dev/null +++ b/containerd/managed_process.go @@ -0,0 +1,133 @@ +package containerd + +import ( + "context" + "github.com/containerd/containerd/v2/api/types/task" + "github.com/creack/pty" + "github.com/hashicorp/go-multierror" + "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + "io" + "os" + "os/exec" + "sync" + "syscall" + "time" +) + +type managedProcess struct { + spec *specs.Process + io stdio + console *os.File + mu sync.Mutex + cmd *exec.Cmd + waitblock chan struct{} + status task.Status + exitStatus uint32 + exitedAt time.Time +} + +func (p *managedProcess) getConsoleL() *os.File { + p.mu.Lock() + defer p.mu.Unlock() + + return p.console +} + +func (p *managedProcess) destroy() (retErr error) { + // TODO: Do we care about error? + _ = p.kill(syscall.SIGKILL) + + if err := p.io.Close(); err != nil { + retErr = multierror.Append(retErr, err) + } + + if p.console != nil { + if err := p.console.Close(); err != nil { + retErr = multierror.Append(retErr, err) + } + } + + if p.status != task.Status_STOPPED { + p.status = task.Status_STOPPED + p.exitedAt = time.Now() + p.exitStatus = uint32(syscall.SIGKILL) + } + + return +} + +func (p *managedProcess) kill(signal syscall.Signal) error { + if p.cmd != nil { + if process := p.cmd.Process; p != nil { + return unix.Kill(-process.Pid, signal) + } + } + + return nil +} + +func (p *managedProcess) setup(ctx context.Context, rootfs string, stdin string, stdout string, stderr string) error { + var err error + + p.io, err = setupIO(ctx, stdin, stdout, stderr) + if err != nil { + return err + } + + if len(p.spec.Args) <= 0 { + // TODO: How to handle this properly? + p.spec.Args = []string{"/bin/sh"} + // return fmt.Errorf("args must not be empty") + } + + p.cmd = exec.Command(p.spec.Args[0]) + p.cmd.Args = p.spec.Args + p.cmd.Dir = p.spec.Cwd + p.cmd.Env = p.spec.Env + p.cmd.SysProcAttr = &syscall.SysProcAttr{ + Chroot: rootfs, + Credential: &syscall.Credential{ + Uid: p.spec.User.UID, + Gid: p.spec.User.GID, + }, + } + + return nil +} + +func (p *managedProcess) start() (err error) { + if p.spec.Terminal { + // TODO: I'd like to use containerd/console package instead + // But see https://github.com/containerd/console/issues/79 + var consoleSize *pty.Winsize + if p.spec.ConsoleSize != nil { + consoleSize = &pty.Winsize{ + Cols: uint16(p.spec.ConsoleSize.Width), + Rows: uint16(p.spec.ConsoleSize.Height), + } + } + + p.console, err = pty.StartWithSize(p.cmd, consoleSize) + if err != nil { + return err + } + + go io.Copy(p.console, p.io.stdin) + go io.Copy(p.io.stdout, p.console) + } else { + p.cmd.SysProcAttr.Setpgid = true + p.cmd.Stdin = p.io.stdin + p.cmd.Stdout = p.io.stdout + p.cmd.Stderr = p.io.stderr + + err = p.cmd.Start() + if err != nil { + return err + } + } + + p.status = task.Status_RUNNING + + return nil +} diff --git a/containerd/service.go b/containerd/service.go index f2c6d61..aec7655 100644 --- a/containerd/service.go +++ b/containerd/service.go @@ -17,12 +17,11 @@ import ( "github.com/containerd/containerd/v2/runtime/v2/shim" "github.com/containerd/log" "github.com/containerd/ttrpc" + "github.com/containerd/typeurl/v2" "github.com/creack/pty" - "golang.org/x/sys/unix" - "io" + "github.com/opencontainers/runtime-spec/specs-go" "net" "os" - "os/exec" "path" "path/filepath" "sync" @@ -84,31 +83,32 @@ func (s *service) State(ctx context.Context, request *taskAPI.StateRequest) (*ta log.G(ctx).WithField("request", request).Info("STATE") defer log.G(ctx).Info("STATE_DONE") - if request.ExecID != "" { - return nil, errdefs.ErrNotImplemented + c, err := s.getContainerL(request.ID) + if err != nil { + return nil, err } - c, err := s.getContainerL(request.ID) + p, err := c.getProcessL(request.ExecID) if err != nil { return nil, err } var pid int - if p := c.cmd.Process; p != nil { - pid = p.Pid + if process := p.cmd.Process; process != nil { + pid = process.Pid } return &taskAPI.StateResponse{ ID: request.ID, Bundle: c.bundlePath, Pid: uint32(pid), - Status: c.status, - Stdin: c.io.stdinPath, - Stdout: c.io.stdoutPath, - Stderr: c.io.stderrPath, + Status: p.status, + Stdin: p.io.stdinPath, + Stdout: p.io.stdoutPath, + Stderr: p.io.stderrPath, Terminal: c.spec.Process.Terminal, - ExitedAt: protobuf.ToTimestamp(c.exitedAt), - ExitStatus: c.exitStatus, + ExitedAt: protobuf.ToTimestamp(p.exitedAt), + ExitStatus: p.exitStatus, ExecID: request.ExecID, }, nil } @@ -147,8 +147,12 @@ func (s *service) Create(ctx context.Context, request *taskAPI.CreateTaskRequest bundlePath: request.Bundle, rootfs: rootfs, dnsSocketPath: dnsSocketPath, - waitblock: make(chan struct{}), - status: task.Status_CREATED, + primary: managedProcess{ + spec: spec.Process, + waitblock: make(chan struct{}), + status: task.Status_CREATED, + }, + auxiliary: make(map[string]*managedProcess), } defer func() { @@ -159,29 +163,10 @@ func (s *service) Create(ctx context.Context, request *taskAPI.CreateTaskRequest } }() - c.io, err = setupIO(ctx, request.Stdin, request.Stdout, request.Stderr) - if err != nil { + if err = c.primary.setup(ctx, c.rootfs, request.Stdin, request.Stdout, request.Stderr); err != nil { return nil, err } - if len(spec.Process.Args) <= 0 { - // TODO: How to handle this properly? - spec.Process.Args = []string{"/bin/sh"} - // return nil, fmt.Errorf("args must not be empty") - } - - c.cmd = exec.Command(c.spec.Process.Args[0]) - c.cmd.Args = c.spec.Process.Args - c.cmd.Dir = c.spec.Process.Cwd - c.cmd.Env = c.spec.Process.Env - c.cmd.SysProcAttr = &syscall.SysProcAttr{ - Chroot: c.rootfs, - Credential: &syscall.Credential{ - Uid: c.spec.Process.User.UID, - Gid: c.spec.Process.User.GID, - }, - } - var mounts []mount.Mount for _, m := range request.Rootfs { mm, err := processMount(c.rootfs, m.Type, m.Source, m.Target, m.Options) @@ -209,6 +194,7 @@ func (s *service) Create(ctx context.Context, request *taskAPI.CreateTaskRequest return nil, fmt.Errorf("failed to mount rootfs component: %w", err) } + // TODO: Check if container already exists? s.containers[request.ID] = c s.events <- &events.TaskCreate{ @@ -280,10 +266,6 @@ func (s *service) Start(ctx context.Context, request *taskAPI.StartRequest) (*ta log.G(ctx).WithField("request", request).Info("START") defer log.G(ctx).Info("START_DONE") - if request.ExecID != "" { - return nil, errdefs.ErrNotImplemented - } - s.mu.Lock() defer s.mu.Unlock() @@ -292,89 +274,93 @@ func (s *service) Start(ctx context.Context, request *taskAPI.StartRequest) (*ta return nil, err } - if err = os.MkdirAll(path.Dir(c.dnsSocketPath), 0o755); err != nil { - return nil, err - } - - dnsSocket, err := net.ListenUnix("unix", &net.UnixAddr{Name: c.dnsSocketPath, Net: "unix"}) - if err != nil { - return nil, err - } - - // TODO: We should stop this somehow? - go func() { - for { - con, err := dnsSocket.AcceptUnix() - if err != nil { - return - } - - pipe, err := net.DialUnix("unix", nil, &net.UnixAddr{Name: "/var/run/mDNSResponder", Net: "unix"}) - if err != nil { - return - } - go unixSocketCopy(con, pipe) - go unixSocketCopy(pipe, con) - } - }() - - if c.spec.Process.Terminal { - // TODO: I'd like to use containerd/console package instead - // But see https://github.com/containerd/console/issues/79 - var consoleSize *pty.Winsize - if c.spec.Process.ConsoleSize != nil { - consoleSize = &pty.Winsize{ - Cols: uint16(c.spec.Process.ConsoleSize.Width), - Rows: uint16(c.spec.Process.ConsoleSize.Height), - } - } + c.mu.Lock() + defer c.mu.Unlock() - c.console, err = pty.StartWithSize(c.cmd, consoleSize) - if err != nil { + if request.ExecID == "" { + if err = os.MkdirAll(path.Dir(c.dnsSocketPath), 0o755); err != nil { return nil, err } - go io.Copy(c.console, c.io.stdin) - go io.Copy(c.io.stdout, c.console) - } else { - c.cmd.SysProcAttr.Setpgid = true - c.cmd.Stdin = c.io.stdin - c.cmd.Stdout = c.io.stdout - c.cmd.Stderr = c.io.stderr - - err = c.cmd.Start() + dnsSocket, err := net.ListenUnix("unix", &net.UnixAddr{Name: c.dnsSocketPath, Net: "unix"}) if err != nil { return nil, err } + + // TODO: We should stop this somehow? + go func() { + for { + con, err := dnsSocket.AcceptUnix() + if err != nil { + return + } + + pipe, err := net.DialUnix("unix", nil, &net.UnixAddr{Name: "/var/run/mDNSResponder", Net: "unix"}) + if err != nil { + return + } + go unixSocketCopy(con, pipe) + go unixSocketCopy(pipe, con) + } + }() } - c.setStatusL(task.Status_RUNNING) + p, err := c.getProcess(request.ExecID) + if err != nil { + return nil, err + } - s.events <- &events.TaskStart{ - ContainerID: request.ID, - Pid: uint32(c.cmd.Process.Pid), + if err = p.start(); err != nil { + return nil, err } go func() { - w, _ := wait(c.cmd.Process) + var w *os.ProcessState + + if request.ExecID == "" { + w, _ = wait(p.cmd.Process) + } else { + w, _ = p.cmd.Process.Wait() + } + + p.exitedAt = time.Now() + p.exitStatus = uint32(w.ExitCode()) + p.status = task.Status_STOPPED + + _ = p.io.Close() + + // Madness... + id := request.ID + if request.ExecID != "" { + id = request.ExecID + } - c.exitedAt = time.Now() - c.exitStatus = uint32(w.ExitCode()) - c.setStatusL(task.Status_STOPPED) - _ = c.io.Close() s.events <- &events.TaskExit{ ContainerID: request.ID, - ID: request.ID, + ID: id, Pid: uint32(w.Pid()), - ExitedAt: protobuf.ToTimestamp(c.exitedAt), - ExitStatus: c.exitStatus, + ExitedAt: protobuf.ToTimestamp(p.exitedAt), + ExitStatus: p.exitStatus, } - close(c.waitblock) + close(p.waitblock) }() + if request.ExecID == "" { + s.events <- &events.TaskStart{ + ContainerID: request.ID, + Pid: uint32(p.cmd.Process.Pid), + } + } else { + s.events <- &events.TaskExecStarted{ + ContainerID: request.ID, + ExecID: request.ExecID, + Pid: uint32(p.cmd.Process.Pid), + } + } + return &taskAPI.StartResponse{ - Pid: uint32(c.cmd.Process.Pid), + Pid: uint32(p.cmd.Process.Pid), }, nil } @@ -382,10 +368,6 @@ func (s *service) Delete(ctx context.Context, request *taskAPI.DeleteRequest) (* log.G(ctx).WithField("request", request).Info("DELETE") defer log.G(ctx).Info("DELETE_DONE") - if request.ExecID != "" { - return nil, errdefs.ErrNotImplemented - } - s.mu.Lock() defer s.mu.Unlock() @@ -394,6 +376,26 @@ func (s *service) Delete(ctx context.Context, request *taskAPI.DeleteRequest) (* return nil, err } + if request.ExecID != "" { + c.mu.Lock() + defer c.mu.Unlock() + + p, err := c.getProcess(request.ExecID) + if err != nil { + return nil, err + } + + if err := p.destroy(); err != nil { + log.G(ctx).WithError(err).Warn("failed to destroy exec") + } + delete(c.auxiliary, request.ExecID) + + return &taskAPI.DeleteResponse{ + ExitedAt: protobuf.ToTimestamp(p.exitedAt), + ExitStatus: p.exitStatus, + }, nil + } + if err := c.destroy(); err != nil { log.G(ctx).WithError(err).Warn("failed to cleanup container") } @@ -401,21 +403,21 @@ func (s *service) Delete(ctx context.Context, request *taskAPI.DeleteRequest) (* delete(s.containers, request.ID) var pid uint32 - if p := c.cmd.Process; p != nil { + if p := c.primary.cmd.Process; p != nil { pid = uint32(p.Pid) } s.events <- &events.TaskDelete{ ContainerID: request.ID, - ExitedAt: protobuf.ToTimestamp(c.exitedAt), - ExitStatus: c.exitStatus, + ExitedAt: protobuf.ToTimestamp(c.primary.exitedAt), + ExitStatus: c.primary.exitStatus, ID: request.ID, Pid: pid, } return &taskAPI.DeleteResponse{ - ExitedAt: protobuf.ToTimestamp(c.exitedAt), - ExitStatus: c.exitStatus, + ExitedAt: protobuf.ToTimestamp(c.primary.exitedAt), + ExitStatus: c.primary.exitStatus, Pid: pid, }, nil } @@ -444,47 +446,92 @@ func (s *service) Kill(ctx context.Context, request *taskAPI.KillRequest) (*ptyp log.G(ctx).WithField("request", request).Info("KILL") defer log.G(ctx).Info("KILL_DONE") - if request.ExecID != "" { - return nil, errdefs.ErrNotImplemented - } - c, err := s.getContainerL(request.ID) if err != nil { return nil, err } - if p := c.cmd.Process; p != nil { - _ = unix.Kill(-p.Pid, syscall.Signal(request.Signal)) + p, err := c.getProcessL(request.ExecID) + if err != nil { + return nil, err } + // TODO: Do we care about error here? + _ = p.kill(syscall.Signal(request.Signal)) + return &ptypes.Empty{}, nil } -func (s *service) Exec(ctx context.Context, request *taskAPI.ExecProcessRequest) (*ptypes.Empty, error) { +func (s *service) Exec(ctx context.Context, request *taskAPI.ExecProcessRequest) (_ *ptypes.Empty, retErr error) { log.G(ctx).WithField("request", request).Info("EXEC") - return nil, errdefs.ErrNotImplemented + + specAny, err := typeurl.UnmarshalAny(request.Spec) + if err != nil { + log.G(ctx).WithError(err).Error("failed to unmarshal spec") + return nil, errdefs.ErrInvalidArgument + } + + spec, ok := specAny.(*specs.Process) + if !ok { + log.G(ctx).Error("mismatched type for spec") + return nil, errdefs.ErrInvalidArgument + } + + c, err := s.getContainerL(request.ID) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + aux := &managedProcess{ + spec: spec, + waitblock: make(chan struct{}), + status: task.Status_CREATED, + } + + defer func() { + if retErr != nil { + if err := aux.destroy(); err != nil { + log.G(ctx).WithError(err).Warn("failed to cleanup aux") + } + } + }() + + if err = aux.setup(ctx, c.rootfs, request.Stdin, request.Stdout, request.Stderr); err != nil { + return nil, err + } + + // TODO: Check if aux already exists? + c.auxiliary[request.ExecID] = aux + + s.events <- &events.TaskExecAdded{ + ContainerID: request.ID, + ExecID: request.ExecID, + } + + return &ptypes.Empty{}, nil } func (s *service) ResizePty(ctx context.Context, request *taskAPI.ResizePtyRequest) (*ptypes.Empty, error) { log.G(ctx).WithField("request", request).Info("RESIZEPTY") defer log.G(ctx).Info("RESIZEPTY_DONE") - if request.ExecID != "" { - return nil, errdefs.ErrNotImplemented - } - c, err := s.getContainerL(request.ID) if err != nil { return nil, err } - con := c.getConsoleL() - if con == nil { - return &ptypes.Empty{}, nil + p, err := c.getProcessL(request.ExecID) + if err != nil { + return nil, err } - if err = pty.Setsize(con, &pty.Winsize{Cols: uint16(request.Width), Rows: uint16(request.Height)}); err != nil { - return nil, err + if con := p.getConsoleL(); con != nil { + if err = pty.Setsize(con, &pty.Winsize{Cols: uint16(request.Width), Rows: uint16(request.Height)}); err != nil { + return nil, err + } } return &ptypes.Empty{}, nil @@ -493,17 +540,18 @@ func (s *service) ResizePty(ctx context.Context, request *taskAPI.ResizePtyReque func (s *service) CloseIO(ctx context.Context, request *taskAPI.CloseIORequest) (*ptypes.Empty, error) { log.G(ctx).WithField("request", request).Info("CLOSEIO") - if request.ExecID != "" { - return nil, errdefs.ErrNotImplemented + c, err := s.getContainerL(request.ID) + if err != nil { + return nil, err } - c, err := s.getContainerL(request.ID) + p, err := c.getProcessL(request.ExecID) if err != nil { return nil, err } - if stdin := c.io.stdin; stdin != nil { - stdin.Close() + if stdin := p.io.stdin; stdin != nil { + _ = stdin.Close() } return &ptypes.Empty{}, nil @@ -518,20 +566,21 @@ func (s *service) Wait(ctx context.Context, request *taskAPI.WaitRequest) (*task log.G(ctx).WithField("request", request).Info("WAIT") defer log.G(ctx).Info("WAIT_DONE") - if request.ExecID != "" { - return nil, errdefs.ErrNotImplemented + c, err := s.getContainerL(request.ID) + if err != nil { + return nil, err } - c, err := s.getContainerL(request.ID) + p, err := c.getProcessL(request.ExecID) if err != nil { return nil, err } - <-c.waitblock + <-p.waitblock return &taskAPI.WaitResponse{ - ExitedAt: protobuf.ToTimestamp(c.exitedAt), - ExitStatus: c.exitStatus, + ExitedAt: protobuf.ToTimestamp(p.exitedAt), + ExitStatus: p.exitStatus, }, nil } @@ -546,7 +595,7 @@ func (s *service) Connect(ctx context.Context, request *taskAPI.ConnectRequest) var pid int if c, err := s.getContainerL(request.ID); err == nil { - if p := c.cmd.Process; p != nil { + if p := c.primary.cmd.Process; p != nil { pid = p.Pid } }