From eda5939c462e2df7a671d2805ee3d82cf0b117a0 Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Fri, 25 Oct 2024 13:29:50 +0300 Subject: [PATCH 1/9] vTPM : refactor control socket communication and error handling This changes refactors the control socket communication and error handling in the vTPM (server) and KVM (client). The control socket communication is now handled by HTTP over UDS, and the error handling is improved, since the vTPM server now returns an error message when an error occurs. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit d965fa1be5701606d0f0a84b4f3b9086885b4520) --- pkg/pillar/hypervisor/kvm.go | 107 ++++++----- pkg/vtpm/swtpm-vtpm/src/main.go | 262 +++++++++++++-------------- pkg/vtpm/swtpm-vtpm/src/vtpm_test.go | 134 ++++++++------ 3 files changed, 266 insertions(+), 237 deletions(-) diff --git a/pkg/pillar/hypervisor/kvm.go b/pkg/pillar/hypervisor/kvm.go index 8c4483fd15..a06655d158 100644 --- a/pkg/pillar/hypervisor/kvm.go +++ b/pkg/pillar/hypervisor/kvm.go @@ -4,8 +4,11 @@ package hypervisor import ( + "context" "fmt" + "io" "net" + "net/http" "os" "path/filepath" "runtime" @@ -28,16 +31,22 @@ import ( const ( // KVMHypervisorName is a name of kvm hypervisor - KVMHypervisorName = "kvm" - minUringKernelTag = uint64((5 << 16) | (4 << 8) | (72 << 0)) - swtpmTimeout = 10 // seconds - qemuTimeout = 3 // seconds - vtpmPurgePrefix = "purge;" - vtpmDeletePrefix = "terminate;" - vtpmLaunchPrefix = "launch;" + KVMHypervisorName = "kvm" + minUringKernelTag = uint64((5 << 16) | (4 << 8) | (72 << 0)) + swtpmTimeout = 10 // seconds + qemuTimeout = 3 // seconds + vtpmPurgeEndpoint = "purge" + vtpmTermEndpoint = "terminate" + vtpmLaunchEndpoint = "launch" ) -var clientCid = uint32(unix.VMADDR_CID_HOST + 1) +var ( + clientCid = uint32(unix.VMADDR_CID_HOST + 1) + vTPMClient = &http.Client{ + Transport: vtpmClientUDSTransport(), + Timeout: 2 * time.Second, + } +) // We build device model around PCIe topology according to best practices // https://github.com/qemu/qemu/blob/master/docs/pcie.txt @@ -1335,21 +1344,21 @@ func (ctx KvmContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.P wk := utils.NewWatchdogKick(ps, agentName, warnTime, errTime) domainUUID, _, _, err := types.DomainnameToUUID(domainName) if err != nil { - return fmt.Errorf("failed to extract UUID from domain name (vTPM setup): %v", err) + return fmt.Errorf("failed to extract UUID from domain name: %v", err) } - return requestVtpmLaunch(domainUUID, wk, swtpmTimeout) + return requestvTPMLaunch(domainUUID, wk, swtpmTimeout) } - return fmt.Errorf("invalid watchdog configuration (vTPM setup)") + return fmt.Errorf("invalid watchdog configuration") } // VirtualTPMTerminate terminates the vTPM instance func (ctx KvmContext) VirtualTPMTerminate(domainName string) error { domainUUID, _, _, err := types.DomainnameToUUID(domainName) if err != nil { - return fmt.Errorf("failed to extract UUID from domain name (vTPM terminate): %v", err) + return fmt.Errorf("failed to extract UUID from domain name: %v", err) } - if err := requestVtpmTermination(domainUUID); err != nil { + if err := requestvTPMTermination(domainUUID); err != nil { return fmt.Errorf("failed to terminate vTPM for domain %s: %w", domainName, err) } return nil @@ -1359,31 +1368,55 @@ func (ctx KvmContext) VirtualTPMTerminate(domainName string) error { func (ctx KvmContext) VirtualTPMTeardown(domainName string) error { domainUUID, _, _, err := types.DomainnameToUUID(domainName) if err != nil { - return fmt.Errorf("failed to extract UUID from domain name (vTPM teardown): %v", err) + return fmt.Errorf("failed to extract UUID from domain name: %v", err) } - if err := requestVtpmPurge(domainUUID); err != nil { + if err := requestvTPMPurge(domainUUID); err != nil { return fmt.Errorf("failed to purge vTPM for domain %s: %w", domainName, err) } return nil } -func requestVtpmLaunch(id uuid.UUID, wk *utils.WatchdogKick, timeoutSeconds uint) error { - conn, err := net.Dial("unix", types.VtpmdCtrlSocket) +// This is the Unix Domain Socket (UDS) transport for vTPM requests. +func vtpmClientUDSTransport() *http.Transport { + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", types.VtpmdCtrlSocket) + }, + } +} + +func makeRequest(client *http.Client, endpoint, id string) (string, error) { + url := fmt.Sprintf("http://unix/%s?id=%s", endpoint, id) + req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) + return "", fmt.Errorf("error when creating request: %v", err) } - defer conn.Close() + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %v", err) + } + defer resp.Body.Close() - pidPath := fmt.Sprintf(types.SwtpmPidPath, id.String()) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error when reading response body: %v", err) + } + if resp.StatusCode != http.StatusOK { + return string(body), fmt.Errorf("received status code %d", resp.StatusCode) + } + + return string(body), nil +} - // Send the request to the vTPM control socket, ask it to launch a swtpm instance. - _, err = conn.Write([]byte(fmt.Sprintf("%s%s\n", vtpmLaunchPrefix, id.String()))) +func requestvTPMLaunch(id uuid.UUID, wk *utils.WatchdogKick, timeoutSeconds uint) error { + body, err := makeRequest(vTPMClient, vtpmLaunchEndpoint, id.String()) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return fmt.Errorf("failed to launch vTPM instance: %w (%s)", err, body) } - // Loop and wait for SWTPM to start. + // Wait for SWTPM to start. + pidPath := fmt.Sprintf(types.SwtpmPidPath, id.String()) pid, err := utils.GetPidFromFileTimeout(pidPath, timeoutSeconds, wk) if err != nil { return fmt.Errorf("failed to get pid from file %s: %w", pidPath, err) @@ -1397,34 +1430,22 @@ func requestVtpmLaunch(id uuid.UUID, wk *utils.WatchdogKick, timeoutSeconds uint return nil } -func requestVtpmPurge(id uuid.UUID) error { - conn, err := net.Dial("unix", types.VtpmdCtrlSocket) - if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) - } - defer conn.Close() - +func requestvTPMPurge(id uuid.UUID) error { // Send a request to vTPM control socket, ask it to purge the instance // and all its data. - _, err = conn.Write([]byte(fmt.Sprintf("%s%s\n", vtpmPurgePrefix, id.String()))) + body, err := makeRequest(vTPMClient, vtpmPurgeEndpoint, id.String()) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return fmt.Errorf("failed to purge vTPM instance: %w (%s)", err, body) } return nil } -func requestVtpmTermination(id uuid.UUID) error { - conn, err := net.Dial("unix", types.VtpmdCtrlSocket) - if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) - } - defer conn.Close() - - // Send a request to the vTPM control socket, ask it to delete the instance. - _, err = conn.Write([]byte(fmt.Sprintf("%s%s\n", vtpmDeletePrefix, id.String()))) +func requestvTPMTermination(id uuid.UUID) error { + // Send a request to vTPM control socket, ask it to terminate the instance. + body, err := makeRequest(vTPMClient, vtpmTermEndpoint, id.String()) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return fmt.Errorf("failed to terminate vTPM instance: %w (%s)", err, body) } return nil diff --git a/pkg/vtpm/swtpm-vtpm/src/main.go b/pkg/vtpm/swtpm-vtpm/src/main.go index 7eec749b07..b3403adda2 100644 --- a/pkg/vtpm/swtpm-vtpm/src/main.go +++ b/pkg/vtpm/swtpm-vtpm/src/main.go @@ -7,14 +7,14 @@ package main import ( - "bufio" "fmt" "net" + "net/http" "os" "os/exec" - "os/signal" "strconv" "strings" + "sync" "syscall" "time" @@ -26,21 +26,18 @@ import ( ) const ( - swtpmPath = "/usr/bin/swtpm" - purgeReq = "purge" - terminateReq = "terminate" - launchReq = "launch" - maxInstances = 10 - maxIDLen = 128 - maxWaitTime = 3 //seconds + swtpmPath = "/usr/bin/swtpm" + maxInstances = 10 + maxPidWaitTime = 3 //seconds ) var ( liveInstances int + m sync.Mutex log *base.LogObject pids = make(map[string]int, 0) + // XXX : move the paths to types so we have everything EVE creates in one place. // These are defined as vars to be able to mock them in tests - // XXX : move this to types so we have everything EVE creates in one place. stateEncryptionKey = "/run/swtpm/%s.binkey" swtpmIsEncryptedPath = "/persist/swtpm/%s.encrypted" swtpmStatePath = "/persist/swtpm/tpm-state-%s" @@ -56,35 +53,11 @@ var ( } ) -func parseRequest(id string) (string, string, error) { - id = strings.TrimSpace(id) - if id == "" || len(id) > maxIDLen { - return "", "", fmt.Errorf("invalid SWTPM ID received") - } - - // breake the string and get the request type - split := strings.Split(id, ";") - if len(split) != 2 { - return "", "", fmt.Errorf("invalid SWTPM ID received (no request)") - } - - if split[1] == "" { - return "", "", fmt.Errorf("invalid SWTPM ID received (no id)") - } - - // request, id - return split[0], split[1], nil -} - func cleanupFiles(id string) { - statePath := fmt.Sprintf(swtpmStatePath, id) - ctrlSockPath := fmt.Sprintf(swtpmCtrlSockPath, id) - pidPath := fmt.Sprintf(swtpmPidPath, id) - isEncryptedPath := fmt.Sprintf(swtpmIsEncryptedPath, id) - os.RemoveAll(statePath) - os.Remove(ctrlSockPath) - os.Remove(pidPath) - os.Remove(isEncryptedPath) + os.RemoveAll(fmt.Sprintf(swtpmStatePath, id)) + os.Remove(fmt.Sprintf(swtpmCtrlSockPath, id)) + os.Remove(fmt.Sprintf(swtpmPidPath, id)) + os.Remove(fmt.Sprintf(swtpmIsEncryptedPath, id)) } func makeDirs(dir string) error { @@ -144,7 +117,7 @@ func runSwtpm(id string) (int, error) { // if SWTPM state for app marked as as encrypted, and TPM is not available // anymore, fail because this will corrupt the SWTPM state. if utils.FileExists(log, isEncryptedPath) { - return 0, fmt.Errorf("state encryption was enabled for app, but TPM is no longer available") + return 0, fmt.Errorf("state encryption was enabled for SWTPM, but TPM is no longer available") } cmd := exec.Command(swtpmPath, swtpmArgs...) @@ -159,7 +132,7 @@ func runSwtpm(id string) (int, error) { return 0, fmt.Errorf("failed to get SWTPM state encryption key : %w", err) } - // we are about to write the key to the disk, so mark the app SWTPM state + // we are about to write the key to the disk, so mark the SWTPM state // as encrypted if !utils.FileExists(log, isEncryptedPath) { if err := utils.WriteRename(isEncryptedPath, []byte("Y")); err != nil { @@ -183,121 +156,87 @@ func runSwtpm(id string) (int, error) { } } - pid, err := getSwtpmPid(pidPath, maxWaitTime) + pid, err := getSwtpmPid(pidPath, maxPidWaitTime) if err != nil { return 0, fmt.Errorf("failed to get SWTPM pid: %w", err) } - // Add to the list. + // Add it to the list. pids[id] = pid return pid, nil } -func main() { - log = base.NewSourceLogObject(logrus.StandardLogger(), "vtpm", os.Getpid()) - if log == nil { - fmt.Println("Failed to create log object") - os.Exit(1) - } - - serviceLoop() -} - -func serviceLoop() { - uds, err := net.Listen("unix", vtpmdCtrlSockPath) - if err != nil { - log.Errorf("failed to create vtpm control socket: %v", err) +// Domain manager is requesting to launch a new VM/App, run a new SWTPM +// instance with the given id. +func handleLaunch(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + err := fmt.Sprintf("Method %s not allowed", r.Method) + http.Error(w, err, http.StatusMethodNotAllowed) return } - defer uds.Close() - - // Make sure we remove the socket file on exit. - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-sigs - os.Remove(vtpmdCtrlSockPath) - os.Exit(0) - }() - for { - // xxx : later get peer creds (getpeereid) and check if the caller is - // the domain manager to avoid any other process from sending requests. - conn, err := uds.Accept() - if err != nil { - log.Errorf("failed to accept connection over vtpmd control socket: %v", err) - continue - } + // pids and liveInstances is shared, take care of them. + m.Lock() + defer m.Unlock() - id, err := bufio.NewReader(conn).ReadString('\n') - if err != nil { - log.Errorf("failed to read SWTPM ID from connection: %v", err) - continue - } - - // Close the connection as soon as we read the ID, - // handle one request at a time. - conn.Close() - - // Don't launch go routines, instead serve requests one by one to avoid - // using any locks, handle* functions access pids global variable. - err = handleRequest(id) - if err != nil { - log.Errorf("failed to handle request: %v", err) - } - } -} - -func handleRequest(id string) error { - request, id, err := parseRequest(id) - if err != nil { - return fmt.Errorf("failed to parse request: %w", err) - } - - switch request { - case launchReq: - // Domain manager is requesting to launch a new VM/App, run a new SWTPM - // instance with the given id. - return handleLaunch(id) - case purgeReq: - // VM/App is being deleted, domain manager is sending a purge request, - // delete the SWTPM instance and clean up all the files. - return handlePurge(id) - case terminateReq: - // Domain manager is sending a terminate request because it hit an error while - // starting the app (i.e. qemu crashed), so just remove kill SWTPM instance, - // remove it's pid for the list and decrease the liveInstances count. - return handleTerminate(id) - default: - return fmt.Errorf("invalid request received") + if liveInstances >= maxInstances { + err := fmt.Sprintf("vTPM max number of SWTPM instances reached %d", liveInstances) + http.Error(w, err, http.StatusTooManyRequests) + return } -} -func handleLaunch(id string) error { - if liveInstances >= maxInstances { - return fmt.Errorf("max number of vTPM instances reached %d", liveInstances) + id := r.URL.Query().Get("id") + if id == "" { + err := "vTPM launch request failed, id is required" + http.Error(w, err, http.StatusBadRequest) + return } // If we have SWTPM instance with the same id running, it means either the - // domain got rebooted or something went wrong on the dommain manager side!! - // it the later case it should have sent a delete request if VM crashed or + // domain got rebooted or something went wrong on the dommain manager side! + // it the later case it should have sent a terminate request if VM crashed or // there was any other VM related errors. Anyway, refuse to launch a new // instance with the same id as this might corrupt the state. if _, ok := pids[id]; ok { - return fmt.Errorf("SWTPM instance with id %s already running", id) + log.Warnf("SWTPM instance with id %s already running", id) + // technically request is satisfied + w.WriteHeader(http.StatusOK) + return } pid, err := runSwtpm(id) if err != nil { - return fmt.Errorf("failed to start SWTPM instance: %v", err) + err := fmt.Sprintf("vTPM failed to start SWTPM instance: %v", err) + http.Error(w, err, http.StatusFailedDependency) + return } - log.Noticef("SWTPM instance with id %s is running with pid %d", id, pid) + log.Noticef("vTPM launched SWTPM instance with id: %s, pid: %d", id, pid) + // Send a success response. liveInstances++ - return nil + w.WriteHeader(http.StatusOK) } -func handleTerminate(id string) error { +// Domain manager is sending a terminate request because it hit an error while +// starting the app (i.e. qemu crashed), so just kill SWTPM instance, +// remove it's pid from the list and decrease the liveInstances count. +func handleTerminate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + err := fmt.Sprintf("Method %s not allowed", r.Method) + http.Error(w, err, http.StatusMethodNotAllowed) + return + } + + // pids and liveInstances is shared, take care of it. + m.Lock() + defer m.Unlock() + + id := r.URL.Query().Get("id") + if id == "" { + err := "vTPM terminate request failed, id is required" + http.Error(w, err, http.StatusBadRequest) + return + } // We expect the SWTPM to be terminated at this point, but just in case send // a term signal. pid, ok := pids[id] @@ -305,20 +244,43 @@ func handleTerminate(id string) error { if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { if err != syscall.ESRCH { // This should not happen, but log it just in case. - log.Errorf("failed to kill SWTPM instance (terminate request): %v", err) + log.Errorf("vTPM failed to kill SWTPM instance (terminate request): %v", err) } } delete(pids, id) liveInstances-- } else { - return fmt.Errorf("terminate request failed, SWTPM instance with id %s not found", id) + err := fmt.Sprintf("vTPM terminate request failed, SWTPM instance with id %s not found", id) + http.Error(w, err, http.StatusNotFound) + return } - return nil + log.Noticef("vTPM terminated SWTPM instance with id: %s, pid: %d", id, pid) + + // send a success response. + w.WriteHeader(http.StatusOK) } -func handlePurge(id string) error { - log.Noticef("Purging SWTPM instance with id: %s", id) +// VM/App is being deleted, domain manager is sending a purge request, +// delete the SWTPM instance and clean up all the files. +func handlePurge(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + err := fmt.Sprintf("Method %s not allowed", r.Method) + http.Error(w, err, http.StatusMethodNotAllowed) + return + } + + // pids and liveInstances is shared, take care of it. + m.Lock() + defer m.Unlock() + + id := r.URL.Query().Get("id") + if id == "" { + err := "vTPM purge request failed, id is required" + http.Error(w, err, http.StatusBadRequest) + return + } + // we actually expect the SWTPM to be terminated at this point, because qemu // either sends CMD_SHUTDOWN through the control socket or in case of qemu // crashing, SWTPM terminates itself when the control socket is closed since @@ -331,12 +293,44 @@ func handlePurge(id string) error { log.Errorf("failed to kill SWTPM instance (purge request): %v", err) } } - delete(pids, id) liveInstances-- } - // clean up files if exists + log.Noticef("vTPM purged SWTPM instance with id: %s, pid: %d", id, pid) + + // clean up the files and send a success response. cleanupFiles(id) - return nil + w.WriteHeader(http.StatusOK) +} + +func startServing() { + if _, err := os.Stat(vtpmdCtrlSockPath); err == nil { + os.Remove(vtpmdCtrlSockPath) + } + listener, err := net.Listen("unix", vtpmdCtrlSockPath) + if err != nil { + log.Fatalf("Error creating Unix socket: %v", err) + } + defer listener.Close() + + os.Chmod(vtpmdCtrlSockPath, 0600) + mux := http.NewServeMux() + mux.HandleFunc("/launch", handleLaunch) + mux.HandleFunc("/terminate", handleTerminate) + mux.HandleFunc("/purge", handlePurge) + + log.Noticef("vTPM server is listening on Unix socket: %s", vtpmdCtrlSockPath) + http.Serve(listener, mux) +} + +func main() { + log = base.NewSourceLogObject(logrus.StandardLogger(), "vtpm", os.Getpid()) + if log == nil { + fmt.Println("Failed to create log object") + os.Exit(1) + } + + // this never returns, ideally. + startServing() } diff --git a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go index a82ecfd314..8695a5b6ab 100644 --- a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go +++ b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go @@ -4,9 +4,12 @@ package main import ( + "context" "crypto/rand" "fmt" + "io" "net" + "net/http" "os" "testing" "time" @@ -17,6 +20,8 @@ import ( const baseDir = "/tmp/swtpm/test" +var client = &http.Client{} + func TestMain(m *testing.M) { log = base.NewSourceLogObject(logrus.StandardLogger(), "vtpm", os.Getpid()) os.MkdirAll(baseDir, 0755) @@ -28,60 +33,57 @@ func TestMain(m *testing.M) { swtpmPidPath = baseDir + "/%s.pid" vtpmdCtrlSockPath = baseDir + "/vtpmd.ctrl.sock" - go serviceLoop() - defer func() { - _ = os.Remove(vtpmdCtrlSockPath) - }() + client = &http.Client{ + Transport: UnixSocketTransport(vtpmdCtrlSockPath), + Timeout: 5 * time.Second, + } + go startServing() time.Sleep(1 * time.Second) m.Run() } -func sendLaunchRequest(id string) error { - conn, err := net.Dial("unix", vtpmdCtrlSockPath) - if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) +func UnixSocketTransport(socketPath string) *http.Transport { + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, } - defer conn.Close() +} - _, err = conn.Write([]byte(fmt.Sprintf("%s;%s\n", launchReq, id))) +func makeRequest(client *http.Client, endpoint, id string) (string, int, error) { + url := fmt.Sprintf("http://unix/%s?id=%s", endpoint, id) + req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return "", -1, fmt.Errorf("error when creating request: %v", err) } - - return nil -} - -func sendPurgeRequest(id string) error { - conn, err := net.Dial("unix", vtpmdCtrlSockPath) + resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) + return "", -1, fmt.Errorf("request failed: %v", err) } - defer conn.Close() + defer resp.Body.Close() - _, err = conn.Write([]byte(fmt.Sprintf("%s;%s\n", purgeReq, id))) + body, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) + return "", -1, fmt.Errorf("error when reading response body: %v", err) + } + if resp.StatusCode != http.StatusOK { + return string(body), resp.StatusCode, fmt.Errorf("received status code %d \n", resp.StatusCode) } - time.Sleep(1 * time.Second) - return nil + return string(body), resp.StatusCode, nil } -func sendTerminateRequest(id string) error { - conn, err := net.Dial("unix", vtpmdCtrlSockPath) - if err != nil { - return fmt.Errorf("failed to connect to vTPM control socket: %w", err) - } - defer conn.Close() +func sendLaunchRequest(id string) (string, int, error) { + return makeRequest(client, "launch", id) +} - _, err = conn.Write([]byte(fmt.Sprintf("%s;%s\n", terminateReq, id))) - if err != nil { - return fmt.Errorf("failed to write to vTPM control socket: %w", err) - } +func sendPurgeRequest(id string) (string, int, error) { + return makeRequest(client, "purge", id) +} - time.Sleep(1 * time.Second) - return nil +func sendTerminateRequest(id string) (string, int, error) { + return makeRequest(client, "terminate", id) } func testLaunchAndPurge(t *testing.T, id string) { @@ -90,9 +92,9 @@ func testLaunchAndPurge(t *testing.T, id string) { // 2. check number of live instances, it should be 1 // 3. send purge request // 4. check number of live instances, it should be 0 - err := sendLaunchRequest(id) + b, _, err := sendLaunchRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } time.Sleep(1 * time.Second) @@ -100,10 +102,11 @@ func testLaunchAndPurge(t *testing.T, id string) { t.Fatalf("expected liveInstances to be 1, got %d", liveInstances) } - err = sendPurgeRequest(id) + b, _, err = sendPurgeRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) if liveInstances != 0 { t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) @@ -112,27 +115,28 @@ func testLaunchAndPurge(t *testing.T, id string) { func testExhaustSwtpmInstances(t *testing.T, id string) { for i := 0; i < maxInstances; i++ { - err := sendLaunchRequest(fmt.Sprintf("%s-%d", id, i)) + b, _, err := sendLaunchRequest(fmt.Sprintf("%s-%d", id, i)) if err != nil { - t.Errorf("failed to send request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } } time.Sleep(5 * time.Second) // this should have no effect as we have reached max instances - err := sendLaunchRequest(id) - if err != nil { - t.Errorf("failed to send request: %v", err) + b, res, err := sendLaunchRequest(id) + if res != http.StatusTooManyRequests { + t.Fatalf("expected status code to be %d, got %d, err : %v, body: %s", http.StatusTooManyRequests, res, err, b) } if liveInstances != maxInstances { t.Errorf("expected liveInstances to be %d, got %d", maxInstances, liveInstances) } - err = sendPurgeRequest(fmt.Sprintf("%s-0", id)) + b, _, err = sendPurgeRequest(fmt.Sprintf("%s-0", id)) if err != nil { - t.Errorf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) if liveInstances != maxInstances-1 { t.Errorf("expected liveInstances to be %d, got %d", maxInstances-1, liveInstances) @@ -140,10 +144,11 @@ func testExhaustSwtpmInstances(t *testing.T, id string) { // clean up for i := 0; i < maxInstances; i++ { - err := sendPurgeRequest(fmt.Sprintf("%s-%d", id, i)) + b, _, err := sendPurgeRequest(fmt.Sprintf("%s-%d", id, i)) if err != nil { - t.Errorf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) } } @@ -210,9 +215,15 @@ func TestSwtpmStateChange(t *testing.T) { } // this mark the instance to be encrypted - err := sendLaunchRequest(id) + b, _, err := sendLaunchRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + time.Sleep(1 * time.Second) + + b, _, err = sendTerminateRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) } time.Sleep(1 * time.Second) @@ -221,18 +232,20 @@ func TestSwtpmStateChange(t *testing.T) { return false } - // this should fail since this id was marked as encrypted - err = sendLaunchRequest(id) - if err != nil { - t.Fatalf("failed to handle request: %v", err) + // this should fail since it was instance was marked as encrypted and now TPM is not available anymore + b, _, err = sendLaunchRequest(id) + if err == nil { + t.Fatalf("expected error, got nil") } + t.Logf("expected error: %v, body: %s", err, b) + if liveInstances > 1 { t.Fatalf("expected liveInstances to be 1, got %d", liveInstances) } - err = sendPurgeRequest(id) + b, _, err = sendPurgeRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } } @@ -245,15 +258,16 @@ func TestDeleteRequest(t *testing.T) { return false } - err := sendLaunchRequest(id) + b, _, err := sendLaunchRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } - err = sendTerminateRequest(id) + b, _, err = sendTerminateRequest(id) if err != nil { - t.Fatalf("failed to handle request: %v", err) + t.Fatalf("failed to send request: %v, body : %s", err, b) } + time.Sleep(1 * time.Second) if liveInstances != 0 { t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) From 13ba1c33f5e5aca9f69c8b9a1dfe112938129c6a Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Fri, 25 Oct 2024 13:32:39 +0300 Subject: [PATCH 2/9] Domainmgr : refactor virtual TPM setup and termination Use a defer function to ensure that the virtual TPM is always terminated when the domain manager hits an error during the setup process or boot process. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit 20da6cddcba4829b4ccb6ce722a6bc58955cc9ae) --- pkg/pillar/cmd/domainmgr/domainmgr.go | 47 +++++++++++++++++---------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/pkg/pillar/cmd/domainmgr/domainmgr.go b/pkg/pillar/cmd/domainmgr/domainmgr.go index fcfabd3e86..f6c2eb30af 100644 --- a/pkg/pillar/cmd/domainmgr/domainmgr.go +++ b/pkg/pillar/cmd/domainmgr/domainmgr.go @@ -1119,21 +1119,28 @@ func maybeRetryBoot(ctx *domainContext, status *types.DomainStatus) { } defer file.Close() - if err := hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime); err != nil { - log.Errorf("Failed to setup virtual TPM for %s: %s", status.DomainName, err) - status.VirtualTPM = false - } else { + err = hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime) + if err == nil { status.VirtualTPM = true + defer func(status *types.DomainStatus) { + if status.BootFailed || status.HasError() { + log.Noticef("Terminating vTPM for domain %s, BootFailed : %v, HasError: %v", + status.DomainName, status.BootFailed, status.HasError()) + + if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { + log.Errorf("Failed to terminate vTPM for %s: %s", status.DomainName, err) + } + } + }(status) + } else { + status.VirtualTPM = false + log.Errorf("Failed to setup vTPM for %s: %s", status.DomainName, err) } if err := hyper.Task(status).Setup(*status, *config, ctx.assignableAdapters, nil, file); err != nil { //it is retry, so omit error log.Errorf("Failed to create DomainStatus from %+v: %s", config, err) - - if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { - log.Errorf("Failed to terminate virtual TPM for %s: %s", status.DomainName, err) - } } status.TriedCount++ @@ -1671,11 +1678,22 @@ func doActivate(ctx *domainContext, config types.DomainConfig, } defer file.Close() - if err := hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime); err != nil { - log.Errorf("Failed to setup virtual TPM for %s: %s", status.DomainName, err) - status.VirtualTPM = false - } else { + err = hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime) + if err == nil { status.VirtualTPM = true + defer func(status *types.DomainStatus) { + if status.BootFailed || status.HasError() { + log.Noticef("Terminating vTPM for domain %s, BootFailed : %v, HasError: %v", + status.DomainName, status.BootFailed, status.HasError()) + + if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { + log.Errorf("Failed to terminate vTPM for %s: %s", status.DomainName, err) + } + } + }(status) + } else { + status.VirtualTPM = false + log.Errorf("Failed to setup vTPM for %s: %s", status.DomainName, err) } globalConfig := agentlog.GetGlobalConfig(log, ctx.subGlobalConfig) @@ -1684,11 +1702,6 @@ func doActivate(ctx *domainContext, config types.DomainConfig, config, err) status.SetErrorNow(err.Error()) releaseCPUs(ctx, &config, status) - - if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { - log.Errorf("Failed to terminate virtual TPM for %s: %s", status.DomainName, err) - } - return } From f98019a0147de44993f5003b87e7d7434c82c6b0 Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Fri, 25 Oct 2024 13:55:01 +0300 Subject: [PATCH 3/9] vTPM : fix bug when getting launch request When server gets a launch request, it checks if the the requested instance is already running, but it only checks the internal list and not actually the running instances. This can lead to server thinking the instance is running but client fails to get the PID with error "failed to get pid from file ...". Signed-off-by: Shahriyar Jalayeri (cherry picked from commit 25af0d65e99045c57d6872fcfebe250fa4cac3b1) --- pkg/vtpm/swtpm-vtpm/src/main.go | 73 ++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/pkg/vtpm/swtpm-vtpm/src/main.go b/pkg/vtpm/swtpm-vtpm/src/main.go index b3403adda2..ae8b0c75f6 100644 --- a/pkg/vtpm/swtpm-vtpm/src/main.go +++ b/pkg/vtpm/swtpm-vtpm/src/main.go @@ -28,7 +28,7 @@ import ( const ( swtpmPath = "/usr/bin/swtpm" maxInstances = 10 - maxPidWaitTime = 3 //seconds + maxPidWaitTime = 1 //seconds ) var ( @@ -53,6 +53,13 @@ var ( } ) +func isAlive(pid int) bool { + if err := syscall.Kill(pid, 0); err != nil { + return false + } + return true +} + func cleanupFiles(id string) { os.RemoveAll(fmt.Sprintf(swtpmStatePath, id)) os.Remove(fmt.Sprintf(swtpmCtrlSockPath, id)) @@ -75,19 +82,28 @@ func makeDirs(dir string) error { return nil } +func readPidFile(pidPath string) (int, error) { + pid := 0 + pidStr, err := os.ReadFile(pidPath) + if err == nil { + pid, err = strconv.Atoi(strings.TrimSpace(string(pidStr))) + if err == nil { + return pid, nil + } + } + + return 0, fmt.Errorf("failed to read pid file: %w", err) +} + func getSwtpmPid(pidPath string, timeoutSeconds uint) (int, error) { startTime := time.Now() for { if time.Since(startTime).Seconds() >= float64(timeoutSeconds) { - return 0, fmt.Errorf("timeout reached") + return 0, fmt.Errorf("timeout reached after %d seconds", int(time.Since(startTime).Seconds())) } - pidStr, err := os.ReadFile(pidPath) - if err == nil { - pid, err := strconv.Atoi(strings.TrimSpace(string(pidStr))) - if err == nil { - return pid, nil - } + if pid, err := readPidFile(pidPath); err == nil { + return pid, nil } time.Sleep(500 * time.Millisecond) @@ -191,16 +207,35 @@ func handleLaunch(w http.ResponseWriter, r *http.Request) { http.Error(w, err, http.StatusBadRequest) return } - // If we have SWTPM instance with the same id running, it means either the - // domain got rebooted or something went wrong on the dommain manager side! - // it the later case it should have sent a terminate request if VM crashed or - // there was any other VM related errors. Anyway, refuse to launch a new + // If we have a record of SWTPM instance with the requested id, + // check if it's still alive. if it is alive, refuse to launch a new // instance with the same id as this might corrupt the state. if _, ok := pids[id]; ok { - log.Warnf("SWTPM instance with id %s already running", id) - // technically request is satisfied - w.WriteHeader(http.StatusOK) - return + pidPath := fmt.Sprintf(swtpmPidPath, id) + // if pid file does not exist, it means the SWTPM instance gracefully + // terminated and we can start a new one. + if _, err := os.Stat(pidPath); err == nil { + pid, err := getSwtpmPid(pidPath, maxPidWaitTime) + if err != nil { + err := fmt.Sprintf("vTPM failed to read pid file of SWTPM with id %s", id) + http.Error(w, err, http.StatusExpectationFailed) + return + } + + // if the SWTPM instance is still alive, we can move on. + if isAlive(pid) { + log.Noticef("vTPM SWTPM instance with id %s is already running with pid %d", id, pid) + w.WriteHeader(http.StatusOK) + return + } + } else if !os.IsNotExist(err) { + err := fmt.Sprintf("vTPM failed to check pid file of SWTPM with id %s: %v", id, err) + http.Error(w, err, http.StatusFailedDependency) + return + } + + liveInstances-- + delete(pids, id) } pid, err := runSwtpm(id) @@ -218,8 +253,8 @@ func handleLaunch(w http.ResponseWriter, r *http.Request) { } // Domain manager is sending a terminate request because it hit an error while -// starting the app (i.e. qemu crashed), so just kill SWTPM instance, -// remove it's pid from the list and decrease the liveInstances count. +// starting the app, so just kill SWTPM instance, remove it's pid from the list +// and decrease the liveInstances count. func handleTerminate(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { err := fmt.Sprintf("Method %s not allowed", r.Method) @@ -237,8 +272,6 @@ func handleTerminate(w http.ResponseWriter, r *http.Request) { http.Error(w, err, http.StatusBadRequest) return } - // We expect the SWTPM to be terminated at this point, but just in case send - // a term signal. pid, ok := pids[id] if ok { if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { From 2bbdea13d0c24a78328f45738b927a94b854ff2e Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Fri, 25 Oct 2024 14:23:41 +0300 Subject: [PATCH 4/9] vTPM : validate id before using it in the request Validate ID before using it in, it must be in form of a UUID. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit 7294cce896f70fbda286ea8ffae6f298dc30b8ec) --- pkg/vtpm/swtpm-vtpm/go.mod | 2 +- pkg/vtpm/swtpm-vtpm/src/main.go | 33 ++++++++++-------- pkg/vtpm/swtpm-vtpm/src/vtpm_test.go | 50 ++++++++++++++++------------ 3 files changed, 48 insertions(+), 37 deletions(-) diff --git a/pkg/vtpm/swtpm-vtpm/go.mod b/pkg/vtpm/swtpm-vtpm/go.mod index d8a0396f21..8c03d78048 100644 --- a/pkg/vtpm/swtpm-vtpm/go.mod +++ b/pkg/vtpm/swtpm-vtpm/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/lf-edge/eve/pkg/pillar v0.0.0-20240820084217-3f317a9e801f + github.com/satori/go.uuid v1.2.1-0.20180404165556-75cca531ea76 github.com/sirupsen/logrus v1.9.3 ) @@ -20,7 +21,6 @@ require ( github.com/leodido/go-urn v1.2.4 // indirect github.com/lf-edge/eve-api/go v0.0.0-20240722173316-ed56da45126b // indirect github.com/lf-edge/eve/pkg/kube/cnirpc v0.0.0-20240315102754-0f6d1f182e0d // indirect - github.com/satori/go.uuid v1.2.1-0.20180404165556-75cca531ea76 // indirect github.com/vishvananda/netlink v1.2.1-beta.2 // indirect github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect golang.org/x/crypto v0.21.0 // indirect diff --git a/pkg/vtpm/swtpm-vtpm/src/main.go b/pkg/vtpm/swtpm-vtpm/src/main.go index ae8b0c75f6..c1f81d806a 100644 --- a/pkg/vtpm/swtpm-vtpm/src/main.go +++ b/pkg/vtpm/swtpm-vtpm/src/main.go @@ -22,6 +22,7 @@ import ( etpm "github.com/lf-edge/eve/pkg/pillar/evetpm" "github.com/lf-edge/eve/pkg/pillar/types" utils "github.com/lf-edge/eve/pkg/pillar/utils/file" + uuid "github.com/satori/go.uuid" "github.com/sirupsen/logrus" ) @@ -35,7 +36,7 @@ var ( liveInstances int m sync.Mutex log *base.LogObject - pids = make(map[string]int, 0) + pids = make(map[uuid.UUID]int, 0) // XXX : move the paths to types so we have everything EVE creates in one place. // These are defined as vars to be able to mock them in tests stateEncryptionKey = "/run/swtpm/%s.binkey" @@ -60,7 +61,8 @@ func isAlive(pid int) bool { return true } -func cleanupFiles(id string) { +func cleanupFiles(uuid uuid.UUID) { + id := uuid.String() os.RemoveAll(fmt.Sprintf(swtpmStatePath, id)) os.Remove(fmt.Sprintf(swtpmCtrlSockPath, id)) os.Remove(fmt.Sprintf(swtpmPidPath, id)) @@ -110,7 +112,8 @@ func getSwtpmPid(pidPath string, timeoutSeconds uint) (int, error) { } } -func runSwtpm(id string) (int, error) { +func runSwtpm(uuid uuid.UUID) (int, error) { + id := uuid.String() statePath := fmt.Sprintf(swtpmStatePath, id) ctrlSockPath := fmt.Sprintf(swtpmCtrlSockPath, id) binKeyPath := fmt.Sprintf(stateEncryptionKey, id) @@ -177,8 +180,6 @@ func runSwtpm(id string) (int, error) { return 0, fmt.Errorf("failed to get SWTPM pid: %w", err) } - // Add it to the list. - pids[id] = pid return pid, nil } @@ -201,9 +202,10 @@ func handleLaunch(w http.ResponseWriter, r *http.Request) { return } - id := r.URL.Query().Get("id") - if id == "" { - err := "vTPM launch request failed, id is required" + reqID := r.URL.Query().Get("id") + id := uuid.FromStringOrNil(reqID) + if id == uuid.Nil { + err := fmt.Sprintf("vTPM launch request failed, id \"%s\" is invalid", reqID) http.Error(w, err, http.StatusBadRequest) return } @@ -248,6 +250,7 @@ func handleLaunch(w http.ResponseWriter, r *http.Request) { log.Noticef("vTPM launched SWTPM instance with id: %s, pid: %d", id, pid) // Send a success response. + pids[id] = pid liveInstances++ w.WriteHeader(http.StatusOK) } @@ -266,9 +269,10 @@ func handleTerminate(w http.ResponseWriter, r *http.Request) { m.Lock() defer m.Unlock() - id := r.URL.Query().Get("id") - if id == "" { - err := "vTPM terminate request failed, id is required" + reqID := r.URL.Query().Get("id") + id := uuid.FromStringOrNil(reqID) + if id == uuid.Nil { + err := fmt.Sprintf("vTPM launch request failed, id \"%s\" is invalid", reqID) http.Error(w, err, http.StatusBadRequest) return } @@ -307,9 +311,10 @@ func handlePurge(w http.ResponseWriter, r *http.Request) { m.Lock() defer m.Unlock() - id := r.URL.Query().Get("id") - if id == "" { - err := "vTPM purge request failed, id is required" + reqID := r.URL.Query().Get("id") + id := uuid.FromStringOrNil(reqID) + if id == uuid.Nil { + err := fmt.Sprintf("vTPM launch request failed, id \"%s\" is invalid", reqID) http.Error(w, err, http.StatusBadRequest) return } diff --git a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go index 8695a5b6ab..3f7152abb7 100644 --- a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go +++ b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/lf-edge/eve/pkg/pillar/base" + uuid "github.com/satori/go.uuid" "github.com/sirupsen/logrus" ) @@ -43,6 +44,11 @@ func TestMain(m *testing.M) { m.Run() } +func generateUUID() uuid.UUID { + id, _ := uuid.NewV4() + return id +} + func UnixSocketTransport(socketPath string) *http.Transport { return &http.Transport{ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { @@ -74,19 +80,19 @@ func makeRequest(client *http.Client, endpoint, id string) (string, int, error) return string(body), resp.StatusCode, nil } -func sendLaunchRequest(id string) (string, int, error) { - return makeRequest(client, "launch", id) +func sendLaunchRequest(id uuid.UUID) (string, int, error) { + return makeRequest(client, "launch", id.String()) } -func sendPurgeRequest(id string) (string, int, error) { - return makeRequest(client, "purge", id) +func sendPurgeRequest(id uuid.UUID) (string, int, error) { + return makeRequest(client, "purge", id.String()) } -func sendTerminateRequest(id string) (string, int, error) { - return makeRequest(client, "terminate", id) +func sendTerminateRequest(id uuid.UUID) (string, int, error) { + return makeRequest(client, "terminate", id.String()) } -func testLaunchAndPurge(t *testing.T, id string) { +func testLaunchAndPurge(t *testing.T, id uuid.UUID) { // test logic : // 1. send launch request // 2. check number of live instances, it should be 1 @@ -113,17 +119,21 @@ func testLaunchAndPurge(t *testing.T, id string) { } } -func testExhaustSwtpmInstances(t *testing.T, id string) { +func testExhaustSwtpmInstances(t *testing.T) { + ids := make([]uuid.UUID, 0) for i := 0; i < maxInstances; i++ { - b, _, err := sendLaunchRequest(fmt.Sprintf("%s-%d", id, i)) + id := generateUUID() + b, _, err := sendLaunchRequest(id) if err != nil { t.Fatalf("failed to send request: %v, body : %s", err, b) } + defer cleanupFiles(id) + ids = append(ids, id) } time.Sleep(5 * time.Second) // this should have no effect as we have reached max instances - b, res, err := sendLaunchRequest(id) + b, res, err := sendLaunchRequest(generateUUID()) if res != http.StatusTooManyRequests { t.Fatalf("expected status code to be %d, got %d, err : %v, body: %s", http.StatusTooManyRequests, res, err, b) } @@ -132,7 +142,7 @@ func testExhaustSwtpmInstances(t *testing.T, id string) { t.Errorf("expected liveInstances to be %d, got %d", maxInstances, liveInstances) } - b, _, err = sendPurgeRequest(fmt.Sprintf("%s-0", id)) + b, _, err = sendPurgeRequest(ids[0]) if err != nil { t.Fatalf("failed to send request: %v, body : %s", err, b) } @@ -144,7 +154,7 @@ func testExhaustSwtpmInstances(t *testing.T, id string) { // clean up for i := 0; i < maxInstances; i++ { - b, _, err := sendPurgeRequest(fmt.Sprintf("%s-%d", id, i)) + b, _, err := sendPurgeRequest(ids[i]) if err != nil { t.Fatalf("failed to send request: %v, body : %s", err, b) } @@ -153,7 +163,7 @@ func testExhaustSwtpmInstances(t *testing.T, id string) { } func TestLaunchAndPurgeWithoutStateEncryption(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) isTPMAvailable = func() bool { return false @@ -163,7 +173,7 @@ func TestLaunchAndPurgeWithoutStateEncryption(t *testing.T) { } func TestLaunchAndPurgeWithStateEncryption(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) isTPMAvailable = func() bool { return true @@ -178,18 +188,14 @@ func TestLaunchAndPurgeWithStateEncryption(t *testing.T) { } func TestExhaustSwtpmInstancesWithoutStateEncryption(t *testing.T) { - id := "test" - defer cleanupFiles(id) isTPMAvailable = func() bool { return false } - testExhaustSwtpmInstances(t, id) + testExhaustSwtpmInstances(t) } func TestExhaustSwtpmInstancesWithStateEncryption(t *testing.T) { - id := "test" - defer cleanupFiles(id) isTPMAvailable = func() bool { return true } @@ -199,11 +205,11 @@ func TestExhaustSwtpmInstancesWithStateEncryption(t *testing.T) { return key, nil } - testExhaustSwtpmInstances(t, id) + testExhaustSwtpmInstances(t) } func TestSwtpmStateChange(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) isTPMAvailable = func() bool { return true @@ -250,7 +256,7 @@ func TestSwtpmStateChange(t *testing.T) { } func TestDeleteRequest(t *testing.T) { - id := "test" + id := generateUUID() defer cleanupFiles(id) // this doesn't matter From 1e7fd4c7607c2d871364b43ed2579be8f788591a Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Mon, 28 Oct 2024 11:03:03 +0200 Subject: [PATCH 5/9] proc utils : rename wd kicker Rename wd kicker in proc utils. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit 9db3b4d549c96a94ed5cd4bf2e3cbcb298eed782) --- pkg/pillar/utils/proc.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/pkg/pillar/utils/proc.go b/pkg/pillar/utils/proc.go index aa436e7528..03d4ace6f7 100644 --- a/pkg/pillar/utils/proc.go +++ b/pkg/pillar/utils/proc.go @@ -15,18 +15,18 @@ import ( "github.com/lf-edge/eve/pkg/pillar/pubsub" ) -// WatchdogKick is used in some proc functions that have a timeout, +// WatchdogKicker is used in some proc functions that have a timeout, // to tell the watchdog agent is still alive. -type WatchdogKick struct { +type WatchdogKicker struct { ps *pubsub.PubSub agentName string warnTime time.Duration errTime time.Duration } -// NewWatchdogKick creates a new WatchdogKick. -func NewWatchdogKick(ps *pubsub.PubSub, agentName string, warnTime time.Duration, errTime time.Duration) *WatchdogKick { - return &WatchdogKick{ +// NewWatchdogKicker creates a new WatchdogKick. +func NewWatchdogKicker(ps *pubsub.PubSub, agentName string, warnTime time.Duration, errTime time.Duration) *WatchdogKicker { + return &WatchdogKicker{ ps: ps, agentName: agentName, warnTime: warnTime, @@ -34,6 +34,11 @@ func NewWatchdogKick(ps *pubsub.PubSub, agentName string, warnTime time.Duration } } +// KickWatchdog tells the watchdog agent is still alive. +func KickWatchdog(wk *WatchdogKicker) { + wk.ps.StillRunning(wk.agentName, wk.warnTime, wk.errTime) +} + // PkillArgs does a pkill func PkillArgs(log *base.LogObject, match string, printOnError bool, kill bool) { cmd := "pkill" @@ -85,7 +90,7 @@ func GetPidFromFile(pidFile string) (int, error) { } // GetPidFromFileTimeout reads a pid from a file with a timeout. -func GetPidFromFileTimeout(pidFile string, timeoutSeconds uint, wk *WatchdogKick) (int, error) { +func GetPidFromFileTimeout(pidFile string, timeoutSeconds uint, wk *WatchdogKicker) (int, error) { startTime := time.Now() for { if time.Since(startTime).Seconds() >= float64(timeoutSeconds) { @@ -118,7 +123,7 @@ func IsProcAlive(pid int) bool { } // IsProcAliveTimeout checks if a process is alive for a given timeout. -func IsProcAliveTimeout(pid int, timeoutSeconds uint, wk *WatchdogKick) bool { +func IsProcAliveTimeout(pid int, timeoutSeconds uint, wk *WatchdogKicker) bool { startTime := time.Now() for { if time.Since(startTime).Seconds() >= float64(timeoutSeconds) { From dbdeb146fbae60a5ea375d34fb20cfb5124b761a Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Mon, 28 Oct 2024 12:04:55 +0200 Subject: [PATCH 6/9] domainmgr : call vTPM asynchronously and refactor setup functions Refactor vTPM setup/term/teardown functions to call the vTPM server endpoints asynchronously, this remove the timeout guessworks and make the vTPM setup more reliable. Refactor vTPM setup functions to accept all watchdog related parameters as struct. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit 18abc71c5305845cc8d52499b983c5c749878c39) --- pkg/pillar/cmd/domainmgr/domainmgr.go | 37 +++++----- pkg/pillar/hypervisor/containerd.go | 7 +- pkg/pillar/hypervisor/kubevirt.go | 7 +- pkg/pillar/hypervisor/kvm.go | 100 +++++++++++++++++--------- pkg/pillar/hypervisor/null.go | 8 +-- pkg/pillar/hypervisor/xen.go | 8 +-- pkg/pillar/types/domainmgrtypes.go | 15 +++- 7 files changed, 111 insertions(+), 71 deletions(-) diff --git a/pkg/pillar/cmd/domainmgr/domainmgr.go b/pkg/pillar/cmd/domainmgr/domainmgr.go index f6c2eb30af..5cf87227bb 100644 --- a/pkg/pillar/cmd/domainmgr/domainmgr.go +++ b/pkg/pillar/cmd/domainmgr/domainmgr.go @@ -1119,19 +1119,20 @@ func maybeRetryBoot(ctx *domainContext, status *types.DomainStatus) { } defer file.Close() - err = hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime) + wp := &types.WatchdogParam{Ps: ctx.ps, AgentName: agentName, WarnTime: warningTime, ErrTime: errorTime} + err = hyper.Task(status).VirtualTPMSetup(status.DomainName, wp) if err == nil { status.VirtualTPM = true - defer func(status *types.DomainStatus) { - if status.BootFailed || status.HasError() { - log.Noticef("Terminating vTPM for domain %s, BootFailed : %v, HasError: %v", - status.DomainName, status.BootFailed, status.HasError()) - - if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { + defer func(status *types.DomainStatus, wp *types.WatchdogParam) { + // this means we failed to boot the VM. + if !status.Activated { + log.Noticef("Failed to activate domain: %s, terminating vTPM", status.DomainName) + if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName, wp); err != nil { + // this is not a critical failure so just log it log.Errorf("Failed to terminate vTPM for %s: %s", status.DomainName, err) } } - }(status) + }(status, wp) } else { status.VirtualTPM = false log.Errorf("Failed to setup vTPM for %s: %s", status.DomainName, err) @@ -1678,19 +1679,20 @@ func doActivate(ctx *domainContext, config types.DomainConfig, } defer file.Close() - err = hyper.Task(status).VirtualTPMSetup(status.DomainName, agentName, ctx.ps, warningTime, errorTime) + wp := &types.WatchdogParam{Ps: ctx.ps, AgentName: agentName, WarnTime: warningTime, ErrTime: errorTime} + err = hyper.Task(status).VirtualTPMSetup(status.DomainName, wp) if err == nil { status.VirtualTPM = true - defer func(status *types.DomainStatus) { - if status.BootFailed || status.HasError() { - log.Noticef("Terminating vTPM for domain %s, BootFailed : %v, HasError: %v", - status.DomainName, status.BootFailed, status.HasError()) - - if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName); err != nil { + defer func(status *types.DomainStatus, wp *types.WatchdogParam) { + // this means we failed to boot the VM. + if !status.Activated { + log.Noticef("Failed to activate domain: %s, terminating vTPM", status.DomainName) + if err := hyper.Task(status).VirtualTPMTerminate(status.DomainName, wp); err != nil { + // this is not a critical failure so just log it log.Errorf("Failed to terminate vTPM for %s: %s", status.DomainName, err) } } - }(status) + }(status, wp) } else { status.VirtualTPM = false log.Errorf("Failed to setup vTPM for %s: %s", status.DomainName, err) @@ -2485,7 +2487,8 @@ func handleDelete(ctx *domainContext, key string, status *types.DomainStatus) { log.Errorln(err) } - if err := hyper.Task(status).VirtualTPMTeardown(status.DomainName); err != nil { + wp := &types.WatchdogParam{Ps: ctx.ps, AgentName: agentName, WarnTime: warningTime, ErrTime: errorTime} + if err := hyper.Task(status).VirtualTPMTeardown(status.DomainName, wp); err != nil { log.Errorln(err) } diff --git a/pkg/pillar/hypervisor/containerd.go b/pkg/pillar/hypervisor/containerd.go index fc8fbcc7c2..1dd1c8b610 100644 --- a/pkg/pillar/hypervisor/containerd.go +++ b/pkg/pillar/hypervisor/containerd.go @@ -11,7 +11,6 @@ import ( "time" "github.com/lf-edge/eve/pkg/pillar/containerd" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" "github.com/opencontainers/runtime-spec/specs-go" @@ -327,14 +326,14 @@ func (ctx ctrdContext) GetDomsCPUMem() (map[string]types.DomainMetric, error) { return res, nil } -func (ctx ctrdContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx ctrdContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx ctrdContext) VirtualTPMTerminate(domainName string) error { +func (ctx ctrdContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx ctrdContext) VirtualTPMTeardown(domainName string) error { +func (ctx ctrdContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/hypervisor/kubevirt.go b/pkg/pillar/hypervisor/kubevirt.go index 4edb6bc5fb..ae3c3e128f 100644 --- a/pkg/pillar/hypervisor/kubevirt.go +++ b/pkg/pillar/hypervisor/kubevirt.go @@ -22,7 +22,6 @@ import ( "time" "github.com/lf-edge/eve/pkg/pillar/base" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" netattdefv1 "github.com/k8snetworkplumbingwg/network-attachment-definition-client/pkg/apis/k8s.cni.cncf.io/v1" @@ -1362,14 +1361,14 @@ func (ctx kubevirtContext) PCISameController(id1 string, id2 string) bool { return PCISameControllerGeneric(id1, id2) } -func (ctx kubevirtContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx kubevirtContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx kubevirtContext) VirtualTPMTerminate(domainName string) error { +func (ctx kubevirtContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx kubevirtContext) VirtualTPMTeardown(domainName string) error { +func (ctx kubevirtContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/hypervisor/kvm.go b/pkg/pillar/hypervisor/kvm.go index a06655d158..043df98a30 100644 --- a/pkg/pillar/hypervisor/kvm.go +++ b/pkg/pillar/hypervisor/kvm.go @@ -20,7 +20,6 @@ import ( zconfig "github.com/lf-edge/eve-api/go/config" "github.com/lf-edge/eve/pkg/pillar/agentlog" "github.com/lf-edge/eve/pkg/pillar/containerd" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" "github.com/lf-edge/eve/pkg/pillar/utils" fileutils "github.com/lf-edge/eve/pkg/pillar/utils/file" @@ -44,10 +43,16 @@ var ( clientCid = uint32(unix.VMADDR_CID_HOST + 1) vTPMClient = &http.Client{ Transport: vtpmClientUDSTransport(), - Timeout: 2 * time.Second, + Timeout: 5 * time.Second, } ) +// vtpmRequestResult holds the result of a vTPM request. +type vtpmRequestResult struct { + Body string + Error error +} + // We build device model around PCIe topology according to best practices // https://github.com/qemu/qemu/blob/master/docs/pcie.txt // and @@ -1339,38 +1344,46 @@ func getQmpListenerSocket(domainName string) string { } // VirtualTPMSetup launches a vTPM instance for the domain -func (ctx KvmContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { - if ps != nil { - wk := utils.NewWatchdogKick(ps, agentName, warnTime, errTime) - domainUUID, _, _, err := types.DomainnameToUUID(domainName) - if err != nil { - return fmt.Errorf("failed to extract UUID from domain name: %v", err) - } - return requestvTPMLaunch(domainUUID, wk, swtpmTimeout) +func (ctx KvmContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { + if wp == nil { + return fmt.Errorf("invalid watchdog configuration") + } + + domainUUID, _, _, err := types.DomainnameToUUID(domainName) + if err != nil { + return fmt.Errorf("failed to extract UUID from domain name: %v", err) } + return requestvTPMLaunch(domainUUID, wp, swtpmTimeout) - return fmt.Errorf("invalid watchdog configuration") } // VirtualTPMTerminate terminates the vTPM instance -func (ctx KvmContext) VirtualTPMTerminate(domainName string) error { +func (ctx KvmContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { + if wp == nil { + return fmt.Errorf("invalid watchdog configuration") + } + domainUUID, _, _, err := types.DomainnameToUUID(domainName) if err != nil { return fmt.Errorf("failed to extract UUID from domain name: %v", err) } - if err := requestvTPMTermination(domainUUID); err != nil { + if err := requestvTPMTermination(domainUUID, wp); err != nil { return fmt.Errorf("failed to terminate vTPM for domain %s: %w", domainName, err) } return nil } // VirtualTPMTeardown purges the vTPM instance. -func (ctx KvmContext) VirtualTPMTeardown(domainName string) error { +func (ctx KvmContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { + if wp == nil { + return fmt.Errorf("invalid watchdog configuration") + } + domainUUID, _, _, err := types.DomainnameToUUID(domainName) if err != nil { return fmt.Errorf("failed to extract UUID from domain name: %v", err) } - if err := requestvTPMPurge(domainUUID); err != nil { + if err := requestvTPMPurge(domainUUID, wp); err != nil { return fmt.Errorf("failed to purge vTPM for domain %s: %w", domainName, err) } @@ -1386,54 +1399,74 @@ func vtpmClientUDSTransport() *http.Transport { } } -func makeRequest(client *http.Client, endpoint, id string) (string, error) { +func makeRequestAsync(client *http.Client, endpoint, id string, rChan chan<- vtpmRequestResult) { url := fmt.Sprintf("http://unix/%s?id=%s", endpoint, id) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return "", fmt.Errorf("error when creating request: %v", err) + rChan <- vtpmRequestResult{Error: fmt.Errorf("error when creating request to %s endpoint: %v", url, err)} + return } resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("request failed: %v", err) + rChan <- vtpmRequestResult{Error: fmt.Errorf("error when sending request to %s endpoint: %v", url, err)} + return } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("error when reading response body: %v", err) + rChan <- vtpmRequestResult{Error: fmt.Errorf("error when reading response body from %s endpoint: %v", url, err)} + return } if resp.StatusCode != http.StatusOK { - return string(body), fmt.Errorf("received status code %d", resp.StatusCode) + rChan <- vtpmRequestResult{Error: fmt.Errorf("received status code %d from %s endpoint", resp.StatusCode, url), Body: string(body)} + return } - return string(body), nil + rChan <- vtpmRequestResult{Body: string(body)} } -func requestvTPMLaunch(id uuid.UUID, wk *utils.WatchdogKick, timeoutSeconds uint) error { - body, err := makeRequest(vTPMClient, vtpmLaunchEndpoint, id.String()) +func makeRequest(client *http.Client, wk *utils.WatchdogKicker, endpoint, id string) (body string, err error) { + rChan := make(chan vtpmRequestResult) + go makeRequestAsync(client, endpoint, id, rChan) + + startTime := time.Now() + for { + select { + case res := <-rChan: + return res.Body, res.Error + default: + utils.KickWatchdog(wk) + if time.Since(startTime).Seconds() >= float64(client.Timeout.Seconds()) { + return "", fmt.Errorf("timeout") + } + time.Sleep(100 * time.Millisecond) + } + } +} + +func requestvTPMLaunch(id uuid.UUID, wp *types.WatchdogParam, timeoutSeconds uint) error { + wk := utils.NewWatchdogKicker(wp.Ps, wp.AgentName, wp.WarnTime, wp.ErrTime) + body, err := makeRequest(vTPMClient, wk, vtpmLaunchEndpoint, id.String()) if err != nil { return fmt.Errorf("failed to launch vTPM instance: %w (%s)", err, body) } // Wait for SWTPM to start. pidPath := fmt.Sprintf(types.SwtpmPidPath, id.String()) - pid, err := utils.GetPidFromFileTimeout(pidPath, timeoutSeconds, wk) + _, err = utils.GetPidFromFileTimeout(pidPath, timeoutSeconds, wk) if err != nil { return fmt.Errorf("failed to get pid from file %s: %w", pidPath, err) } - // One last time, check SWTPM is not dead right after launch. - if !utils.IsProcAlive(pid) { - return fmt.Errorf("SWTPM (pid: %d) is dead", pid) - } - return nil } -func requestvTPMPurge(id uuid.UUID) error { +func requestvTPMPurge(id uuid.UUID, wp *types.WatchdogParam) error { // Send a request to vTPM control socket, ask it to purge the instance // and all its data. - body, err := makeRequest(vTPMClient, vtpmPurgeEndpoint, id.String()) + wk := utils.NewWatchdogKicker(wp.Ps, wp.AgentName, wp.WarnTime, wp.ErrTime) + body, err := makeRequest(vTPMClient, wk, vtpmPurgeEndpoint, id.String()) if err != nil { return fmt.Errorf("failed to purge vTPM instance: %w (%s)", err, body) } @@ -1441,9 +1474,10 @@ func requestvTPMPurge(id uuid.UUID) error { return nil } -func requestvTPMTermination(id uuid.UUID) error { +func requestvTPMTermination(id uuid.UUID, wp *types.WatchdogParam) error { // Send a request to vTPM control socket, ask it to terminate the instance. - body, err := makeRequest(vTPMClient, vtpmTermEndpoint, id.String()) + wk := utils.NewWatchdogKicker(wp.Ps, wp.AgentName, wp.WarnTime, wp.ErrTime) + body, err := makeRequest(vTPMClient, wk, vtpmTermEndpoint, id.String()) if err != nil { return fmt.Errorf("failed to terminate vTPM instance: %w (%s)", err, body) } diff --git a/pkg/pillar/hypervisor/null.go b/pkg/pillar/hypervisor/null.go index 899c35ed27..f1daf5d5d5 100644 --- a/pkg/pillar/hypervisor/null.go +++ b/pkg/pillar/hypervisor/null.go @@ -6,9 +6,7 @@ package hypervisor import ( "fmt" "os" - "time" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" uuid "github.com/satori/go.uuid" @@ -174,14 +172,14 @@ func (ctx nullContext) GetDomsCPUMem() (map[string]types.DomainMetric, error) { return nil, nil } -func (ctx nullContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx nullContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx nullContext) VirtualTPMTerminate(domainName string) error { +func (ctx nullContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx nullContext) VirtualTPMTeardown(domainName string) error { +func (ctx nullContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/hypervisor/xen.go b/pkg/pillar/hypervisor/xen.go index 729f0d7432..323fa2233a 100644 --- a/pkg/pillar/hypervisor/xen.go +++ b/pkg/pillar/hypervisor/xen.go @@ -11,9 +11,7 @@ import ( "runtime" "strconv" "strings" - "time" - "github.com/lf-edge/eve/pkg/pillar/pubsub" "github.com/lf-edge/eve/pkg/pillar/types" "github.com/shirou/gopsutil/cpu" "github.com/shirou/gopsutil/mem" @@ -877,14 +875,14 @@ func fallbackDomainMetric() map[string]types.DomainMetric { return dmList } -func (ctx xenContext) VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error { +func (ctx xenContext) VirtualTPMSetup(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx xenContext) VirtualTPMTerminate(domainName string) error { +func (ctx xenContext) VirtualTPMTerminate(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } -func (ctx xenContext) VirtualTPMTeardown(domainName string) error { +func (ctx xenContext) VirtualTPMTeardown(domainName string, wp *types.WatchdogParam) error { return fmt.Errorf("not implemented") } diff --git a/pkg/pillar/types/domainmgrtypes.go b/pkg/pillar/types/domainmgrtypes.go index 3e7bd762cc..2ea433d1f7 100644 --- a/pkg/pillar/types/domainmgrtypes.go +++ b/pkg/pillar/types/domainmgrtypes.go @@ -277,9 +277,9 @@ const ( type Task interface { Setup(DomainStatus, DomainConfig, *AssignableAdapters, *ConfigItemValueMap, *os.File) error - VirtualTPMSetup(domainName, agentName string, ps *pubsub.PubSub, warnTime, errTime time.Duration) error - VirtualTPMTerminate(domainName string) error - VirtualTPMTeardown(domainName string) error + VirtualTPMSetup(domainName string, wp *WatchdogParam) error + VirtualTPMTerminate(domainName string, wp *WatchdogParam) error + VirtualTPMTeardown(domainName string, wp *WatchdogParam) error Create(string, string, *DomainConfig) (int, error) Start(string) error Stop(string, bool) error @@ -580,3 +580,12 @@ type Capabilities struct { CPUPinning bool // CPU Pinning support UseVHost bool // vHost support } + +// WatchdogParam is used in some proc functions that have a timeout, +// to tell the watchdog agent is still alive. +type WatchdogParam struct { + Ps *pubsub.PubSub + AgentName string + WarnTime time.Duration + ErrTime time.Duration +} From aa24ad54ebdd74816b405f48ade07cedf2cb5393 Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Mon, 28 Oct 2024 12:09:02 +0200 Subject: [PATCH 7/9] vTPM : more relaxed timeout The domainmanager calls vTPM server asynchronously, so we dont need to worry and set the wait time too low to return quicly to prevent a watchdog kill on pillar. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit bd856c7fe1a0771897d7a71e81913f785919026b) --- pkg/vtpm/swtpm-vtpm/src/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/vtpm/swtpm-vtpm/src/main.go b/pkg/vtpm/swtpm-vtpm/src/main.go index c1f81d806a..8f47ee32b6 100644 --- a/pkg/vtpm/swtpm-vtpm/src/main.go +++ b/pkg/vtpm/swtpm-vtpm/src/main.go @@ -29,7 +29,7 @@ import ( const ( swtpmPath = "/usr/bin/swtpm" maxInstances = 10 - maxPidWaitTime = 1 //seconds + maxPidWaitTime = 5 //seconds ) var ( From aa9611fbb9772fb6d2f87e3c2dfca58a1a7b222b Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Tue, 29 Oct 2024 15:20:40 +0200 Subject: [PATCH 8/9] add vtpm vendor directory to .spdxignore Add vtpm vendor directory to .spdxignore. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit 5d4f771e66c41e2f4f9a543ed71b5fbc75f217e2) Signed-off-by: Shahriyar Jalayeri --- .spdxignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.spdxignore b/.spdxignore index c089d6d0f4..82455009bd 100644 --- a/.spdxignore +++ b/.spdxignore @@ -10,3 +10,4 @@ pkg/rngd/cmd/rngd/vendor/ pkg/wwan/mmagent/vendor/ tools/get-deps/vendor/ pkg/installer/vendor/ +pkg/vtpm/swtpm-vtpm/vendor/ From ec97f9745c64eeb89c9add8c7d0a22371f4e86f1 Mon Sep 17 00:00:00 2001 From: Shahriyar Jalayeri Date: Mon, 4 Nov 2024 12:25:20 +0200 Subject: [PATCH 9/9] Add new vTPM tests The TestSwtpmAbruptTerminationRequest function verifies that if swtpm is terminated without vTPM notice, no stale id is left in the vtpm internal bookkeeping and vtpm can launch a new instance with the same id. The TestSwtpmMultipleLaucnhRequest function verifies that if swtpm is launched multiple times with the same id, only one instance is created and other requests are ignored. Signed-off-by: Shahriyar Jalayeri (cherry picked from commit bc80a42d9ea733bf5297d7220f6cc3050d16e4f0) --- pkg/vtpm/swtpm-vtpm/src/vtpm_test.go | 89 ++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go index 3f7152abb7..a875e876dd 100644 --- a/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go +++ b/pkg/vtpm/swtpm-vtpm/src/vtpm_test.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "os" + "syscall" "testing" "time" @@ -279,3 +280,91 @@ func TestDeleteRequest(t *testing.T) { t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) } } + +func TestSwtpmAbruptTerminationRequest(t *testing.T) { + // this test verify that if swtpm is terminated without vTPM notice, + // no stale id is left in the vtpm internal bookkeeping and vtpm + // can launch new instance with the same id. + // test logic : + // 1. send launch request + // 2. read swtpm pid file and terminate it + // 3. send launch request again, this should not fail + id := generateUUID() + defer cleanupFiles(id) + + // this doesn't matter + isTPMAvailable = func() bool { + return false + } + + b, _, err := sendLaunchRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + + pid, err := readPidFile(fmt.Sprintf(swtpmPidPath, id)) + if err != nil { + t.Fatalf("failed to read pid file: %v", err) + } + if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { + t.Fatalf("failed to kill process: %v", err) + + } + + // this should not fail + b, _, err = sendLaunchRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + + b, _, err = sendTerminateRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + time.Sleep(1 * time.Second) + + if liveInstances != 0 { + t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) + } +} + +func TestSwtpmMultipleLaucnhRequest(t *testing.T) { + // this test verify that if swtpm is launched multiple times with the same id, + // only one instance is created and other requests are ignored. + // test logic : + // 1. send launch request multiple times, it all should succeed + // 2. clean up + id := generateUUID() + defer cleanupFiles(id) + + // this doesn't matter + isTPMAvailable = func() bool { + return false + } + + for i := 0; i < 5; i++ { + b, _, err := sendLaunchRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + } + + pid, err := readPidFile(fmt.Sprintf(swtpmPidPath, id)) + if err != nil { + t.Fatalf("failed to read pid file: %v", err) + } + if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { + t.Fatalf("failed to kill process: %v", err) + + } + + b, _, err := sendTerminateRequest(id) + if err != nil { + t.Fatalf("failed to send request: %v, body : %s", err, b) + } + time.Sleep(1 * time.Second) + + if liveInstances != 0 { + t.Fatalf("expected liveInstances to be 0, got %d", liveInstances) + } +}