diff --git a/windows-agent/internal/config/config.go b/windows-agent/internal/config/config.go index 902297700..3a85db403 100644 --- a/windows-agent/internal/config/config.go +++ b/windows-agent/internal/config/config.go @@ -11,6 +11,7 @@ import ( "path/filepath" "sync" + "github.com/canonical/ubuntu-pro-for-windows/common" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/config/registry" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/contracts" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/database" @@ -206,7 +207,7 @@ func (c *Config) LandscapeAgentUID(ctx context.Context) (string, error) { // FetchMicrosoftStoreSubscription contacts Ubuntu Pro's contract server and the Microsoft Store // to check if the user has an active subscription that provides a pro token. If so, that token is used. -func (c *Config) FetchMicrosoftStoreSubscription(ctx context.Context) (err error) { +func (c *Config) FetchMicrosoftStoreSubscription(ctx context.Context, args ...contracts.Option) (err error) { defer decorate.OnError(&err, "could not validate subscription against Microsoft Store") readOnly, err := c.IsReadOnly() @@ -219,11 +220,36 @@ func (c *Config) FetchMicrosoftStoreSubscription(ctx context.Context) (err error return fmt.Errorf("subscription cannot be user-managed") } - proToken, err := contracts.ProToken(ctx) + _, src, err := c.Subscription(ctx) + if err != nil { + return fmt.Errorf("could not get current subscription status: %v", err) + } + + // Shortcut to avoid spamming the contract server + // We don't need to request a new token if we have a non-expired one + if src == SourceMicrosoftStore { + valid, err := contracts.ValidSubscription(args...) + if err != nil { + return fmt.Errorf("could not obtain current subscription status: %v", err) + } + + if valid { + log.Debug(ctx, "Microsoft Store subscription is active") + return nil + } + + log.Debug(ctx, "No valid Microsoft Store subscription") + } + + proToken, err := contracts.NewProToken(ctx, args...) if err != nil { return fmt.Errorf("could not get ProToken from Microsoft Store: %v", err) } + if proToken != "" { + log.Debugf(ctx, "Obtained Ubuntu Pro token from the Microsoft Store: %q", common.Obfuscate(proToken)) + } + if err := c.setStoreSubscription(ctx, proToken); err != nil { return err } diff --git a/windows-agent/internal/config/config_test.go b/windows-agent/internal/config/config_test.go index b093693b3..9ef905819 100644 --- a/windows-agent/internal/config/config_test.go +++ b/windows-agent/internal/config/config_test.go @@ -3,14 +3,19 @@ package config_test import ( "context" "errors" + "fmt" "io/fs" + "net/url" "os" "path/filepath" "testing" + "time" "github.com/canonical/ubuntu-pro-for-windows/common/wsltestutils" + "github.com/canonical/ubuntu-pro-for-windows/mocks/contractserver/contractsmockserver" config "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/config" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/config/registry" + "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/contracts" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/database" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/distro" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/task" @@ -412,21 +417,37 @@ func TestIsReadOnly(t *testing.T) { func TestFetchMicrosoftStoreSubscription(t *testing.T) { t.Parallel() + //nolint:gosec // These are not real credentials + const ( + proToken = "UBUNTU_PRO_TOKEN_456" + azureADToken = "AZURE_AD_TOKEN_789" + ) + testCases := map[string]struct { - settingsState settingsState + settingsState settingsState + subscriptionExpired bool + registryErr uint32 registryIsReadOnly bool + msStoreJWTErr bool + msStoreExpirationErr bool + wantToken string wantErr bool }{ - // TODO: Implement more test cases when the MS Store mock is available. There is no single successful test in here so far. - "Error when registry is read only": {settingsState: userTokenHasValue, registryIsReadOnly: true, wantToken: "user_token", wantErr: true}, - "Error when registry read-only check fails": {registryErr: registry.MockErrOnCreateKey, wantErr: true}, + // Tests where there is no pre-existing subscription + "Success": {wantToken: proToken}, - // Stub test-case: Must be replaced with Success/Error return values of contracts.ProToken - // when the Microsoft store dance is implemented. - "Error when the microsoft store is not yet implemented": {wantErr: true}, + "Error when registry is read only": {settingsState: userTokenHasValue, registryIsReadOnly: true, wantToken: "user_token", wantErr: true}, + "Error when registry read-only check fails": {registryErr: registry.MockErrOnCreateKey, wantErr: true}, + "Error when the Microsoft Store cannot provide the JWT": {msStoreJWTErr: true, wantErr: true}, + + // Tests where there is a pre-existing subscription + "Success when there is a store token already": {settingsState: storeTokenHasValue, wantToken: "store_token"}, + "Success when there is an expired store token": {settingsState: storeTokenHasValue, subscriptionExpired: true, wantToken: proToken}, + + "Error when the Microsoft Store cannot provide the expiration date": {settingsState: storeTokenHasValue, msStoreExpirationErr: true, wantToken: "store_token", wantErr: true}, } for name, tc := range testCases { @@ -439,7 +460,33 @@ func TestFetchMicrosoftStoreSubscription(t *testing.T) { r, dir := setUpMockSettings(t, tc.registryErr, tc.settingsState, tc.registryIsReadOnly, false) c := config.New(ctx, dir, config.WithRegistry(r)) - err := c.FetchMicrosoftStoreSubscription(ctx) + // Set up the mock Microsoft store + store := mockMSStore{ + expirationDate: time.Now().Add(24 * 365 * time.Hour), // Next year + expirationDateErr: tc.msStoreExpirationErr, + + jwt: "JWT_123", + jwtErr: tc.msStoreJWTErr, + } + + if tc.subscriptionExpired { + store.expirationDate = time.Now().Add(-24 * 365 * time.Hour) // Last year + } + + // Set up the mock contract server + csSettings := contractsmockserver.DefaultSettings() + csSettings.Token.OnSuccess.Value = azureADToken + csSettings.Subscription.OnSuccess.Value = proToken + server := contractsmockserver.NewServer(csSettings) + err := server.Serve(ctx, "localhost:0") + require.NoError(t, err, "Setup: Server should return no error") + //nolint:errcheck // Nothing we can do about it + defer server.Stop() + + csAddr, err := url.Parse(fmt.Sprintf("http://%s", server.Address())) + require.NoError(t, err, "Setup: Server URL should have been parsed with no issues") + + err = c.FetchMicrosoftStoreSubscription(ctx, contracts.WithProURL(csAddr), contracts.WithMockMicrosoftStore(store)) if tc.wantErr { require.Error(t, err, "FetchMicrosoftStoreSubscription should return an error") } else { @@ -455,6 +502,30 @@ func TestFetchMicrosoftStoreSubscription(t *testing.T) { } } +type mockMSStore struct { + jwt string + jwtErr bool + + expirationDate time.Time + expirationDateErr bool +} + +func (s mockMSStore) GenerateUserJWT(azureADToken string) (jwt string, err error) { + if s.jwtErr { + return "", errors.New("mock error") + } + + return s.jwt, nil +} + +func (s mockMSStore) GetSubscriptionExpirationDate() (tm time.Time, err error) { + if s.expirationDateErr { + return time.Time{}, errors.New("mock error") + } + + return s.expirationDate, nil +} + func TestUpdateRegistrySettings(t *testing.T) { if wsl.MockAvailable() { t.Parallel() diff --git a/windows-agent/internal/contracts/contracts.go b/windows-agent/internal/contracts/contracts.go index aa7794bd3..afb2ebc98 100644 --- a/windows-agent/internal/contracts/contracts.go +++ b/windows-agent/internal/contracts/contracts.go @@ -3,6 +3,7 @@ package contracts import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -10,7 +11,6 @@ import ( "github.com/canonical/ubuntu-pro-for-windows/storeapi/go-wrapper/microsoftstore" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/contracts/contractclient" - "github.com/ubuntu/decorate" ) type options struct { @@ -52,10 +52,8 @@ func (msftStoreDLL) GetSubscriptionExpirationDate() (tm time.Time, err error) { return microsoftstore.GetSubscriptionExpirationDate() } -// ProToken directs the dance between the Microsoft Store and the Ubuntu Pro contract server to -// validate a store entitlement and obtain its associated pro token. If there is no entitlement, -// the token is returned as an empty string. -func ProToken(ctx context.Context, args ...Option) (token string, err error) { +// ValidSubscription returns true if there is a subscription via the Microsoft Store and it is not expired. +func ValidSubscription(args ...Option) (bool, error) { opts := options{ microsoftStore: msftStoreDLL{}, } @@ -64,37 +62,50 @@ func ProToken(ctx context.Context, args ...Option) (token string, err error) { f(&opts) } - if opts.proURL == nil { - url, err := defaultProBackendURL() - if err != nil { - return "", fmt.Errorf("could not parse default contract server URL: %v", err) + expiration, err := opts.microsoftStore.GetSubscriptionExpirationDate() + if err != nil { + var target microsoftstore.StoreAPIError + if errors.As(err, &target) && target == microsoftstore.ErrNotSubscribed { + // ValidSubscription -> false: we are not subscribed + return false, nil } - opts.proURL = url - } - contractClient := contractclient.New(opts.proURL, &http.Client{Timeout: 30 * time.Second}) + return false, fmt.Errorf("could not get subscription expiration date: %v", err) + } - token, err = proToken(ctx, contractClient, opts.microsoftStore) - if err != nil { - return "", err + if expiration.Before(time.Now()) { + // ValidSubscription -> false: the subscription is expired + return false, nil } - return token, nil + // ValidSubscription -> true: the subscription is not yet expired + return true, nil } -func proToken(ctx context.Context, serverClient *contractclient.Client, msftStore MicrosoftStore) (proToken string, err error) { - defer decorate.OnError(&err, "could not obtain pro token") +// NewProToken directs the dance between the Microsoft Store and the Ubuntu Pro contract server to +// validate a store entitlement and obtain its associated pro token. If there is no entitlement, +// the token is returned as an empty string. +func NewProToken(ctx context.Context, args ...Option) (token string, err error) { + opts := options{ + microsoftStore: msftStoreDLL{}, + } - expiration, err := msftStore.GetSubscriptionExpirationDate() - if err != nil { - return "", fmt.Errorf("could not get subscription expiration date: %v", err) + for _, f := range args { + f(&opts) } - if expiration.Before(time.Now()) { - return "", fmt.Errorf("the subscription has been expired since %s", expiration) + if opts.proURL == nil { + url, err := defaultProBackendURL() + if err != nil { + return "", fmt.Errorf("could not parse default contract server URL: %v", err) + } + opts.proURL = url } - adToken, err := serverClient.GetServerAccessToken(ctx) + contractClient := contractclient.New(opts.proURL, &http.Client{Timeout: 30 * time.Second}) + msftStore := opts.microsoftStore + + adToken, err := contractClient.GetServerAccessToken(ctx) if err != nil { return "", err } @@ -104,7 +115,7 @@ func proToken(ctx context.Context, serverClient *contractclient.Client, msftStor return "", err } - proToken, err = serverClient.GetProToken(ctx, storeToken) + proToken, err := contractClient.GetProToken(ctx, storeToken) if err != nil { return "", err } diff --git a/windows-agent/internal/contracts/contracts_test.go b/windows-agent/internal/contracts/contracts_test.go index cda47b03c..c014857e6 100644 --- a/windows-agent/internal/contracts/contracts_test.go +++ b/windows-agent/internal/contracts/contracts_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/canonical/ubuntu-pro-for-windows/mocks/contractserver/contractsmockserver" + "github.com/canonical/ubuntu-pro-for-windows/storeapi/go-wrapper/microsoftstore" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/contracts" "github.com/stretchr/testify/require" ) @@ -24,9 +25,7 @@ func TestProToken(t *testing.T) { testCases := map[string]struct { // Microsoft store - expired bool - jwtError bool - expDateError bool + jwtError bool // Contract server getServerAccessTokenErr bool @@ -36,9 +35,7 @@ func TestProToken(t *testing.T) { }{ "Success": {}, - "Error when the subscription has expired": {expired: true, wantErr: true}, "Error when the store's GenerateUserJWT fails": {jwtError: true, wantErr: true}, - "Error when the store's GetSubscriptionExpirationDate fails": {expDateError: true, wantErr: true}, "Error when the contract server's GetServerAccessToken fails": {getServerAccessTokenErr: true, wantErr: true}, "Error when the contract server's GetProToken fails": {getProTokenErr: true, wantErr: true}, } @@ -50,18 +47,13 @@ func TestProToken(t *testing.T) { ctx := context.Background() store := mockMSStore{ - expirationDate: time.Now().Add(24 * 365 * time.Hour), // Next year - expirationDateErr: tc.expDateError, + expirationDate: time.Now().Add(24 * 365 * time.Hour), // Next year jwt: "JWT_123", jwtWantADToken: azureADToken, jwtErr: tc.jwtError, } - if tc.expired { - store.expirationDate = time.Now().Add(-24 * 365 * time.Hour) // Last year - } - settings := contractsmockserver.DefaultSettings() settings.Token.OnSuccess.Value = azureADToken @@ -80,7 +72,7 @@ func TestProToken(t *testing.T) { url, err := url.Parse(fmt.Sprintf("http://%s", addr)) require.NoError(t, err, "Setup: Server URL should have been parsed with no issues") - token, err := contracts.ProToken(ctx, contracts.WithProURL(url), contracts.WithMockMicrosoftStore(store)) + token, err := contracts.NewProToken(ctx, contracts.WithProURL(url), contracts.WithMockMicrosoftStore(store)) if tc.wantErr { require.Error(t, err, "ProToken should return an error") return @@ -92,11 +84,68 @@ func TestProToken(t *testing.T) { } } +func TestValidSubscription(t *testing.T) { + t.Parallel() + + type subscriptionStatus int + const ( + subscribed subscriptionStatus = iota + expired + unsubscribed + ) + + testCases := map[string]struct { + status subscriptionStatus + expirationErr bool + + want bool + wantErr bool + }{ + "Succcess when the current subscription is active": {status: subscribed, want: true}, + "Succcess when the current subscription is expired": {status: expired, want: false}, + "Success when there is no subscription": {status: unsubscribed, want: false}, + + "Error when subscription validity cannot be ascertained": {status: subscribed, expirationErr: true, wantErr: true}, + } + + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + var store mockMSStore + + switch tc.status { + case subscribed: + store.expirationDate = time.Now().Add(time.Hour * 24 * 365) // Next year + case expired: + store.expirationDate = time.Now().Add(-time.Hour * 24 * 365) // Last year + case unsubscribed: + store.notSubscribed = true + } + + if tc.expirationErr { + store.expirationDateErr = true + } + + got, err := contracts.ValidSubscription(contracts.WithMockMicrosoftStore(store)) + if tc.wantErr { + require.Error(t, err, "contracts.ValidSubscription should have returned an error") + return + } + + require.NoError(t, err, "contracts.ValidSubscription should have returned no error") + require.Equal(t, tc.want, got, "Unexpected return from ValidSubscription") + }) + } +} + type mockMSStore struct { jwt string jwtWantADToken string jwtErr bool + notSubscribed bool expirationDate time.Time expirationDateErr bool } @@ -115,7 +164,11 @@ func (s mockMSStore) GenerateUserJWT(azureADToken string) (jwt string, err error func (s mockMSStore) GetSubscriptionExpirationDate() (tm time.Time, err error) { if s.expirationDateErr { - return time.Time{}, errors.New("mock error") + return time.Time{}, fmt.Errorf("mock error: %w", microsoftstore.ErrStoreAPI) + } + + if s.notSubscribed { + return time.Time{}, fmt.Errorf("mock error: %w", microsoftstore.ErrNotSubscribed) } return s.expirationDate, nil diff --git a/windows-agent/internal/proservices/ui/ui.go b/windows-agent/internal/proservices/ui/ui.go index 7a032f7a1..e5dfddd79 100644 --- a/windows-agent/internal/proservices/ui/ui.go +++ b/windows-agent/internal/proservices/ui/ui.go @@ -9,6 +9,7 @@ import ( agentapi "github.com/canonical/ubuntu-pro-for-windows/agentapi/go" "github.com/canonical/ubuntu-pro-for-windows/common" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/config" + "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/contracts" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/database" log "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/grpc/logstreamer" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/tasks" @@ -19,7 +20,7 @@ type Config interface { SetUserSubscription(ctx context.Context, token string) error IsReadOnly() (bool, error) Subscription(context.Context) (string, config.Source, error) - FetchMicrosoftStoreSubscription(context.Context) error + FetchMicrosoftStoreSubscription(context.Context, ...contracts.Option) error } // Service it the UI GRPC service implementation. diff --git a/windows-agent/internal/proservices/ui/ui_test.go b/windows-agent/internal/proservices/ui/ui_test.go index 7f591006d..ca841151f 100644 --- a/windows-agent/internal/proservices/ui/ui_test.go +++ b/windows-agent/internal/proservices/ui/ui_test.go @@ -12,6 +12,7 @@ import ( "github.com/canonical/ubuntu-pro-for-windows/common/wsltestutils" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/config" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/config/registry" + "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/contracts" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/database" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/distro" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/proservices/ui" @@ -244,7 +245,11 @@ func (m mockConfig) Subscription(context.Context) (string, config.Source, error) } return m.token, m.source, nil } -func (m *mockConfig) FetchMicrosoftStoreSubscription(ctx context.Context) error { +func (m *mockConfig) FetchMicrosoftStoreSubscription(ctx context.Context, args ...contracts.Option) error { + if len(args) != 0 { + panic("The variadic argument exists solely to match the interface. Do not use.") + } + readOnly, err := m.IsReadOnly() if err != nil { return err