diff --git a/.spdxignore b/.spdxignore index c089d6d0f4..82455009bd 100644 --- a/.spdxignore +++ b/.spdxignore @@ -10,3 +10,4 @@ pkg/rngd/cmd/rngd/vendor/ pkg/wwan/mmagent/vendor/ tools/get-deps/vendor/ pkg/installer/vendor/ +pkg/vtpm/swtpm-vtpm/vendor/ diff --git a/pkg/pillar/cmd/domainmgr/domainmgr.go b/pkg/pillar/cmd/domainmgr/domainmgr.go index fcfabd3e86..5cf87227bb 100644 --- a/pkg/pillar/cmd/domainmgr/domainmgr.go +++ b/pkg/pillar/cmd/domainmgr/domainmgr.go @@ -1119,21 +1119,29 @@ func maybeRetryBoot(ctx *domainContext, status *types.DomainStatus) { } defer file.Close() - if err := hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime); err != nil { - log.Errorf("Failed to setup virtual TPM for %s: %s", status.DomainName, err) - status.VirtualTPM = false - } else { + wp := &types.WatchdogParam{Ps: ctx.ps, AgentName: agentName, WarnTime: warningTime, ErrTime: errorTime} + err = hyper.Task(status).VirtualTPMSetup(status.DomainName, wp) + if err == nil { status.VirtualTPM = true + defer func(status *types.DomainStatus, wp *types.WatchdogParam) { + // this means we failed to boot the VM. + if !status.Activated { + log.Noticef("Failed to activate domain: %s, terminating vTPM", status.DomainName) + if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName, wp); err != nil { + // this is not a critical failure so just log it + log.Errorf("Failed to terminate vTPM for %s: %s", status.DomainName, err) + } + } + }(status, wp) + } else { + status.VirtualTPM = false + log.Errorf("Failed to setup vTPM for %s: %s", status.DomainName, err) } if err := hyper.Task(status).Setup(*status, *config, ctx.assignableAdapters, nil, file); err != nil { //it is retry, so omit error log.Errorf("Failed to create DomainStatus from %+v: %s", config, err) - - if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { - log.Errorf("Failed to terminate virtual TPM for %s: %s", status.DomainName, err) - } } status.TriedCount++ @@ -1671,11 +1679,23 @@ func doActivate(ctx *domainContext, config types.DomainConfig, } defer file.Close() - if err := hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime); err != nil { - log.Errorf("Failed to setup virtual TPM for %s: %s", status.DomainName, err) - status.VirtualTPM = false - } else { + wp := &types.WatchdogParam{Ps: ctx.ps, AgentName: agentName, WarnTime: warningTime, ErrTime: errorTime} + err = hyper.Task(status).VirtualTPMSetup(status.DomainName, wp) + if err == nil { status.VirtualTPM = true + defer func(status *types.DomainStatus, wp *types.WatchdogParam) { + // this means we failed to boot the VM. + if !status.Activated { + log.Noticef("Failed to activate domain: %s, terminating vTPM", status.DomainName) + if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName, wp); err != nil { + // this is not a critical failure so just log it + log.Errorf("Failed to terminate vTPM for %s: %s", status.DomainName, err) + } + } + }(status, wp) + } else { + status.VirtualTPM = false + log.Errorf("Failed to setup vTPM for %s: %s", status.DomainName, err) } globalConfig := agentlog.GetGlobalConfig(log, ctx.subGlobalConfig) @@ -1684,11 +1704,6 @@ func doActivate(ctx *domainContext, config types.DomainConfig, config, err) status.SetErrorNow(err.Error()) releaseCPUs(ctx, &config, status) - - if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { - log.Errorf("Failed to terminate virtual TPM for %s: %s", status.DomainName, err) - } - return } @@ -2472,7 +2487,8 @@ func handleDelete(ctx *domainContext, key string, status *types.DomainStatus) { log.Errorln(err) } - if err := hyper.Task(status).VirtualTPMTeardown(status.DomainName); err != nil { + wp := &types.WatchdogParam{Ps: ctx.ps, AgentName: agentName, WarnTime: warningTime, ErrTime: errorTime} + if err := hyper.Task(status).VirtualTPMTeardown(status.DomainName, wp); err != nil { log.Errorln(err) } diff --git a/pkg/pillar/hypervisor/containerd.go b/pkg/pillar/hypervisor/containerd.go index fc8fbcc7c2..1dd1c8b610 100644 --- a/pkg/pillar/hypervisor/containerd.go +++ b/pkg/pillar/hypervisor/containerd.go @@ -11,7 +11,6 @@ import ( "time" "github.com/lf-edge/eve/pkg/pillar/containerd" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" "github.com/opencontainers/runtime-spec/specs-go" @@ -327,14 +326,14 @@ func (ctx ctrdContext) GetDomsCPUMem() (map[string]types.DomainMetric, error) { return res, nil } -func (ctx ctrdContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx ctrdContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx ctrdContext) VirtualTPMTerminate(domainName string) error { +func (ctx ctrdContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx ctrdContext) VirtualTPMTeardown(domainName string) error { +func (ctx ctrdContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/hypervisor/kubevirt.go b/pkg/pillar/hypervisor/kubevirt.go index 4edb6bc5fb..ae3c3e128f 100644 --- a/pkg/pillar/hypervisor/kubevirt.go +++ b/pkg/pillar/hypervisor/kubevirt.go @@ -22,7 +22,6 @@ import ( "time" "github.com/lf-edge/eve/pkg/pillar/base" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" netattdefv1 "github.com/k8snetworkplumbingwg/network-attachment-definition-client/pkg/apis/k8s.cni.cncf.io/v1" @@ -1362,14 +1361,14 @@ func (ctx kubevirtContext) PCISameController(id1 string, id2 string) bool { return PCISameControllerGeneric(id1, id2) } -func (ctx kubevirtContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx kubevirtContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx kubevirtContext) VirtualTPMTerminate(domainName string) error { +func (ctx kubevirtContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx kubevirtContext) VirtualTPMTeardown(domainName string) error { +func (ctx kubevirtContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/hypervisor/kvm.go b/pkg/pillar/hypervisor/kvm.go index 8c4483fd15..043df98a30 100644 --- a/pkg/pillar/hypervisor/kvm.go +++ b/pkg/pillar/hypervisor/kvm.go @@ -4,8 +4,11 @@ package hypervisor import ( + "context" "fmt" + "io" "net" + "net/http" "os" "path/filepath" "runtime" @@ -17,7 +20,6 @@ import ( zconfig "github.com/lf-edge/eve-api/go/config" "github.com/lf-edge/eve/pkg/pillar/agentlog" "github.com/lf-edge/eve/pkg/pillar/containerd" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" "github.com/lf-edge/eve/pkg/pillar/utils" fileutils "github.com/lf-edge/eve/pkg/pillar/utils/file" @@ -28,16 +30,28 @@ import ( const ( // KVMHypervisorName is a name of kvm hypervisor - KVMHypervisorName = "kvm" - minUringKernelTag = uint64((5 << 16) | (4 << 8) | (72 << 0)) - swtpmTimeout = 10 // seconds - qemuTimeout = 3 // seconds - vtpmPurgePrefix = "purge;" - vtpmDeletePrefix = "terminate;" - vtpmLaunchPrefix = "launch;" + KVMHypervisorName = "kvm" + minUringKernelTag = uint64((5 << 16) | (4 << 8) | (72 << 0)) + swtpmTimeout = 10 // seconds + qemuTimeout = 3 // seconds + vtpmPurgeEndpoint = "purge" + vtpmTermEndpoint = "terminate" + vtpmLaunchEndpoint = "launch" ) -var clientCid = uint32(unix.VMADDR_CID_HOST + 1) +var ( + clientCid = uint32(unix.VMADDR_CID_HOST + 1) + vTPMClient = &http.Client{ + Transport: vtpmClientUDSTransport(), + Timeout: 5 * time.Second, + } +) + +// vtpmRequestResult holds the result of a vTPM request. +type vtpmRequestResult struct { + Body string + Error error +} // We build device model around PCIe topology according to best practices // https://github.com/qemu/qemu/blob/master/docs/pcie.txt @@ -1330,101 +1344,142 @@ func getQmpListenerSocket(domainName string) string { } // VirtualTPMSetup launches a vTPM instance for the domain -func (ctx KvmContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { - if ps != nil { - wk := utils.NewWatchdogKick(ps, agentName, warnTime, errTime) - domainUUID, _, _, err := types.DomainnameToUUID(domainName) - if err != nil { - return fmt.Errorf("failed to extract UUID from domain name (vTPM setup): %v", err) - } - return requestVtpmLaunch(domainUUID, wk, swtpmTimeout) +func (ctx KvmContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { + if wp == nil { + return fmt.Errorf("invalid watchdog configuration") + } + + domainUUID, _, _, err := types.DomainnameToUUID(domainName) + if err != nil { + return fmt.Errorf("failed to extract UUID from domain name: %v", err) } + return requestvTPMLaunch(domainUUID, wp, swtpmTimeout) - return fmt.Errorf("invalid watchdog configuration (vTPM setup)") } // VirtualTPMTerminate terminates the vTPM instance -func (ctx KvmContext) VirtualTPMTerminate(domainName string) error { +func (ctx KvmContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { + if wp == nil { + return fmt.Errorf("invalid watchdog configuration") + } + domainUUID, _, _, err := types.DomainnameToUUID(domainName) if err != nil { - return fmt.Errorf("failed to extract UUID from domain name (vTPM terminate): %v", err) + return fmt.Errorf("failed to extract UUID from domain name: %v", err) } - if err := requestVtpmTermination(domainUUID); err != nil { + if err := requestvTPMTermination(domainUUID, wp); err != nil { return fmt.Errorf("failed to terminate vTPM for domain %s: %w", domainName, err) } return nil } // VirtualTPMTeardown purges the vTPM instance. -func (ctx KvmContext) VirtualTPMTeardown(domainName string) error { +func (ctx KvmContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { + if wp == nil { + return fmt.Errorf("invalid watchdog configuration") + } + domainUUID, _, _, err := types.DomainnameToUUID(domainName) if err != nil { - return fmt.Errorf("failed to extract UUID from domain name (vTPM teardown): %v", err) + return fmt.Errorf("failed to extract UUID from domain name: %v", err) } - if err := requestVtpmPurge(domainUUID); err != nil { + if err := requestvTPMPurge(domainUUID, wp); err != nil { return fmt.Errorf("failed to purge vTPM for domain %s: %w", domainName, err) } return nil } -func requestVtpmLaunch(id uuid.UUID, wk *utils.WatchdogKick, timeoutSeconds uint) error { - conn, err := net.Dial("unix", types.VtpmdCtrlSocket) - if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) +// This is the Unix Domain Socket (UDS) transport for vTPM requests. +func vtpmClientUDSTransport() *http.Transport { + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", types.VtpmdCtrlSocket) + }, } - defer conn.Close() - - pidPath := fmt.Sprintf(types.SwtpmPidPath, id.String()) +} - // Send the request to the vTPM control socket, ask it to launch a swtpm instance. - _, err = conn.Write([]byte(fmt.Sprintf("%s%s\n", vtpmLaunchPrefix, id.String()))) +func makeRequestAsync(client *http.Client, endpoint, id string, rChan chan<- vtpmRequestResult) { + url := fmt.Sprintf("http://unix/%s?id=%s", endpoint, id) + req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + rChan <- vtpmRequestResult{Error: fmt.Errorf("error when creating request to %s endpoint: %v", url, err)} + return } - - // Loop and wait for SWTPM to start. - pid, err := utils.GetPidFromFileTimeout(pidPath, timeoutSeconds, wk) + resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to get pid from file %s: %w", pidPath, err) + rChan <- vtpmRequestResult{Error: fmt.Errorf("error when sending request to %s endpoint: %v", url, err)} + return } + defer resp.Body.Close() - // One last time, check SWTPM is not dead right after launch. - if !utils.IsProcAlive(pid) { - return fmt.Errorf("SWTPM (pid: %d) is dead", pid) + body, err := io.ReadAll(resp.Body) + if err != nil { + rChan <- vtpmRequestResult{Error: fmt.Errorf("error when reading response body from %s endpoint: %v", url, err)} + return + } + if resp.StatusCode != http.StatusOK { + rChan <- vtpmRequestResult{Error: fmt.Errorf("received status code %d from %s endpoint", resp.StatusCode, url), Body: string(body)} + return } - return nil + rChan <- vtpmRequestResult{Body: string(body)} +} + +func makeRequest(client *http.Client, wk *utils.WatchdogKicker, endpoint, id string) (body string, err error) { + rChan := make(chan vtpmRequestResult) + go makeRequestAsync(client, endpoint, id, rChan) + + startTime := time.Now() + for { + select { + case res := <-rChan: + return res.Body, res.Error + default: + utils.KickWatchdog(wk) + if time.Since(startTime).Seconds() >= float64(client.Timeout.Seconds()) { + return "", fmt.Errorf("timeout") + } + time.Sleep(100 * time.Millisecond) + } + } } -func requestVtpmPurge(id uuid.UUID) error { - conn, err := net.Dial("unix", types.VtpmdCtrlSocket) +func requestvTPMLaunch(id uuid.UUID, wp *types.WatchdogParam, timeoutSeconds uint) error { + wk := utils.NewWatchdogKicker(wp.Ps, wp.AgentName, wp.WarnTime, wp.ErrTime) + body, err := makeRequest(vTPMClient, wk, vtpmLaunchEndpoint, id.String()) if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) + return fmt.Errorf("failed to launch vTPM instance: %w (%s)", err, body) } - defer conn.Close() - // Send a request to vTPM control socket, ask it to purge the instance - // and all its data. - _, err = conn.Write([]byte(fmt.Sprintf("%s%s\n", vtpmPurgePrefix, id.String()))) + // Wait for SWTPM to start. + pidPath := fmt.Sprintf(types.SwtpmPidPath, id.String()) + _, err = utils.GetPidFromFileTimeout(pidPath, timeoutSeconds, wk) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return fmt.Errorf("failed to get pid from file %s: %w", pidPath, err) } return nil } -func requestVtpmTermination(id uuid.UUID) error { - conn, err := net.Dial("unix", types.VtpmdCtrlSocket) +func requestvTPMPurge(id uuid.UUID, wp *types.WatchdogParam) error { + // Send a request to vTPM control socket, ask it to purge the instance + // and all its data. + wk := utils.NewWatchdogKicker(wp.Ps, wp.AgentName, wp.WarnTime, wp.ErrTime) + body, err := makeRequest(vTPMClient, wk, vtpmPurgeEndpoint, id.String()) if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) + return fmt.Errorf("failed to purge vTPM instance: %w (%s)", err, body) } - defer conn.Close() - // Send a request to the vTPM control socket, ask it to delete the instance. - _, err = conn.Write([]byte(fmt.Sprintf("%s%s\n", vtpmDeletePrefix, id.String()))) + return nil +} + +func requestvTPMTermination(id uuid.UUID, wp *types.WatchdogParam) error { + // Send a request to vTPM control socket, ask it to terminate the instance. + wk := utils.NewWatchdogKicker(wp.Ps, wp.AgentName, wp.WarnTime, wp.ErrTime) + body, err := makeRequest(vTPMClient, wk, vtpmTermEndpoint, id.String()) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return fmt.Errorf("failed to terminate vTPM instance: %w (%s)", err, body) } return nil diff --git a/pkg/pillar/hypervisor/null.go b/pkg/pillar/hypervisor/null.go index 899c35ed27..f1daf5d5d5 100644 --- a/pkg/pillar/hypervisor/null.go +++ b/pkg/pillar/hypervisor/null.go @@ -6,9 +6,7 @@ package hypervisor import ( "fmt" "os" - "time" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" uuid "github.com/satori/go.uuid" @@ -174,14 +172,14 @@ func (ctx nullContext) GetDomsCPUMem() (map[string]types.DomainMetric, error) { return nil, nil } -func (ctx nullContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx nullContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx nullContext) VirtualTPMTerminate(domainName string) error { +func (ctx nullContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx nullContext) VirtualTPMTeardown(domainName string) error { +func (ctx nullContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/hypervisor/xen.go b/pkg/pillar/hypervisor/xen.go index 729f0d7432..323fa2233a 100644 --- a/pkg/pillar/hypervisor/xen.go +++ b/pkg/pillar/hypervisor/xen.go @@ -11,9 +11,7 @@ import ( "runtime" "strconv" "strings" - "time" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" "github.com/shirou/gopsutil/cpu" "github.com/shirou/gopsutil/mem" @@ -877,14 +875,14 @@ func fallbackDomainMetric() map[string]types.DomainMetric { return dmList } -func (ctx xenContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx xenContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx xenContext) VirtualTPMTerminate(domainName string) error { +func (ctx xenContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx xenContext) VirtualTPMTeardown(domainName string) error { +func (ctx xenContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/types/domainmgrtypes.go b/pkg/pillar/types/domainmgrtypes.go index 3e7bd762cc..2ea433d1f7 100644 --- a/pkg/pillar/types/domainmgrtypes.go +++ b/pkg/pillar/types/domainmgrtypes.go @@ -277,9 +277,9 @@ const ( type Task interface { Setup(DomainStatus, DomainConfig, *AssignableAdapters, *ConfigItemValueMap, *os.File) error - VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error - VirtualTPMTerminate(domainName string) error - VirtualTPMTeardown(domainName string) error + VirtualTPMSetup(domainName string, wp *WatchdogParam) error + VirtualTPMTerminate(domainName string, wp *WatchdogParam) error + VirtualTPMTeardown(domainName string, wp *WatchdogParam) error Create(string, string, *DomainConfig) (int, error) Start(string) error Stop(string, bool) error @@ -580,3 +580,12 @@ type Capabilities struct { CPUPinning bool // CPU Pinning support UseVHost bool // vHost support } + +// WatchdogParam is used in some proc functions that have a timeout, +// to tell the watchdog agent is still alive. +type WatchdogParam struct { + Ps *pubsub.PubSub + AgentName string + WarnTime time.Duration + ErrTime time.Duration +} diff --git a/pkg/pillar/utils/proc.go b/pkg/pillar/utils/proc.go index aa436e7528..03d4ace6f7 100644 --- a/pkg/pillar/utils/proc.go +++ b/pkg/pillar/utils/proc.go @@ -15,18 +15,18 @@ import ( "github.com/lf-edge/eve/pkg/pillar/pubsub" ) -// WatchdogKick is used in some proc functions that have a timeout, +// WatchdogKicker is used in some proc functions that have a timeout, // to tell the watchdog agent is still alive. -type WatchdogKick struct { +type WatchdogKicker struct { ps *pubsub.PubSub agentName string warnTime time.Duration errTime time.Duration } -// NewWatchdogKick creates a new WatchdogKick. -func NewWatchdogKick(ps *pubsub.PubSub, agentName string, warnTime time.Duration, errTime time.Duration) *WatchdogKick { - return &WatchdogKick{ +// NewWatchdogKicker creates a new WatchdogKick. +func NewWatchdogKicker(ps *pubsub.PubSub, agentName string, warnTime time.Duration, errTime time.Duration) *WatchdogKicker { + return &WatchdogKicker{ ps: ps, agentName: agentName, warnTime: warnTime, @@ -34,6 +34,11 @@ func NewWatchdogKick(ps *pubsub.PubSub, agentName string, warnTime time.Duration } } +// KickWatchdog tells the watchdog agent is still alive. +func KickWatchdog(wk *WatchdogKicker) { + wk.ps.StillRunning(wk.agentName, wk.warnTime, wk.errTime) +} + // PkillArgs does a pkill func PkillArgs(log *base.LogObject, match string, printOnError bool, kill bool) { cmd := "pkill" @@ -85,7 +90,7 @@ func GetPidFromFile(pidFile string) (int, error) { } // GetPidFromFileTimeout reads a pid from a file with a timeout. -func GetPidFromFileTimeout(pidFile string, timeoutSeconds uint, wk *WatchdogKick) (int, error) { +func GetPidFromFileTimeout(pidFile string, timeoutSeconds uint, wk *WatchdogKicker) (int, error) { startTime := time.Now() for { if time.Since(startTime).Seconds() >= float64(timeoutSeconds) { @@ -118,7 +123,7 @@ func IsProcAlive(pid int) bool { } // IsProcAliveTimeout checks if a process is alive for a given timeout. -func IsProcAliveTimeout(pid int, timeoutSeconds uint, wk *WatchdogKick) bool { +func IsProcAliveTimeout(pid int, timeoutSeconds uint, wk *WatchdogKicker) bool { startTime := time.Now() for { if time.Since(startTime).Seconds() >= float64(timeoutSeconds) { diff --git a/pkg/vtpm/swtpm-vtpm/go.mod b/pkg/vtpm/swtpm-vtpm/go.mod index d8a0396f21..8c03d78048 100644 --- a/pkg/vtpm/swtpm-vtpm/go.mod +++ b/pkg/vtpm/swtpm-vtpm/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/lf-edge/eve/pkg/pillar v0.0.0-20240820084217-3f317a9e801f + github.com/satori/go.uuid v1.2.1-0.20180404165556-75cca531ea76 github.com/sirupsen/logrus v1.9.3 ) @@ -20,7 +21,6 @@ require ( github.com/leodido/go-urn v1.2.4 // indirect github.com/lf-edge/eve-api/go v0.0.0-20240722173316-ed56da45126b // indirect github.com/lf-edge/eve/pkg/kube/cnirpc v0.0.0-20240315102754-0f6d1f182e0d // indirect - github.com/satori/go.uuid v1.2.1-0.20180404165556-75cca531ea76 // indirect github.com/vishvananda/netlink v1.2.1-beta.2 // indirect github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect golang.org/x/crypto v0.21.0 // indirect diff --git a/pkg/vtpm/swtpm-vtpm/src/main.go b/pkg/vtpm/swtpm-vtpm/src/main.go index 7eec749b07..8f47ee32b6 100644 --- a/pkg/vtpm/swtpm-vtpm/src/main.go +++ b/pkg/vtpm/swtpm-vtpm/src/main.go @@ -7,14 +7,14 @@ package main import ( - "bufio" "fmt" "net" + "net/http" "os" "os/exec" - "os/signal" "strconv" "strings" + "sync" "syscall" "time" @@ -22,25 +22,23 @@ import ( etpm "github.com/lf-edge/eve/pkg/pillar/evetpm" "github.com/lf-edge/eve/pkg/pillar/types" utils "github.com/lf-edge/eve/pkg/pillar/utils/file" + uuid "github.com/satori/go.uuid" "github.com/sirupsen/logrus" ) const ( - swtpmPath = "/usr/bin/swtpm" - purgeReq = "purge" - terminateReq = "terminate" - launchReq = "launch" - maxInstances = 10 - maxIDLen = 128 - maxWaitTime = 3 //seconds + swtpmPath = "/usr/bin/swtpm" + maxInstances = 10 + maxPidWaitTime = 5 //seconds ) var ( liveInstances int + m sync.Mutex log *base.LogObject - pids = make(map[string]int, 0) + pids = make(map[uuid.UUID]int, 0) + // XXX : move the paths to types so we have everything EVE creates in one place. // These are defined as vars to be able to mock them in tests - // XXX : move this to types so we have everything EVE creates in one place. stateEncryptionKey = "/run/swtpm/%s.binkey" swtpmIsEncryptedPath = "/persist/swtpm/%s.encrypted" swtpmStatePath = "/persist/swtpm/tpm-state-%s" @@ -56,35 +54,19 @@ var ( } ) -func parseRequest(id string) (string, string, error) { - id = strings.TrimSpace(id) - if id == "" || len(id) > maxIDLen { - return "", "", fmt.Errorf("invalid SWTPM ID received") +func isAlive(pid int) bool { + if err := syscall.Kill(pid, 0); err != nil { + return false } - - // breake the string and get the request type - split := strings.Split(id, ";") - if len(split) != 2 { - return "", "", fmt.Errorf("invalid SWTPM ID received (no request)") - } - - if split[1] == "" { - return "", "", fmt.Errorf("invalid SWTPM ID received (no id)") - } - - // request, id - return split[0], split[1], nil + return true } -func cleanupFiles(id string) { - statePath := fmt.Sprintf(swtpmStatePath, id) - ctrlSockPath := fmt.Sprintf(swtpmCtrlSockPath, id) - pidPath := fmt.Sprintf(swtpmPidPath, id) - isEncryptedPath := fmt.Sprintf(swtpmIsEncryptedPath, id) - os.RemoveAll(statePath) - os.Remove(ctrlSockPath) - os.Remove(pidPath) - os.Remove(isEncryptedPath) +func cleanupFiles(uuid uuid.UUID) { + id := uuid.String() + os.RemoveAll(fmt.Sprintf(swtpmStatePath, id)) + os.Remove(fmt.Sprintf(swtpmCtrlSockPath, id)) + os.Remove(fmt.Sprintf(swtpmPidPath, id)) + os.Remove(fmt.Sprintf(swtpmIsEncryptedPath, id)) } func makeDirs(dir string) error { @@ -102,26 +84,36 @@ func makeDirs(dir string) error { return nil } +func readPidFile(pidPath string) (int, error) { + pid := 0 + pidStr, err := os.ReadFile(pidPath) + if err == nil { + pid, err = strconv.Atoi(strings.TrimSpace(string(pidStr))) + if err == nil { + return pid, nil + } + } + + return 0, fmt.Errorf("failed to read pid file: %w", err) +} + func getSwtpmPid(pidPath string, timeoutSeconds uint) (int, error) { startTime := time.Now() for { if time.Since(startTime).Seconds() >= float64(timeoutSeconds) { - return 0, fmt.Errorf("timeout reached") + return 0, fmt.Errorf("timeout reached after %d seconds", int(time.Since(startTime).Seconds())) } - pidStr, err := os.ReadFile(pidPath) - if err == nil { - pid, err := strconv.Atoi(strings.TrimSpace(string(pidStr))) - if err == nil { - return pid, nil - } + if pid, err := readPidFile(pidPath); err == nil { + return pid, nil } time.Sleep(500 * time.Millisecond) } } -func runSwtpm(id string) (int, error) { +func runSwtpm(uuid uuid.UUID) (int, error) { + id := uuid.String() statePath := fmt.Sprintf(swtpmStatePath, id) ctrlSockPath := fmt.Sprintf(swtpmCtrlSockPath, id) binKeyPath := fmt.Sprintf(stateEncryptionKey, id) @@ -144,7 +136,7 @@ func runSwtpm(id string) (int, error) { // if SWTPM state for app marked as as encrypted, and TPM is not available // anymore, fail because this will corrupt the SWTPM state. if utils.FileExists(log, isEncryptedPath) { - return 0, fmt.Errorf("state encryption was enabled for app, but TPM is no longer available") + return 0, fmt.Errorf("state encryption was enabled for SWTPM, but TPM is no longer available") } cmd := exec.Command(swtpmPath, swtpmArgs...) @@ -159,7 +151,7 @@ func runSwtpm(id string) (int, error) { return 0, fmt.Errorf("failed to get SWTPM state encryption key : %w", err) } - // we are about to write the key to the disk, so mark the app SWTPM state + // we are about to write the key to the disk, so mark the SWTPM state // as encrypted if !utils.FileExists(log, isEncryptedPath) { if err := utils.WriteRename(isEncryptedPath, []byte("Y")); err != nil { @@ -183,142 +175,150 @@ func runSwtpm(id string) (int, error) { } } - pid, err := getSwtpmPid(pidPath, maxWaitTime) + pid, err := getSwtpmPid(pidPath, maxPidWaitTime) if err != nil { return 0, fmt.Errorf("failed to get SWTPM pid: %w", err) } - // Add to the list. - pids[id] = pid return pid, nil } -func main() { - log = base.NewSourceLogObject(logrus.StandardLogger(), "vtpm", os.Getpid()) - if log == nil { - fmt.Println("Failed to create log object") - os.Exit(1) - } - - serviceLoop() -} - -func serviceLoop() { - uds, err := net.Listen("unix", vtpmdCtrlSockPath) - if err != nil { - log.Errorf("failed to create vtpm control socket: %v", err) +// Domain manager is requesting to launch a new VM/App, run a new SWTPM +// instance with the given id. +func handleLaunch(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + err := fmt.Sprintf("Method %s not allowed", r.Method) + http.Error(w, err, http.StatusMethodNotAllowed) return } - defer uds.Close() - - // Make sure we remove the socket file on exit. - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-sigs - os.Remove(vtpmdCtrlSockPath) - os.Exit(0) - }() - - for { - // xxx : later get peer creds (getpeereid) and check if the caller is - // the domain manager to avoid any other process from sending requests. - conn, err := uds.Accept() - if err != nil { - log.Errorf("failed to accept connection over vtpmd control socket: %v", err) - continue - } - - id, err := bufio.NewReader(conn).ReadString('\n') - if err != nil { - log.Errorf("failed to read SWTPM ID from connection: %v", err) - continue - } - - // Close the connection as soon as we read the ID, - // handle one request at a time. - conn.Close() - - // Don't launch go routines, instead serve requests one by one to avoid - // using any locks, handle* functions access pids global variable. - err = handleRequest(id) - if err != nil { - log.Errorf("failed to handle request: %v", err) - } - } -} -func handleRequest(id string) error { - request, id, err := parseRequest(id) - if err != nil { - return fmt.Errorf("failed to parse request: %w", err) - } + // pids and liveInstances is shared, take care of them. + m.Lock() + defer m.Unlock() - switch request { - case launchReq: - // Domain manager is requesting to launch a new VM/App, run a new SWTPM - // instance with the given id. - return handleLaunch(id) - case purgeReq: - // VM/App is being deleted, domain manager is sending a purge request, - // delete the SWTPM instance and clean up all the files. - return handlePurge(id) - case terminateReq: - // Domain manager is sending a terminate request because it hit an error while - // starting the app (i.e. qemu crashed), so just remove kill SWTPM instance, - // remove it's pid for the list and decrease the liveInstances count. - return handleTerminate(id) - default: - return fmt.Errorf("invalid request received") + if liveInstances >= maxInstances { + err := fmt.Sprintf("vTPM max number of SWTPM instances reached %d", liveInstances) + http.Error(w, err, http.StatusTooManyRequests) + return } -} -func handleLaunch(id string) error { - if liveInstances >= maxInstances { - return fmt.Errorf("max number of vTPM instances reached %d", liveInstances) + reqID := r.URL.Query().Get("id") + id := uuid.FromStringOrNil(reqID) + if id == uuid.Nil { + err := fmt.Sprintf("vTPM launch request failed, id \"%s\" is invalid", reqID) + http.Error(w, err, http.StatusBadRequest) + return } - // If we have SWTPM instance with the same id running, it means either the - // domain got rebooted or something went wrong on the dommain manager side!! - // it the later case it should have sent a delete request if VM crashed or - // there was any other VM related errors. Anyway, refuse to launch a new + // If we have a record of SWTPM instance with the requested id, + // check if it's still alive. if it is alive, refuse to launch a new // instance with the same id as this might corrupt the state. if _, ok := pids[id]; ok { - return fmt.Errorf("SWTPM instance with id %s already running", id) + pidPath := fmt.Sprintf(swtpmPidPath, id) + // if pid file does not exist, it means the SWTPM instance gracefully + // terminated and we can start a new one. + if _, err := os.Stat(pidPath); err == nil { + pid, err := getSwtpmPid(pidPath, maxPidWaitTime) + if err != nil { + err := fmt.Sprintf("vTPM failed to read pid file of SWTPM with id %s", id) + http.Error(w, err, http.StatusExpectationFailed) + return + } + + // if the SWTPM instance is still alive, we can move on. + if isAlive(pid) { + log.Noticef("vTPM SWTPM instance with id %s is already running with pid %d", id, pid) + w.WriteHeader(http.StatusOK) + return + } + } else if !os.IsNotExist(err) { + err := fmt.Sprintf("vTPM failed to check pid file of SWTPM with id %s: %v", id, err) + http.Error(w, err, http.StatusFailedDependency) + return + } + + liveInstances-- + delete(pids, id) } pid, err := runSwtpm(id) if err != nil { - return fmt.Errorf("failed to start SWTPM instance: %v", err) + err := fmt.Sprintf("vTPM failed to start SWTPM instance: %v", err) + http.Error(w, err, http.StatusFailedDependency) + return } - log.Noticef("SWTPM instance with id %s is running with pid %d", id, pid) + log.Noticef("vTPM launched SWTPM instance with id: %s, pid: %d", id, pid) + // Send a success response. + pids[id] = pid liveInstances++ - return nil + w.WriteHeader(http.StatusOK) } -func handleTerminate(id string) error { - // We expect the SWTPM to be terminated at this point, but just in case send - // a term signal. +// Domain manager is sending a terminate request because it hit an error while +// starting the app, so just kill SWTPM instance, remove it's pid from the list +// and decrease the liveInstances count. +func handleTerminate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + err := fmt.Sprintf("Method %s not allowed", r.Method) + http.Error(w, err, http.StatusMethodNotAllowed) + return + } + + // pids and liveInstances is shared, take care of it. + m.Lock() + defer m.Unlock() + + reqID := r.URL.Query().Get("id") + id := uuid.FromStringOrNil(reqID) + if id == uuid.Nil { + err := fmt.Sprintf("vTPM launch request failed, id \"%s\" is invalid", reqID) + http.Error(w, err, http.StatusBadRequest) + return + } pid, ok := pids[id] if ok { if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { if err != syscall.ESRCH { // This should not happen, but log it just in case. - log.Errorf("failed to kill SWTPM instance (terminate request): %v", err) + log.Errorf("vTPM failed to kill SWTPM instance (terminate request): %v", err) } } delete(pids, id) liveInstances-- } else { - return fmt.Errorf("terminate request failed, SWTPM instance with id %s not found", id) + err := fmt.Sprintf("vTPM terminate request failed, SWTPM instance with id %s not found", id) + http.Error(w, err, http.StatusNotFound) + return } - return nil + log.Noticef("vTPM terminated SWTPM instance with id: %s, pid: %d", id, pid) + + // send a success response. + w.WriteHeader(http.StatusOK) } -func handlePurge(id string) error { - log.Noticef("Purging SWTPM instance with id: %s", id) +// VM/App is being deleted, domain manager is sending a purge request, +// delete the SWTPM instance and clean up all the files. +func handlePurge(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + err := fmt.Sprintf("Method %s not allowed", r.Method) + http.Error(w, err, http.StatusMethodNotAllowed) + return + } + + // pids and liveInstances is shared, take care of it. + m.Lock() + defer m.Unlock() + + reqID := r.URL.Query().Get("id") + id := uuid.FromStringOrNil(reqID) + if id == uuid.Nil { + err := fmt.Sprintf("vTPM launch request failed, id \"%s\" is invalid", reqID) + http.Error(w, err, http.StatusBadRequest) + return + } + // we actually expect the SWTPM to be terminated at this point, because qemu // either sends CMD_SHUTDOWN through the control socket or in case of qemu // crashing, SWTPM terminates itself when the control socket is closed since @@ -331,12 +331,44 @@ func handlePurge(id string) error { log.Errorf("failed to kill SWTPM instance (purge request): %v", err) } } - delete(pids, id) liveInstances-- } - // clean up files if exists + log.Noticef("vTPM purged SWTPM instance with id: %s, pid: %d", id, pid) + + // clean up the files and send a success response. cleanupFiles(id) - return nil + w.WriteHeader(http.StatusOK) +} + +func startServing() { + if _, err := os.Stat(vtpmdCtrlSockPath); err == nil { + os.Remove(vtpmdCtrlSockPath) + } + listener, err := net.Listen("unix", vtpmdCtrlSockPath) + if err != nil { + log.Fatalf("Error creating Unix socket: %v", err) + } + defer listener.Close() + + os.Chmod(vtpmdCtrlSockPath, 0600) + mux := http.NewServeMux() + mux.HandleFunc("/launch", handleLaunch) + mux.HandleFunc("/terminate", handleTerminate) + mux.HandleFunc("/purge", handlePurge) + + log.Noticef("vTPM server is listening on Unix socket: %s", vtpmdCtrlSockPath) + http.Serve(listener, mux) +} + +func main() { + log = base.NewSourceLogObject(logrus.StandardLogger(), "vtpm", os.Getpid()) + if log == nil { + fmt.Println("Failed to create log object") + os.Exit(1) + } + + // this never returns, ideally. + startServing() } diff --git a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go index a82ecfd314..a875e876dd 100644 --- a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go +++ b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go @@ -4,19 +4,26 @@ package main import ( + "context" "crypto/rand" "fmt" + "io" "net" + "net/http" "os" + "syscall" "testing" "time" "github.com/lf-edge/eve/pkg/pillar/base" + uuid "github.com/satori/go.uuid" "github.com/sirupsen/logrus" ) const baseDir = "/tmp/swtpm/test" +var client = &http.Client{} + func TestMain(m *testing.M) { log = base.NewSourceLogObject(logrus.StandardLogger(), "vtpm", os.Getpid()) os.MkdirAll(baseDir, 0755) @@ -28,71 +35,73 @@ func TestMain(m *testing.M) { swtpmPidPath = baseDir + "/%s.pid" vtpmdCtrlSockPath = baseDir + "/vtpmd.ctrl.sock" - go serviceLoop() - defer func() { - _ = os.Remove(vtpmdCtrlSockPath) - }() + client = &http.Client{ + Transport: UnixSocketTransport(vtpmdCtrlSockPath), + Timeout: 5 * time.Second, + } + go startServing() time.Sleep(1 * time.Second) m.Run() } -func sendLaunchRequest(id string) error { - conn, err := net.Dial("unix", vtpmdCtrlSockPath) - if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) - } - defer conn.Close() +func generateUUID() uuid.UUID { + id, _ := uuid.NewV4() + return id +} - _, err = conn.Write([]byte(fmt.Sprintf("%s;%s\n", launchReq, id))) - if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) +func UnixSocketTransport(socketPath string) *http.Transport { + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, } - - return nil } -func sendPurgeRequest(id string) error { - conn, err := net.Dial("unix", vtpmdCtrlSockPath) +func makeRequest(client *http.Client, endpoint, id string) (string, int, error) { + url := fmt.Sprintf("http://unix/%s?id=%s", endpoint, id) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return "", -1, fmt.Errorf("error when creating request: %v", err) + } + resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) + return "", -1, fmt.Errorf("request failed: %v", err) } - defer conn.Close() + defer resp.Body.Close() - _, err = conn.Write([]byte(fmt.Sprintf("%s;%s\n", purgeReq, id))) + body, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return "", -1, fmt.Errorf("error when reading response body: %v", err) + } + if resp.StatusCode != http.StatusOK { + return string(body), resp.StatusCode, fmt.Errorf("received status code %d \n", resp.StatusCode) } - time.Sleep(1 * time.Second) - return nil + return string(body), resp.StatusCode, nil } -func sendTerminateRequest(id string) error { - conn, err := net.Dial("unix", vtpmdCtrlSockPath) - if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) - } - defer conn.Close() +func sendLaunchRequest(id uuid.UUID) (string, int, error) { + return makeRequest(client, "launch", id.String()) +} - _, err = conn.Write([]byte(fmt.Sprintf("%s;%s\n", terminateReq, id))) - if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) - } +func sendPurgeRequest(id uuid.UUID) (string, int, error) { + return makeRequest(client, "purge", id.String()) +} - time.Sleep(1 * time.Second) - return nil +func sendTerminateRequest(id uuid.UUID) (string, int, error) { + return makeRequest(client, "terminate", id.String()) } -func testLaunchAndPurge(t *testing.T, id string) { +func testLaunchAndPurge(t *testing.T, id uuid.UUID) { // test logic : // 1. send launch request // 2. check number of live instances, it should be 1 // 3. send purge request // 4. check number of live instances, it should be 0 - err := sendLaunchRequest(id) + b, _, err := sendLaunchRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } time.Sleep(1 * time.Second) @@ -100,39 +109,45 @@ func testLaunchAndPurge(t *testing.T, id string) { t.Fatalf("expected liveInstances to be 1, got %d", liveInstances) } - err = sendPurgeRequest(id) + b, _, err = sendPurgeRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) if liveInstances != 0 { t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) } } -func testExhaustSwtpmInstances(t *testing.T, id string) { +func testExhaustSwtpmInstances(t *testing.T) { + ids := make([]uuid.UUID, 0) for i := 0; i < maxInstances; i++ { - err := sendLaunchRequest(fmt.Sprintf("%s-%d", id, i)) + id := generateUUID() + b, _, err := sendLaunchRequest(id) if err != nil { - t.Errorf("failed to send request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + defer cleanupFiles(id) + ids = append(ids, id) } time.Sleep(5 * time.Second) // this should have no effect as we have reached max instances - err := sendLaunchRequest(id) - if err != nil { - t.Errorf("failed to send request: %v", err) + b, res, err := sendLaunchRequest(generateUUID()) + if res != http.StatusTooManyRequests { + t.Fatalf("expected status code to be %d, got %d, err : %v, body: %s", http.StatusTooManyRequests, res, err, b) } if liveInstances != maxInstances { t.Errorf("expected liveInstances to be %d, got %d", maxInstances, liveInstances) } - err = sendPurgeRequest(fmt.Sprintf("%s-0", id)) + b, _, err = sendPurgeRequest(ids[0]) if err != nil { - t.Errorf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) if liveInstances != maxInstances-1 { t.Errorf("expected liveInstances to be %d, got %d", maxInstances-1, liveInstances) @@ -140,15 +155,16 @@ func testExhaustSwtpmInstances(t *testing.T, id string) { // clean up for i := 0; i < maxInstances; i++ { - err := sendPurgeRequest(fmt.Sprintf("%s-%d", id, i)) + b, _, err := sendPurgeRequest(ids[i]) if err != nil { - t.Errorf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) } } func TestLaunchAndPurgeWithoutStateEncryption(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) isTPMAvailable = func() bool { return false @@ -158,7 +174,7 @@ func TestLaunchAndPurgeWithoutStateEncryption(t *testing.T) { } func TestLaunchAndPurgeWithStateEncryption(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) isTPMAvailable = func() bool { return true @@ -173,18 +189,14 @@ func TestLaunchAndPurgeWithStateEncryption(t *testing.T) { } func TestExhaustSwtpmInstancesWithoutStateEncryption(t *testing.T) { - id := "test" - defer cleanupFiles(id) isTPMAvailable = func() bool { return false } - testExhaustSwtpmInstances(t, id) + testExhaustSwtpmInstances(t) } func TestExhaustSwtpmInstancesWithStateEncryption(t *testing.T) { - id := "test" - defer cleanupFiles(id) isTPMAvailable = func() bool { return true } @@ -194,11 +206,11 @@ func TestExhaustSwtpmInstancesWithStateEncryption(t *testing.T) { return key, nil } - testExhaustSwtpmInstances(t, id) + testExhaustSwtpmInstances(t) } func TestSwtpmStateChange(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) isTPMAvailable = func() bool { return true @@ -210,9 +222,15 @@ func TestSwtpmStateChange(t *testing.T) { } // this mark the instance to be encrypted - err := sendLaunchRequest(id) + b, _, err := sendLaunchRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + time.Sleep(1 * time.Second) + + b, _, err = sendTerminateRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } time.Sleep(1 * time.Second) @@ -221,23 +239,25 @@ func TestSwtpmStateChange(t *testing.T) { return false } - // this should fail since this id was marked as encrypted - err = sendLaunchRequest(id) - if err != nil { - t.Fatalf("failed to handle request: %v", err) + // this should fail since it was instance was marked as encrypted and now TPM is not available anymore + b, _, err = sendLaunchRequest(id) + if err == nil { + t.Fatalf("expected error, got nil") } + t.Logf("expected error: %v, body: %s", err, b) + if liveInstances > 1 { t.Fatalf("expected liveInstances to be 1, got %d", liveInstances) } - err = sendPurgeRequest(id) + b, _, err = sendPurgeRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } } func TestDeleteRequest(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) // this doesn't matter @@ -245,15 +265,104 @@ func TestDeleteRequest(t *testing.T) { return false } - err := sendLaunchRequest(id) + b, _, err := sendLaunchRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } - err = sendTerminateRequest(id) + b, _, err = sendTerminateRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) + + if liveInstances != 0 { + t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) + } +} + +func TestSwtpmAbruptTerminationRequest(t *testing.T) { + // this test verify that if swtpm is terminated without vTPM notice, + // no stale id is left in the vtpm internal bookkeeping and vtpm + // can launch new instance with the same id. + // test logic : + // 1. send launch request + // 2. read swtpm pid file and terminate it + // 3. send launch request again, this should not fail + id := generateUUID() + defer cleanupFiles(id) + + // this doesn't matter + isTPMAvailable = func() bool { + return false + } + + b, _, err := sendLaunchRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + + pid, err := readPidFile(fmt.Sprintf(swtpmPidPath, id)) + if err != nil { + t.Fatalf("failed to read pid file: %v", err) + } + if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { + t.Fatalf("failed to kill process: %v", err) + + } + + // this should not fail + b, _, err = sendLaunchRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + + b, _, err = sendTerminateRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + time.Sleep(1 * time.Second) + + if liveInstances != 0 { + t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) + } +} + +func TestSwtpmMultipleLaucnhRequest(t *testing.T) { + // this test verify that if swtpm is launched multiple times with the same id, + // only one instance is created and other requests are ignored. + // test logic : + // 1. send launch request multiple times, it all should succeed + // 2. clean up + id := generateUUID() + defer cleanupFiles(id) + + // this doesn't matter + isTPMAvailable = func() bool { + return false + } + + for i := 0; i < 5; i++ { + b, _, err := sendLaunchRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + } + + pid, err := readPidFile(fmt.Sprintf(swtpmPidPath, id)) + if err != nil { + t.Fatalf("failed to read pid file: %v", err) + } + if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { + t.Fatalf("failed to kill process: %v", err) + + } + + b, _, err := sendTerminateRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + time.Sleep(1 * time.Second) if liveInstances != 0 { t.Fatalf("expected liveInstances to be 0, got %d", liveInstances)