diff --git a/storeapi/go-wrapper/microsoftstore/errors.go b/storeapi/go-wrapper/microsoftstore/errors.go index c48d5a29b..287ad07f6 100644 --- a/storeapi/go-wrapper/microsoftstore/errors.go +++ b/storeapi/go-wrapper/microsoftstore/errors.go @@ -2,7 +2,7 @@ package microsoftstore // StoreAPIError are the error constants in the store api. -type StoreAPIError int +type StoreAPIError int64 // Keep up-to-date with `storeapi\base\Exception.hpp`. const ( @@ -26,7 +26,7 @@ const ( ) // NewStoreAPIError creates StoreAPIError from the result of a call to the storeAPI DLL. -func NewStoreAPIError(hresult uintptr) error { +func NewStoreAPIError(hresult int64) error { if err := StoreAPIError(hresult); err < ErrSuccess { return err } diff --git a/storeapi/go-wrapper/microsoftstore/export_test.go b/storeapi/go-wrapper/microsoftstore/export_test.go index d893d5df4..ee7467f26 100644 --- a/storeapi/go-wrapper/microsoftstore/export_test.go +++ b/storeapi/go-wrapper/microsoftstore/export_test.go @@ -2,3 +2,6 @@ package microsoftstore // FindWorkspaceRoot climbs up the current working directory until the Go workspace root is found. var FindWorkspaceRoot = findWorkspaceRoot + +// CheckError inspects the values of hres and err to determine what kind of error we have, if any, according to the rules of syscall/dll_windows.go. +var CheckError = checkError diff --git a/storeapi/go-wrapper/microsoftstore/store.go b/storeapi/go-wrapper/microsoftstore/store.go index ef792e710..0f0c1db33 100644 --- a/storeapi/go-wrapper/microsoftstore/store.go +++ b/storeapi/go-wrapper/microsoftstore/store.go @@ -2,8 +2,10 @@ package microsoftstore import ( "errors" + "fmt" "os" "path/filepath" + "syscall" ) // findWorkspaceRoot climbs up the current working directory until the Go workspace root is found. @@ -26,3 +28,40 @@ func findWorkspaceRoot() (string, error) { } } } + +// checkError inspects the values of hres and err to determine what kind of error we have, if any, according to the rules of syscall/dll_windows.go. +func checkError(hres int64, err error) (int64, error) { + // From syscall/dll_windows.go (*Proc).Call doc: + // > Callers must inspect the primary return value to decide whether an + // error occurred [...] before consulting the error. + // There is no possibility of nil error, the `err` return value is always constructed with the + // result of `GetLastError()` which could have been set by something completely + // unrelated to our code some time in the past, as well as it could be `ERROR_SUCCESS` which is the `Errno(0)`. + // If the act of calling the API fails (not the function we're calling, but the attempt to call it), then we'd + // have a meaningful `syscall.Errno` object via the `err` parameter, related to the actual failure (like a function not found in this DLL) + // Since our implementation of the store API doesn't touch errno the call should return `hres` + // in our predefined range plus garbage in the `err` argument, thus we only care about the `hres` in this case. + if e := NewStoreAPIError(hres); e != nil { + return hres, fmt.Errorf("storeApi returned error code %d: %w", hres, e) + } + + // Supposedly unreachable: proc.Call must always return a non-nil syscall.Errno + if err == nil { + return hres, nil + } + + var target syscall.Errno + if b := errors.As(err, &target); !b { + // Supposedly unreachable: proc.Call must always return a non-nil syscall.Errno + return hres, err + } + + // The act of calling our API didn't succeed, function not found in the DLL for example: + if target != syscall.Errno(0) { + return hres, fmt.Errorf("failed syscall to storeApi: %v (syscall errno %d)", target, err) + } + + // A non-error value in hres plus ERROR_SUCCESS in err. + // This shouldn't happen in the current store API implementation anyway. + return hres, nil +} diff --git a/storeapi/go-wrapper/microsoftstore/store_test.go b/storeapi/go-wrapper/microsoftstore/store_test.go index b502a797b..12f862de4 100644 --- a/storeapi/go-wrapper/microsoftstore/store_test.go +++ b/storeapi/go-wrapper/microsoftstore/store_test.go @@ -2,12 +2,14 @@ package microsoftstore_test import ( "context" + "errors" "fmt" "log/slog" "os" "os/exec" "path/filepath" "runtime" + "syscall" "testing" "time" @@ -87,6 +89,39 @@ func TestGetSubscriptionExpirationDate(t *testing.T) { require.ErrorIs(t, gotErr, wantErr, "GetSubscriptionExpirationDate should have returned code %d", wantErr) } +func TestErrorVerification(t *testing.T) { + t.Parallel() + testcases := map[string]struct { + hresult int64 + err error + + wantErr bool + }{ + "Success": {}, + // If HRESULT is not in the Store API error range and err is not a syscall.Errno then we don't have an error. + "With an unknown value (not an error)": {hresult: 1, wantErr: false}, + + "Upper bound of the Store API enum range": {hresult: -1, wantErr: true}, + "Lower bound of the Store API enum range": {hresult: int64(microsoftstore.ErrNotSubscribed), wantErr: true}, + "With a system error (errno)": {hresult: 32 /*garbage*/, err: syscall.Errno(2) /*E_FILE_NOT_FOUND*/, wantErr: true}, + "With a generic (unreachable) error": {hresult: 1, err: errors.New("test error"), wantErr: true}, + // This would mean an API call returning a non-error hresult plus GetLastError() returning ERROR_SUCCESS + // This shouldn't happen in the current store API implementation anyway. + "With weird successful error": {hresult: 1, err: syscall.Errno(0) /*ERROR_SUCCESS*/}, + } + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + t.Parallel() + res, err := microsoftstore.CheckError(tc.hresult, tc.err) + if tc.wantErr { + require.Error(t, err, "CheckError should have returned an error for value: %v, returned value was: %v", tc.hresult, res) + return + } + require.NoError(t, err, "CheckError should have not returned an error for value: %v, returned value was: %v", tc.hresult, res) + }) + } +} + func buildStoreAPI(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() diff --git a/storeapi/go-wrapper/microsoftstore/store_windows.go b/storeapi/go-wrapper/microsoftstore/store_windows.go index d9bf10754..2a79967d2 100644 --- a/storeapi/go-wrapper/microsoftstore/store_windows.go +++ b/storeapi/go-wrapper/microsoftstore/store_windows.go @@ -2,7 +2,6 @@ package microsoftstore import ( - "errors" "fmt" "sync" "syscall" @@ -81,7 +80,7 @@ func GetSubscriptionExpirationDate() (tm time.Time, err error) { // Use this instead of proc.Call to avoid panics. // //nolint:unparam // Return value is provided to follow convention. -func call(proc *syscall.LazyProc, args ...uintptr) (int, error) { +func call(proc *syscall.LazyProc, args ...uintptr) (int64, error) { if err := loadDll(); err != nil { return 0, err } @@ -92,29 +91,8 @@ func call(proc *syscall.LazyProc, args ...uintptr) (int, error) { } hresult, _, err := proc.Call(args...) - - // From syscall/dll_windows.go (*Proc).Call doc: - // > Callers must inspect the primary return value to decide whether an - // error occurred [...] before consulting the error. - if err := NewStoreAPIError(hresult); err != nil { - return int(hresult), fmt.Errorf("storeApi returned error code %d: %w", int(hresult), err) - } - - if err == nil { - return int(hresult), nil - } - - var target syscall.Errno - if b := errors.As(err, &target); !b { - // Supposedly unrechable: proc.Call must always return a syscall.Errno - return int(hresult), err - } - - if target != syscall.Errno(0) { - return int(hresult), fmt.Errorf("failed syscall to storeApi: %v (syscall errno %d)", target, err) - } - - return int(hresult), nil + //nolint:gosec // Windows HRESULTS are guaranteed to be 32-bit vlaue, thus they surely fit inside a int64 without overflow. + return checkError(int64(hresult), err) } // loadDll finds the dll and ensures it loads. diff --git a/windows-agent/internal/daemon/daemon_test.go b/windows-agent/internal/daemon/daemon_test.go index 08fe629f7..2fd4e1ff4 100644 --- a/windows-agent/internal/daemon/daemon_test.go +++ b/windows-agent/internal/daemon/daemon_test.go @@ -169,10 +169,11 @@ func TestServeWSLIP(t *testing.T) { "With mirrored networking mode": {netmode: "mirrored", withAdapters: daemontestutils.MultipleHyperVAdaptersInList}, "With no access to the system distro but net mode is the default (NAT)": {netmode: "error", withAdapters: daemontestutils.MultipleHyperVAdaptersInList}, - "Error when the networking mode is unknown": {netmode: "unknown", wantErr: true}, - "Error when the list of adapters is empty": {withAdapters: daemontestutils.EmptyList, wantErr: true}, - "Error when there is no Hyper-V adapter the list": {withAdapters: daemontestutils.NoHyperVAdapterInList, wantErr: true}, - "Error when retrieving adapters information fails": {withAdapters: daemontestutils.MockError, wantErr: true}, + "Error when the networking mode is unknown": {netmode: "unknown", wantErr: true}, + "Error when the list of adapters is empty": {withAdapters: daemontestutils.EmptyList, wantErr: true}, + "Error when listing adapters requires too much memory": {withAdapters: daemontestutils.RequiresTooMuchMem, wantErr: true}, + "Error when there is no Hyper-V adapter the list": {withAdapters: daemontestutils.NoHyperVAdapterInList, wantErr: true}, + "Error when retrieving adapters information fails": {withAdapters: daemontestutils.MockError, wantErr: true}, } for name, tc := range testcases { diff --git a/windows-agent/internal/daemon/daemontestutils/networking_mock.go b/windows-agent/internal/daemon/daemontestutils/networking_mock.go index 1360c023c..9d0bd9721 100644 --- a/windows-agent/internal/daemon/daemontestutils/networking_mock.go +++ b/windows-agent/internal/daemon/daemontestutils/networking_mock.go @@ -2,6 +2,7 @@ package daemontestutils import ( "errors" + "math" "net" "unsafe" @@ -15,6 +16,9 @@ const ( // MockError is a state that causes the GetAdaptersAddresses to always return an error. MockError MockIPAdaptersState = iota + // RequiresTooMuchMem is a state that causes the GetAdaptersAddresses to request allocation of MaxUint32 (over the capacity of the real Win32 API). + RequiresTooMuchMem + // EmptyList is a state that causes the GetAdaptersAddresses to return an empty list of adapters. EmptyList @@ -96,6 +100,9 @@ func (m *MockIPConfig) GetAdaptersAddresses(_, _ uint32, _ uintptr, adapterAddre switch m.state { case MockError: return errors.New("mock error") + case RequiresTooMuchMem: + *sizePointer = math.MaxUint32 + return ERROR_BUFFER_OVERFLOW case EmptyList: return nil default: @@ -105,24 +112,29 @@ func (m *MockIPConfig) GetAdaptersAddresses(_, _ uint32, _ uintptr, adapterAddre // fillBufferFromTemplate fills a pre-allocated buffer of ipAdapterAddresses with the data from the mockIPAddrsTemplate. func fillBufferFromTemplate(adaptersAddresses *IPAdapterAddresses, sizePointer *uint32, mockIPAddrsTemplate []MockIPAddrsTemplate) error { - count := uint32(len(mockIPAddrsTemplate)) - objSize := uint32(unsafe.Sizeof(IPAdapterAddresses{})) + count := len(mockIPAddrsTemplate) + objSize := int(unsafe.Sizeof(IPAdapterAddresses{})) bufSizeNeeded := count * objSize - if *sizePointer < bufSizeNeeded { + if bufSizeNeeded >= math.MaxUint32 || bufSizeNeeded < 0 { + return errors.New("buffer size limit reached") + } + //nolint:gosec // Value guaranteed to fit inside uint32. + bufSz := uint32(bufSizeNeeded) + if *sizePointer < bufSz { return ERROR_BUFFER_OVERFLOW } //nolint:gosec // Using unsafe to manipulate pointers mimicking the Win32 API, only used in tests. begin := unsafe.Pointer(adaptersAddresses) for _, addr := range mockIPAddrsTemplate { - next := unsafe.Add(begin, int(objSize)) // next = ++begin + next := unsafe.Add(begin, objSize) // next = ++begin ptr := (*IPAdapterAddresses)(begin) fillFromTemplate(&addr, ptr, (*IPAdapterAddresses)(next)) begin = next } - *sizePointer = bufSizeNeeded + *sizePointer = bufSz return nil } diff --git a/windows-agent/internal/daemon/networking.go b/windows-agent/internal/daemon/networking.go index a7c2bd37e..ed916caac 100644 --- a/windows-agent/internal/daemon/networking.go +++ b/windows-agent/internal/daemon/networking.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "math" "net" "os/exec" "reflect" @@ -139,15 +140,20 @@ func getAddrList(opts options) (head *ipAdapterAddresses, err error) { // the buffer around while we're using it, invalidating the NEXT pointers. var buff buffer[ipAdapterAddresses] - // Win32 API docs recommend a buff size of 15KB. - buff.resizeBytes(15 * kilobyte) + // Win32 API docs recommend a buff size of 15KB to start. + size, err := buff.resizeBytes(15 * kilobyte) + // This error condition should be impossible. + if err != nil { + return nil, err + } for range 10 { - size := buff.byteCount() - err := opts.getAdaptersAddresses(family, flags, 0, &buff.data[0], &size) + err = opts.getAdaptersAddresses(family, flags, 0, &buff.data[0], &size) if errors.Is(err, ERROR_BUFFER_OVERFLOW) { // Buffer too small, try again with the returned size. - buff.resizeBytes(size) + if size, err = buff.resizeBytes(size); err != nil { + return nil, err + } continue } if err != nil { @@ -171,27 +177,31 @@ type buffer[T any] struct { data []T } -// byteCount returns the number of bytes in the buffer. -func (b buffer[T]) byteCount() uint32 { - var t T - sizeOf := uint32(reflect.TypeOf(t).Size()) - n := uint32(len(b.data)) - return n * sizeOf -} - // ResizeBytes resizes the buffer to the given number of bytes, rounded UP to fit an integer element size. -func (b *buffer[T]) resizeBytes(n uint32) { +func (b *buffer[T]) resizeBytes(n uint32) (uint32, error) { var t T - sizeOf := uint32(reflect.TypeOf(t).Size()) + n64 := uint64(n) + sizeOf := uint64(reflect.TypeOf(t).Size()) - newLen := int(n / sizeOf) - if n%sizeOf != 0 { + newLen := n64 / sizeOf + if n64%sizeOf != 0 { newLen++ } - if newLen > len(b.data) { + // the sizes the Win32 API GetAdaptersAddresses works with are uint32, thus we cannot allocate + // more than MaxUint32 bytes after all. + newSize := newLen * sizeOf + if newSize >= math.MaxUint32 { + return 0, errors.New("buffer allocated size limit reached") + } + + if newLen > uint64(len(b.data)) { b.data = make([]T, newLen) + // Since make() guarantees len(b.data) == newLen, there is no need to recompute it. } + + //nolint:gosec //uint64 -> uint32 conversion is safe because we checked that newSize < MaxUint32. + return uint32(newSize), nil } // ptr returns a pointer to the start of the buffer. diff --git a/windows-agent/internal/proservices/landscape/distroinstall/distroinstall.go b/windows-agent/internal/proservices/landscape/distroinstall/distroinstall.go index 075c325f5..ef5ed8402 100644 --- a/windows-agent/internal/proservices/landscape/distroinstall/distroinstall.go +++ b/windows-agent/internal/proservices/landscape/distroinstall/distroinstall.go @@ -69,6 +69,7 @@ func CreateUser(ctx context.Context, d gowsl.Distro, userName string, userFullNa return 0, fmt.Errorf("could not parse uid %q: %v", string(out), err) } + //nolint:gosec // strconv.ParseUint with bitSize 32 ensures the value of id64 fits inside uint32. return uint32(id64), nil } diff --git a/wsl-pro-service/internal/system/networking.go b/wsl-pro-service/internal/system/networking.go index d0c09ab11..1d27b516d 100644 --- a/wsl-pro-service/internal/system/networking.go +++ b/wsl-pro-service/internal/system/networking.go @@ -101,6 +101,7 @@ func (s *System) defaultGateway() (ip net.IP, err error) { } b := make([]byte, 4) + //nolint:gosec // Value is guaranteed by strconv.ParseUint to fit in uint32 (due the bitSize argument) binary.LittleEndian.PutUint32(b, uint32(gatewayRaw)) return net.IP(b), nil