diff --git a/.github/workflows/sync-tag.yml b/.github/workflows/sync-tag.yml index 2a2255996cb..1cc553b12fa 100644 --- a/.github/workflows/sync-tag.yml +++ b/.github/workflows/sync-tag.yml @@ -17,6 +17,7 @@ jobs: uses: benc-uk/workflow-dispatch@v1 with: workflow: sync-tag.yml + ref: main repo: ${{ secrets.UPSTREAM_REPO }} token: ${{ secrets.NC_GITHUB_TOKEN }} inputs: '{ "tag": "${{ github.ref_name }}" }' \ No newline at end of file diff --git a/client/cmd/login.go b/client/cmd/login.go index ac79199e2d6..5af8c177509 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -51,7 +51,7 @@ var loginCmd = &cobra.Command{ AdminURL: adminURL, ConfigPath: configPath, } - if preSharedKey != "" { + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { ic.PreSharedKey = &preSharedKey } @@ -151,13 +151,21 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C jwtToken = tokenInfo.GetTokenToUse() } + var lastError error + err = WithBackOff(func() error { err := internal.Login(ctx, config, setupKey, jwtToken) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { + lastError = err return nil } return err }) + + if lastError != nil { + return fmt.Errorf("login failed: %v", lastError) + } + if err != nil { return fmt.Errorf("backoff cycle failed: %v", err) } diff --git a/client/cmd/root.go b/client/cmd/root.go index 24c027d0c0d..91a2c6861de 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -26,6 +26,7 @@ import ( const ( externalIPMapFlag = "external-ip-map" + preSharedKeyFlag = "preshared-key" dnsResolverAddress = "dns-resolver-address" ) @@ -94,7 +95,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") - rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") + rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.AddCommand(serviceCmd) rootCmd.AddCommand(upCmd) diff --git a/client/cmd/status.go b/client/cmd/status.go index 74d2061ffac..1c5dfab26ae 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -66,13 +66,15 @@ type statusOutputOverview struct { } var ( - detailFlag bool - ipv4Flag bool - jsonFlag bool - yamlFlag bool - ipsFilter []string - statusFilter string - ipsFilterMap map[string]struct{} + detailFlag bool + ipv4Flag bool + jsonFlag bool + yamlFlag bool + ipsFilter []string + prefixNamesFilter []string + statusFilter string + ipsFilterMap map[string]struct{} + prefixNamesFilterMap map[string]struct{} ) var statusCmd = &cobra.Command{ @@ -83,12 +85,14 @@ var statusCmd = &cobra.Command{ func init() { ipsFilterMap = make(map[string]struct{}) + prefixNamesFilterMap = make(map[string]struct{}) statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format") statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") + statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected") } @@ -172,8 +176,12 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, } func parseFilters() error { + switch strings.ToLower(statusFilter) { case "", "disconnected", "connected": + if strings.ToLower(statusFilter) != "" { + enableDetailFlagWhenFilterFlag() + } default: return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter) } @@ -185,11 +193,26 @@ func parseFilters() error { return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err) } ipsFilterMap[addr] = struct{}{} + enableDetailFlagWhenFilterFlag() } } + + if len(prefixNamesFilter) > 0 { + for _, name := range prefixNamesFilter { + prefixNamesFilterMap[strings.ToLower(name)] = struct{}{} + } + enableDetailFlagWhenFilterFlag() + } + return nil } +func enableDetailFlagWhenFilterFlag() { + if !detailFlag && !jsonFlag && !yamlFlag { + detailFlag = true + } +} + func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview { pbFullStatus := resp.GetFullStatus() @@ -415,6 +438,7 @@ func parsePeers(peers peersStateOutput) string { func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { statusEval := false ipEval := false + nameEval := false if statusFilter != "" { lowerStatusFilter := strings.ToLower(statusFilter) @@ -431,5 +455,15 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { ipEval = true } } - return statusEval || ipEval + + if len(prefixNamesFilter) > 0 { + for prefixNameFilter := range prefixNamesFilterMap { + if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { + nameEval = true + break + } + } + } + + return statusEval || ipEval || nameEval } diff --git a/client/cmd/up.go b/client/cmd/up.go index dd4c7290ee2..ebfcb2b9d0f 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -85,7 +85,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { NATExternalIPs: natExternalIPs, CustomDNSAddress: customDNSAddressConverted, } - if preSharedKey != "" { + + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { ic.PreSharedKey = &preSharedKey } diff --git a/client/internal/config.go b/client/internal/config.go index 646848a2f66..8f433a0419f 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -215,12 +215,9 @@ func update(input ConfigInput) (*Config, error) { } if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey { - if *input.PreSharedKey != "" { - log.Infof("new pre-shared key provides, updated to %s (old value %s)", - *input.PreSharedKey, config.PreSharedKey) - config.PreSharedKey = *input.PreSharedKey - refresh = true - } + log.Infof("new pre-shared key provided, replacing old key") + config.PreSharedKey = *input.PreSharedKey + refresh = true } if config.SSHKey == "" { diff --git a/client/internal/config_test.go b/client/internal/config_test.go index 8bd8d8d6138..eeec9b516b1 100644 --- a/client/internal/config_test.go +++ b/client/internal/config_test.go @@ -6,8 +6,9 @@ import ( "path/filepath" "testing" - "github.com/netbirdio/netbird/util" "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/util" ) func TestGetConfig(t *testing.T) { @@ -60,22 +61,7 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, config.ManagementURL.String(), managementURL) assert.Equal(t, config.PreSharedKey, preSharedKey) - // case 4: new empty pre-shared key config -> fetch it - newPreSharedKey := "" - config, err = UpdateOrCreateConfig(ConfigInput{ - ManagementURL: managementURL, - AdminURL: adminURL, - ConfigPath: path, - PreSharedKey: &newPreSharedKey, - }) - if err != nil { - return - } - - assert.Equal(t, config.ManagementURL.String(), managementURL) - assert.Equal(t, config.PreSharedKey, preSharedKey) - - // case 5: existing config, but new managementURL has been provided -> update config + // case 4: existing config, but new managementURL has been provided -> update config newManagementURL := "https://test.newManagement.url:33071" config, err = UpdateOrCreateConfig(ConfigInput{ ManagementURL: newManagementURL, diff --git a/management/cmd/management.go b/management/cmd/management.go index f05de4e4e81..54f672f51a1 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -83,7 +83,10 @@ var ( if err != nil { return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) } - config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled + + if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed { + config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled + } tlsEnabled := false if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") { diff --git a/management/cmd/root.go b/management/cmd/root.go index 1c9b95bfd2a..f5c533969e6 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -12,7 +12,8 @@ import ( const ( // ExitSetupFailed defines exit code - ExitSetupFailed = 1 + ExitSetupFailed = 1 + idpSignKeyRefreshEnabledFlagName = "idp-sign-key-refresh-enabled" ) var ( @@ -62,7 +63,7 @@ func init() { mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max length is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) - mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") + mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, idpSignKeyRefreshEnabledFlagName, false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account") rootCmd.MarkFlagRequired("config") //nolint diff --git a/management/server/account.go b/management/server/account.go index 4c13c853587..f2d5e79d587 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -17,11 +17,12 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" - "github.com/netbirdio/management-integrations/additions" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/management-integrations/additions" + "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" @@ -66,6 +67,7 @@ type AccountManager interface { GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(accountID, userID string) error MarkPATUsed(tokenID string) error @@ -164,6 +166,9 @@ type Settings struct { // JWTGroupsClaimName from which we extract groups name to add it to account groups JWTGroupsClaimName string + // JWTAllowGroups list of groups to which users are allowed access + JWTAllowGroups []string `gorm:"serializer:json"` + // Extra is a dictionary of Account settings Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` } @@ -176,6 +181,7 @@ func (s *Settings) Copy() *Settings { JWTGroupsEnabled: s.JWTGroupsEnabled, JWTGroupsClaimName: s.JWTGroupsClaimName, GroupsPropagationEnabled: s.GroupsPropagationEnabled, + JWTAllowGroups: s.JWTAllowGroups, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -1693,6 +1699,39 @@ func (am *DefaultAccountManager) GetDNSDomain() string { return am.dnsDomain } +// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT +// group propagation and set the list of groups with access permissions. +func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { + account, _, err := am.GetAccountFromToken(claims) + if err != nil { + return err + } + + // Ensures JWT group synchronization to the management is enabled before, + // filtering access based on the allowed groups. + if account.Settings != nil && account.Settings.JWTGroupsEnabled { + if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 { + userJWTGroups := make([]string, 0) + + if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } + } + } + } + + if !userHasAllowedGroup(allowedGroups, userJWTGroups) { + return fmt.Errorf("user does not belong to any of the allowed JWT groups") + } + } + } + + return nil +} + // addAllGroup to account object if it doesn't exists func addAllGroup(account *Account) error { if len(account.Groups) == 0 { @@ -1764,3 +1803,15 @@ func newAccountWithId(accountID, userID, domain string) *Account { } return acc } + +// userHasAllowedGroup checks if a user belongs to any of the allowed groups. +func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { + for _, userGroup := range userGroups { + for _, allowedGroup := range allowedGroups { + if userGroup == allowedGroup { + return true + } + } + } + return false +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 8d3d82661d0..d6463edd9b6 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -220,6 +220,10 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } + if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil { + return "", status.Errorf(codes.PermissionDenied, err.Error()) + } + return claims.UserId, nil } @@ -312,7 +316,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p userID, err = s.validateToken(loginReq.GetJwtToken()) if err != nil { log.Warnf("failed validating JWT token sent from peer %s", peerKey) - return nil, mapError(err) + return nil, err } } var sshKey []byte diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index c2751abd478..bab00219b86 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -91,6 +91,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) if req.Settings.JwtGroupsClaimName != nil { settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName } + if req.Settings.JwtAllowGroups != nil { + settings.JWTAllowGroups = *req.Settings.JwtAllowGroups + } updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings) if err != nil { @@ -128,12 +131,18 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) } func toAccountResponse(account *server.Account) *api.Account { + jwtAllowGroups := account.Settings.JWTAllowGroups + if jwtAllowGroups == nil { + jwtAllowGroups = []string{} + } + settings := api.AccountSettings{ PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, + JwtAllowGroups: &jwtAllowGroups, } if account.Settings.Extra != nil { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 08c98c830f6..fd2c4bfcd33 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -95,6 +95,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(false), JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, }, expectedArray: true, expectedID: accountID, @@ -112,6 +113,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(false), JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, }, expectedArray: false, expectedID: accountID, @@ -121,7 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -129,6 +131,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(false), JwtGroupsClaimName: sr("roles"), JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, }, expectedArray: false, expectedID: accountID, @@ -146,6 +149,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(true), JwtGroupsClaimName: sr("groups"), JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{}, }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 3a125bdd7e7..1a049a0cff5 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -66,6 +66,12 @@ components: description: Name of the claim from which we extract groups names to add it to account groups. type: string example: "roles" + jwt_allow_groups: + description: List of groups to which users are allowed access + type: array + items: + type: string + example: Administrators extra: $ref: '#/components/schemas/AccountExtraSettings' required: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 820cf5c48b2..329c6688482 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -160,6 +160,9 @@ type AccountSettings struct { // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` + // JwtAllowGroups List of groups to which users are allowed access + JwtAllowGroups *[]string `json:"jwt_allow_groups,omitempty"` + // JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups. JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"` diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 8c77d27dc4a..c47eac5731f 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -34,12 +34,20 @@ type emptyObject struct { // APIHandler creates the Management service HTTP API handler registering all the available endpoints. func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { + claimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ) + authMiddleware := middleware.NewAuthMiddleware( accountManager.GetAccountFromPAT, jwtValidator.ValidateAndParse, accountManager.MarkPATUsed, + accountManager.CheckUserAccessByJWTGroups, + claimsExtractor, authCfg.Audience, - authCfg.UserIDClaim) + authCfg.UserIDClaim, + ) corsMiddleware := cors.AllowAll() @@ -60,11 +68,6 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid AuthCfg: authCfg, } - claimsExtractor := jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ) - integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor) api.addAccountsEndpoint() api.addPeersEndpoint() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 99482bfb7f5..766a0c235c0 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -26,13 +26,18 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) // MarkPATUsedFunc function type MarkPATUsedFunc func(token string) error +// CheckUserAccessByJWTGroupsFunc function +type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error + // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - getAccountFromPAT GetAccountFromPATFunc - validateAndParseToken ValidateAndParseTokenFunc - markPATUsed MarkPATUsedFunc - audience string - userIDClaim string + getAccountFromPAT GetAccountFromPATFunc + validateAndParseToken ValidateAndParseTokenFunc + markPATUsed MarkPATUsedFunc + checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc + claimsExtractor *jwtclaims.ClaimsExtractor + audience string + userIDClaim string } const ( @@ -40,16 +45,21 @@ const ( ) // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware { +func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, + markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, + audience string, userIdClaim string) *AuthMiddleware { if userIdClaim == "" { userIdClaim = jwtclaims.UserIDClaim } + return &AuthMiddleware{ - getAccountFromPAT: getAccountFromPAT, - validateAndParseToken: validateAndParseToken, - markPATUsed: markPATUsed, - audience: audience, - userIDClaim: userIdClaim, + getAccountFromPAT: getAccountFromPAT, + validateAndParseToken: validateAndParseToken, + markPATUsed: markPATUsed, + checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, + claimsExtractor: claimsExtractor, + audience: audience, + userIDClaim: userIdClaim, } } @@ -107,6 +117,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return nil } + if err := m.verifyUserAccess(validatedToken); err != nil { + return err + } + // If we get here, everything worked and we can set the // user property in context. newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint @@ -115,6 +129,14 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return nil } +// verifyUserAccess checks if a user, based on a validated JWT token, +// is allowed access, particularly in cases where the admin enabled JWT +// group propagation and designated certain groups with access permissions. +func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error { + authClaims := m.claimsExtractor.FromToken(validatedToken) + return m.checkUserAccessByJWTGroups(authClaims) +} + // CheckPATFromRequest checks if the PAT is valid func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { token, err := getTokenFromPATRequest(auth) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 55e5de260f6..5fa73ea3aa5 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -10,6 +10,7 @@ import ( "github.com/golang-jwt/jwt" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" ) const ( @@ -54,7 +55,13 @@ func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server func mockValidateAndParseToken(token string) (*jwt.Token, error) { if token == JWT { - return &jwt.Token{}, nil + return &jwt.Token{ + Claims: jwt.MapClaims{ + userIDClaim: userID, + audience + jwtclaims.AccountIDSuffix: accountID, + }, + Valid: true, + }, nil } return nil, fmt.Errorf("JWT invalid") } @@ -66,6 +73,18 @@ func mockMarkPATUsed(token string) error { return fmt.Errorf("Should never get reached") } +func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { + if testAccount.Id != claims.AccountId { + return fmt.Errorf("account with id %s does not exist", claims.AccountId) + } + + if _, ok := testAccount.Users[claims.UserId]; !ok { + return fmt.Errorf("user with id %s does not exist", claims.UserId) + } + + return nil +} + func TestAuthMiddleware_Handler(t *testing.T) { tt := []struct { name string @@ -108,7 +127,20 @@ func TestAuthMiddleware_Handler(t *testing.T) { // do nothing }) - authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim) + claimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(audience), + jwtclaims.WithUserIDClaim(userIDClaim), + ) + + authMiddleware := NewAuthMiddleware( + mockGetAccountFromPAT, + mockValidateAndParseToken, + mockMarkPATUsed, + mockCheckUserAccessByJWTGroups, + claimsExtractor, + audience, + userIDClaim, + ) handlerToTest := authMiddleware.Handler(nextHandler) diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index b564e4f4eb0..f218c1aa9c6 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -108,6 +108,8 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string, refreshedKeys = keys } + log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + keys = refreshedKeys } } @@ -179,7 +181,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { // stillValid returns true if the JSONWebKey still valid and have enough time to be used func (jwks *Jwks) stillValid() bool { - return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) + return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) } func getPemKeys(keysLocation string) (*Jwks, error) { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a349b35a9d1..4ca68ac216f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -69,6 +69,7 @@ type MockAccountManager struct { ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(accountID, userID string) error GetDNSDomainFunc func() string StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) @@ -366,7 +367,7 @@ func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*se func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) { if am.ListUsersFunc != nil { - return am.ListUsers(accountID) + return am.ListUsersFunc(accountID) } return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") } @@ -464,7 +465,7 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { - if am.SaveUserFunc != nil { + if am.SaveOrAddUserFunc != nil { return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists) } return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") @@ -543,9 +544,16 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") } +func (am *MockAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { + if am.CheckUserAccessByJWTGroupsFunc != nil { + return am.CheckUserAccessByJWTGroupsFunc(claims) + } + return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") +} + // GetPeers mocks GetPeers of the AccountManager interface func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) { - if am.GetAccountFromTokenFunc != nil { + if am.GetPeersFunc != nil { return am.GetPeersFunc(accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") diff --git a/management/server/policy.go b/management/server/policy.go index 0eb2fb5385c..d7e27a1b556 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -493,7 +493,11 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([] for _, p := range group.Peers { peer, ok := account.Peers[p] - if ok && peer != nil && peer.ID == peerID { + if !ok || peer == nil { + continue + } + + if peer.ID == peerID { peerInGroups = true continue }