From 389c9619afe8b6d129a12d137bf332491503bb83 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:31:41 +0300 Subject: [PATCH 01/60] Refactor setup key handling to use store methods Signed-off-by: bcmmbaga --- management/server/setupkey.go | 179 +++++++++++++++++----------- management/server/sql_store.go | 83 ++++++++----- management/server/sql_store_test.go | 4 +- management/server/status/error.go | 15 ++- management/server/store.go | 7 +- 5 files changed, 178 insertions(+), 110 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 43b6e02c936..f54eafdc1fd 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -4,7 +4,6 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" - "fmt" "hash/fnv" "strconv" "strings" @@ -12,9 +11,8 @@ import ( "unicode/utf8" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -226,34 +224,49 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { - return nil, err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) - account.SetupKeys[setupKey.Key] = setupKey - err = am.Store.SaveAccount(ctx, account) + var accountGroups []*nbgroup.Group + var setupKey *SetupKey + var plainKey string + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if err = validateSetupKeyAutoGroups(accountGroups, autoGroups); err != nil { + return err + } + + setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) + setupKey.AccountID = accountID + + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) + }) if err != nil { - return nil, status.Errorf(status.Internal, "failed adding account key") + return nil, err } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) + groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) + for _, g := range accountGroups { + groupMap[g.ID] = g + } for _, g := range setupKey.AutoGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupMap[g] + if ok { am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } @@ -268,43 +281,48 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + var accountGroups []*nbgroup.Group var oldKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyToSave.Id { - oldKey = key.Copy() - break + var newKey *SetupKey + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - } - if oldKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { - return nil, err - } + if err = validateSetupKeyAutoGroups(accountGroups, keyToSave.AutoGroups); err != nil { + return err + } - // only auto groups, revoked status, and name can be updated for now - newKey := oldKey.Copy() - newKey.Name = keyToSave.Name - newKey.AutoGroups = keyToSave.AutoGroups - newKey.Revoked = keyToSave.Revoked - newKey.UpdatedAt = time.Now().UTC() + oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) + if err != nil { + return err + } - account.SetupKeys[newKey.Key] = newKey + // only auto groups, revoked status, and name can be updated for now + newKey = oldKey.Copy() + newKey.Name = keyToSave.Name + newKey.AutoGroups = keyToSave.AutoGroups + newKey.Revoked = keyToSave.Revoked + newKey.UpdatedAt = time.Now().UTC() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) + }) + if err != nil { return nil, err } @@ -315,24 +333,25 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str defer func() { addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + + groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) + for _, g := range accountGroups { + groupMap[g.ID] = g + } + for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupMap[g] + if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } - } for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupMap[g] + if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } }() @@ -347,16 +366,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - return setupKeys, nil + return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. @@ -366,8 +384,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) @@ -387,21 +409,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return fmt.Errorf("failed to get user: %w", err) + return err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) - if err != nil { - return fmt.Errorf("failed to get setup key: %w", err) + if user.IsRegularUser() { + return status.NewAdminPermissionError() } - err = am.Store.DeleteSetupKey(ctx, accountID, keyID) + var deletedSetupKey *SetupKey + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + if err != nil { + return err + } + + return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + }) if err != nil { - return fmt.Errorf("failed to delete setup key: %w", err) + return err } am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) @@ -409,15 +439,22 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { - for _, group := range autoGroups { - g, ok := account.Groups[group] - if !ok { - return status.Errorf(status.NotFound, "group %s doesn't exist", group) +func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { + groupMap[g.ID] = g + } + + for _, groupID := range autoGroups { + g, exists := groupMap[groupID] + if !exists { + return status.Errorf(status.NotFound, "group %s doesn't exist", groupID) } + if g.Name == "All" { - return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") + return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } } + return nil } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 646184578eb..a11370e4f9d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -633,11 +633,11 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { startTime := time.Now() var groups []*nbgroup.Group - result := s.db.Find(&groups, accountIDCondition, accountID) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -645,8 +645,8 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*n if errors.Is(result.Error, context.Canceled) { return nil, status.NewStoreContextCanceledError(time.Since(startTime)) } - log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting groups from store") + log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account groups from the store") } return groups, nil @@ -1404,12 +1404,59 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) + var setupKeys []*SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&setupKeys, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup keys from store") + } + + return setupKeys, nil } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { - return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { + var setupKey *SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup key from store") + } + + return setupKey, nil +} + +// SaveSetupKey saves a setup key to the database. +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { + result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save setup key to store") + } + + return nil +} + +// DeleteSetupKey deletes a setup key from the database. +func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete setup key from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "setup key not found") + } + + return nil } // GetAccountNameServerGroups retrieves name server groups for an account. @@ -1422,10 +1469,6 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) } -func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { - return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID) -} - // getRecords retrieves records from the database based on the account ID. func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { var record []T @@ -1458,21 +1501,3 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } - -// deleteRecordByID deletes a record by its ID and account ID from the database. -func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { - var record T - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) - } - - if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "record not found") - } - - return nil -} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index b371e231319..3f3b2a453d4 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1274,7 +1274,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID) require.NoError(t, err) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) @@ -1290,6 +1290,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" nonExistingKeyID := "non-existing-key-id" - err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } diff --git a/management/server/status/error.go b/management/server/status/error.go index a145edf8002..5a75c94b1c1 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -111,11 +111,21 @@ func NewGetAccountFromStoreError(err error) error { return Errorf(Internal, "issue getting account from store: %s", err) } +// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account +func NewUserNotPartOfAccountError() error { + return Errorf(PermissionDenied, "user is not part of this account") +} + // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store func NewGetUserFromStoreError() error { return Errorf(Internal, "issue getting user from store") } +// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role. +func NewAdminPermissionError() error { + return Errorf(PermissionDenied, "admin role required to perform this action") +} + // NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context func NewStoreContextCanceledError(duration time.Duration) error { return Errorf(Internal, "store access: context canceled after %v", duration) @@ -125,8 +135,3 @@ func NewStoreContextCanceledError(duration time.Duration) error { func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") } - -// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key -func NewUnauthorizedToViewSetupKeysError() error { - return Errorf(Unauthorized, "only users with admin power can view setup keys") -} diff --git a/management/server/store.go b/management/server/store.go index 087c9884763..73c9ef6a692 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -70,7 +70,7 @@ type Store interface { DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error @@ -96,7 +96,9 @@ type Store interface { GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) @@ -124,7 +126,6 @@ type Store interface { // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error - DeleteSetupKey(ctx context.Context, accountID, keyID string) error } type StoreEngine string From 78044c226d9240edcdd5bb180aaab1da86f442e4 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:32:14 +0300 Subject: [PATCH 02/60] add lock to get account groups Signed-off-by: bcmmbaga --- management/server/account.go | 4 ++-- management/server/account_test.go | 2 +- management/server/group.go | 2 +- management/server/peer.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index aa7609388c0..583853f2504 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2029,7 +2029,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, accountID) + groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2059,7 +2059,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, accountID) + groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 6a2d85fe8f7..fdf004a3b8a 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2773,7 +2773,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") diff --git a/management/server/group.go b/management/server/group.go index bdb569e377f..b2ec88cc0d2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -59,7 +59,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us return nil, err } - return am.Store.GetAccountGroups(ctx, accountID) + return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers diff --git a/management/server/peer.go b/management/server/peer.go index 9c5ab571bab..8ced2a1deb0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -765,7 +765,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } From 1a5f3c653c4b78a5c52bca9bba74c966fbd7495c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:37:47 +0300 Subject: [PATCH 03/60] add check for regular user Signed-off-by: bcmmbaga --- management/server/user.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/management/server/user.go b/management/server/user.go index 9fdd3a6eeea..1368b76b121 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -103,6 +103,11 @@ func (u *User) IsAdminOrServiceUser() bool { return u.HasAdminPower() || u.IsServiceUser } +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + // ToUserInfo converts a User object to a UserInfo object. func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups From 931521d505b012f45dcf5bcb5de0ee07f0c5b876 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:59:37 +0300 Subject: [PATCH 04/60] get only required groups for auto-group validation Signed-off-by: bcmmbaga --- management/server/group/group.go | 5 ++++ management/server/setupkey.go | 46 +++++++++++++------------------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/management/server/group/group.go b/management/server/group/group.go index d293e1afc6f..e98e5ecc4b5 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -49,3 +49,8 @@ func (g *Group) Copy() *Group { func (g *Group) HasPeers() bool { return len(g.Peers) > 0 } + +// IsGroupAll checks if the group is a default "All" group. +func (g *Group) IsGroupAll() bool { + return g.Name == "All" +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index f54eafdc1fd..da248be25d6 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -233,20 +233,16 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } - var accountGroups []*nbgroup.Group + var groups []*nbgroup.Group var setupKey *SetupKey var plainKey string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups) if err != nil { return err } - if err = validateSetupKeyAutoGroups(accountGroups, autoGroups); err != nil { - return err - } - setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID @@ -257,8 +253,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) - for _, g := range accountGroups { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { groupMap[g.ID] = g } @@ -294,20 +290,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } - var accountGroups []*nbgroup.Group + var groups []*nbgroup.Group var oldKey *SetupKey var newKey *SetupKey err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups) if err != nil { return err } - if err = validateSetupKeyAutoGroups(accountGroups, keyToSave.AutoGroups); err != nil { - return err - } - oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return err @@ -334,8 +326,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) - for _, g := range accountGroups { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { groupMap[g.ID] = g } @@ -439,22 +431,20 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error { - groupMap := make(map[string]*nbgroup.Group, len(groups)) - for _, g := range groups { - groupMap[g.ID] = g - } +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) ([]*nbgroup.Group, error) { + autoGroups := make([]*nbgroup.Group, 0, len(autoGroupIDs)) - for _, groupID := range autoGroups { - g, exists := groupMap[groupID] - if !exists { - return status.Errorf(status.NotFound, "group %s doesn't exist", groupID) + for _, groupID := range autoGroupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + if err != nil { + return nil, err } - if g.Name == "All" { - return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") + if group.IsGroupAll() { + return nil, status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } + autoGroups = append(autoGroups, group) } - return nil + return autoGroups, nil } From f8b5eedd382d8a218517cf7c7b552f3a0dd8ee3d Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 10:14:13 +0300 Subject: [PATCH 05/60] add account lock and return auto groups map on validation Signed-off-by: bcmmbaga --- management/server/setupkey.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index da248be25d6..65d7796f1a0 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -224,6 +224,9 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -233,7 +236,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } - var groups []*nbgroup.Group + var groups map[string]*nbgroup.Group var setupKey *SetupKey var plainKey string @@ -253,13 +256,9 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - groupMap := make(map[string]*nbgroup.Group, len(groups)) - for _, g := range groups { - groupMap[g.ID] = g - } for _, g := range setupKey.AutoGroups { - group, ok := groupMap[g] + group, ok := groups[g] if ok { am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) @@ -281,6 +280,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -290,7 +292,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } - var groups []*nbgroup.Group + var groups map[string]*nbgroup.Group var oldKey *SetupKey var newKey *SetupKey @@ -326,13 +328,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - groupMap := make(map[string]*nbgroup.Group, len(groups)) - for _, g := range groups { - groupMap[g.ID] = g - } - for _, g := range removedGroups { - group, ok := groupMap[g] + group, ok := groups[g] if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) @@ -340,7 +337,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } for _, g := range addedGroups { - group, ok := groupMap[g] + group, ok := groups[g] if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) @@ -431,8 +428,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) ([]*nbgroup.Group, error) { - autoGroups := make([]*nbgroup.Group, 0, len(autoGroupIDs)) +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) (map[string]*nbgroup.Group, error) { + autoGroups := map[string]*nbgroup.Group{} for _, groupID := range autoGroupIDs { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) @@ -443,7 +440,7 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI if group.IsGroupAll() { return nil, status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } - autoGroups = append(autoGroups, group) + autoGroups[group.ID] = group } return autoGroups, nil From 106fc759365d535db529d93c9c1ad0324b2ccff6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:38:32 +0300 Subject: [PATCH 06/60] refactor account peers update Signed-off-by: bcmmbaga --- management/server/account.go | 25 ++++++++++++------------- management/server/dns.go | 2 +- management/server/nameserver.go | 6 +++--- management/server/peer.go | 23 ++++++++++++++--------- management/server/peer_test.go | 2 +- management/server/policy.go | 4 ++-- management/server/posture_checks.go | 2 +- management/server/route.go | 6 +++--- management/server/user.go | 8 ++++---- 9 files changed, 41 insertions(+), 37 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 583853f2504..2b18c344101 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -110,7 +110,6 @@ type AccountManager interface { SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) @@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + removedGroupAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, removeOldGroups) if err != nil { - return fmt.Errorf("error getting account: %w", err) + return err } - if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + newGroupsAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, addNewGroups) + if err != nil { + return err + } + + if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } } @@ -2398,12 +2402,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - return - } - am.updateAccountPeers(ctx, updatedAccount) + am.updateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { diff --git a/management/server/dns.go b/management/server/dns.go index 256b8b12512..4551be5ab92 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5ebd263dcc2..957008714e5 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, newNSGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun } if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, nsGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) diff --git a/management/server/peer.go b/management/server/peer.go index 8ced2a1deb0..994cc02879c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -131,7 +131,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -267,7 +267,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return peer, nil @@ -344,7 +344,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -551,7 +551,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, accountID) + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -597,7 +597,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s groupsToAdd = append(groupsToAdd, allGroup.ID) if areGroupChangesAffectPeers(account, groupsToAdd) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -661,7 +661,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } } @@ -680,7 +680,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } validPeersMap, err := am.GetValidatedPeers(account) @@ -811,7 +811,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -974,7 +974,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { start := time.Now() defer func() { if am.metrics != nil { @@ -982,6 +982,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 78885ea1b72..4e2dcb2c313 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account) + manager.updateAccountPeers(ctx, account.Id) } duration := time.Since(start) diff --git a/management/server/policy.go b/management/server/policy.go index 43a925f8850..8a5733f011c 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) if anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 2dccd8f590c..096cff3f5c9 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route.go b/management/server/route.go index 1cf00b37c46..dcf2cb0d32c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/user.go b/management/server/user.go index 1368b76b121..38b820cb41b 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -492,7 +492,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -833,7 +833,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -1124,7 +1124,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil } @@ -1232,7 +1232,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta { From 0a70e4c5d45292223c78427984fb470aaf0a9a40 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:39:36 +0300 Subject: [PATCH 07/60] Refactor groups to use store methods Signed-off-by: bcmmbaga --- management/server/group.go | 390 ++++++++++++------ management/server/integrated_validator.go | 27 +- management/server/mock_server/account_mock.go | 9 - management/server/sql_store.go | 81 +++- management/server/store.go | 7 +- 5 files changed, 355 insertions(+), 159 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index b2ec88cc0d2..da4c0fb9415 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "groups are blocked for users") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() && settings.RegularUsersViewBlocked { + return status.NewAdminPermissionError() } return nil @@ -49,8 +53,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - - return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -58,13 +61,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) + return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers @@ -78,12 +80,19 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var eventsToStore []func() + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var ( + eventsToStore []func() + groupsToSave []*nbgroup.Group + ) for _, newGroup := range newGroups { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { @@ -91,7 +100,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := account.FindGroupByName(newGroup.Name) + existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) if err != nil { s, ok := status.FromError(err) if !ok || s.ErrorType != status.NotFound { @@ -109,15 +118,15 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { + if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } } - oldGroup := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) + events := am.prepareGroupEvents(ctx, userID, accountID, newGroup) eventsToStore = append(eventsToStore, events...) } @@ -126,30 +135,45 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user newGroupIDs = append(newGroupIDs, newGroup.ID) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs) + if err != nil { return err } - if areGroupChangesAffectPeers(account, newGroupIDs) { - am.updateAccountPeers(ctx, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { + return fmt.Errorf("failed to save groups: %w", err) + } + return nil + }) + if err != nil { + return err } for _, storeEvent := range eventsToStore { storeEvent() } + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if oldGroup != nil { + oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { @@ -159,12 +183,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range addedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range addedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, @@ -175,12 +200,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range removedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range removedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, @@ -210,119 +236,108 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers. -func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return nil + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - allGroup, err := account.GetGroupAll() + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - if allGroup.ID == groupID { + if group.Name == "All" { return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") } - if err = validateDeleteGroup(account, group, userId); err != nil { + if err = am.validateDeleteGroup(ctx, group, userID); err != nil { return err } - delete(account.Groups, groupID) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) + if err != nil { return err } - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) + am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) return nil } // DeleteGroups deletes groups from an account. -// Note: This function does not acquire the global lock. -// It is the caller's responsibility to ensure proper locking is in place before invoking this method. -// -// If an error occurs while deleting a group, the function skips it and continues deleting other groups. -// Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var allErrors error + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var ( + allErrors error + groupIDsToDelete []string + deletedGroups []*nbgroup.Group + ) - deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) for _, groupID := range groupIDs { - group, ok := account.Groups[groupID] - if !ok { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { continue } - if err := validateDeleteGroup(account, group, userId); err != nil { + if err := am.validateDeleteGroup(ctx, group, userID); err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) continue } - delete(account.Groups, groupID) + groupIDsToDelete = append(groupIDsToDelete, groupID) deletedGroups = append(deletedGroups, group) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - for _, g := range deletedGroups { - am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) - } - - return allErrors -} - -// ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - account, err := am.Store.GetAccount(ctx, accountID) + if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) if err != nil { - return nil, err + return err } - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) + for _, group := range deletedGroups { + am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } - return groups, nil + return allErrors } // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - add := true for _, itemID := range group.Peers { if itemID == peerID { @@ -334,13 +349,27 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr group.Peers = append(group.Peers, peerID) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { + return err + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) + } + return nil + }) + if err != nil { return err } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil @@ -348,41 +377,55 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - account.Network.IncSerial() + updated := false for i, itemID := range group.Peers { if itemID == peerID { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(ctx, account); err != nil { - return err - } + updated = true + break + } + } + + if !updated { + return nil + } + + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { + return err + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) } + return nil + }) + if err != nil { + return err } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil } -func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { +func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userID] - if executingUser == nil { + executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { return status.Errorf(status.NotFound, "user not found") } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { @@ -390,32 +433,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) } } - if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { return &GroupLinkError{"disabled DNS management groups", group.Name} } - if account.Settings.Extra != nil { - if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if settings.Extra != nil { + if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { return &GroupLinkError{"integrated validator", group.Name} } } @@ -424,17 +477,30 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { +func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) { + routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + for _, r := range routes { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { return true, r } } + return false, nil } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { +func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) { + policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + for _, policy := range policies { for _, rule := range policy.Rules { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { @@ -446,7 +512,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { +func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + for _, dns := range nameServerGroups { for _, g := range dns.Groups { if g == groupID { @@ -454,11 +526,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou } } } + return false, nil } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { +func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + for _, setupKey := range setupKeys { if slices.Contains(setupKey.AutoGroups, groupID) { return true, setupKey @@ -468,7 +547,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { +func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) { + users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + for _, user := range users { if slices.Contains(user.AutoGroups, groupID) { return true, user @@ -477,6 +562,69 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { return false, nil } +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, groupID := range groupIDs { + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil + } + if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked { + return true, nil + } + } + + return false, nil +} + +// isGroupLinkedToRoute checks if a group is linked to any route in the account. +func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { + for _, r := range routes { + if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { + return true, r + } + } + return false, nil +} + +// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. +func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { + for _, policy := range policies { + for _, rule := range policy.Rules { + if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { + return true, policy + } + } + } + return false, nil +} + +// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. +func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { + for _, dns := range nameServerGroups { + for _, g := range dns.Groups { + if g == groupID { + return true, dns + } + } + } + return false, nil +} + // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 99e6b204c2b..0c70b702a01 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { - if len(groups) == 0 { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(ctx, accountId) - if err != nil { - return false, err - } - for _, group := range groups { - var found bool - for _, accountGroup := range accountsGroups { - if accountGroup.ID == group { - found = true - break + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } } - if !found { - return false, nil - } + return nil + }) + if err != nil { + return false, err } return true, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d7139bb2a5f..aa6a47b152e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,7 +45,6 @@ type MockAccountManager struct { SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error @@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") } -// ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { - if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(ctx, accountID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") -} - // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a11370e4f9d..506142453e6 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -614,11 +614,11 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { startTime := time.Now() var users []*User - result := s.db.Find(&users, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -1240,10 +1240,27 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { +// GetPeerByID retrieves a peer by its ID and account ID. +func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + var peer *nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, accountAndIDQueryCondition, accountID, peerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return peer, nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { startTime := time.Now() - result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { if errors.Is(result.Error, context.Canceled) { return status.NewStoreContextCanceledError(time.Since(startTime)) @@ -1336,42 +1353,82 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { + var group *nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + log.WithContext(ctx).Errorf("failed to get group from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get group from store") + } + + return group, nil } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { var group nbgroup.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. - query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) if s.storeEngine == PostgresStoreEngine { query = query.Order("json_array_length(peers::json) DESC") } else { query = query.Order("json_array_length(peers) DESC") } - result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) + result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") } - return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get group by name from store") } return &group, nil } // SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save group to store") } return nil } +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete group from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "group not found") + } + + return nil +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { + result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + } + + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/store.go b/management/server/store.go index 73c9ef6a692..cb3c533dd09 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -62,7 +62,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -75,6 +75,8 @@ type Store interface { GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -89,6 +91,7 @@ type Store interface { AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) + GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -107,7 +110,7 @@ type Store interface { GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string From 8126d953166ddfa79950469f42d0a8dc5084ce71 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:58:04 +0300 Subject: [PATCH 08/60] refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga --- management/server/account.go | 4 ++-- management/server/setupkey.go | 2 +- management/server/sql_store.go | 8 ++++---- management/server/status/error.go | 5 +++++ 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 2b18c344101..2902bc9521c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2100,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2113,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 65d7796f1a0..a3330bba888 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -432,7 +432,7 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI autoGroups := map[string]*nbgroup.Group{} for _, groupID := range autoGroupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 506142453e6..3707aa9cec1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1196,7 +1196,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "group not found for account") + return status.NewGroupNotFoundError(groupID) } if errors.Is(result.Error, context.Canceled) { return status.NewStoreContextCanceledError(time.Since(startTime)) @@ -1358,7 +1358,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") + return nil, status.NewGroupNotFoundError(groupID) } log.WithContext(ctx).Errorf("failed to get group from store: %s", err) return nil, status.Errorf(status.Internal, "failed to get group from store") @@ -1383,7 +1383,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") + return nil, status.NewGroupNotFoundError(groupName) } log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) return nil, status.Errorf(status.Internal, "failed to get group by name from store") @@ -1411,7 +1411,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "group not found") + return status.NewGroupNotFoundError(groupID) } return nil diff --git a/management/server/status/error.go b/management/server/status/error.go index 5a75c94b1c1..00be347ada4 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -135,3 +135,8 @@ func NewStoreContextCanceledError(duration time.Duration) error { func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") } + +// NewGroupNotFoundError creates a new Error with NotFound type for a missing group +func NewGroupNotFoundError(groupID string) error { + return Errorf(NotFound, "group: %s not found", groupID) +} From ac05f69131651fded5f6a304b7dbe2b517a72b31 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:58:19 +0300 Subject: [PATCH 09/60] fix tests Signed-off-by: bcmmbaga --- management/server/account_test.go | 12 +++++++----- management/server/route_test.go | 2 +- management/server/sql_store_test.go | 8 ++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index fdf004a3b8a..97e0d45f016 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - group := group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - } + }) + + require.NoError(t, err, "failed to save group") policy := Policy{ Enabled: true, @@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil { t.Errorf("delete group: %v", err) return } @@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) @@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) diff --git a/management/server/route_test.go b/management/server/route_test.go index 4893e19b9f3..5c848f68c7b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(context.Background(), account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 3f3b2a453d4..20409798b0e 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Fatal("failed to save group") return err } - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID) if err != nil { t.Fatal("failed to get group") return err @@ -1201,7 +1201,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) - users, err := store.GetAccountUsers(context.Background(), accountID) + users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) require.Len(t, users, len(account.Users)) } @@ -1260,9 +1260,9 @@ func TestSqlite_GetGroupByName(t *testing.T) { } accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, err) - require.Equal(t, "All", group.Name) + require.True(t, group.IsGroupAll()) } func Test_DeleteSetupKeySuccessfully(t *testing.T) { From 7100be83cdd002c25b2bf824687f14f1b183f770 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:14:30 +0300 Subject: [PATCH 10/60] Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga --- management/server/group/group.go | 29 +++++++++ management/server/group/group_test.go | 90 +++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 management/server/group/group_test.go diff --git a/management/server/group/group.go b/management/server/group/group.go index e98e5ecc4b5..bb0f5b7b6e2 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -54,3 +54,32 @@ func (g *Group) HasPeers() bool { func (g *Group) IsGroupAll() bool { return g.Name == "All" } + +// AddPeer adds peerID to Peers if not already present, +// returning true if added. +func (g *Group) AddPeer(peerID string) bool { + if peerID == "" { + return false + } + + for _, itemID := range g.Peers { + if itemID == peerID { + return false + } + } + + g.Peers = append(g.Peers, peerID) + return true +} + +// RemovePeer removes peerID from Peers if present, +// returning true if removed. +func (g *Group) RemovePeer(peerID string) bool { + for i, itemID := range g.Peers { + if itemID == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + return true + } + } + return false +} diff --git a/management/server/group/group_test.go b/management/server/group/group_test.go new file mode 100644 index 00000000000..cb002f8d9e1 --- /dev/null +++ b/management/server/group/group_test.go @@ -0,0 +1,90 @@ +package group + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddPeer(t *testing.T) { + t.Run("add new peer to empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to non-empty slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add duplicate peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer1" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("add empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} + +func TestRemovePeer(t *testing.T) { + t.Run("remove existing peer from slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2", "peer3"}} + peerID := "peer2" + assert.True(t, group.RemovePeer(peerID)) + assert.NotContains(t, group.Peers, peerID) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + }) + + t.Run("remove peer from nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Nil(t, group.Peers) + }) + + t.Run("remove non-existent peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from single-item slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1"}} + peerID := "peer1" + assert.True(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + assert.NotContains(t, group.Peers, peerID) + }) + + t.Run("remove empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} From 6dc185e141c4e10c64c8879f54a8338fd4e4c01d Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:16:03 +0300 Subject: [PATCH 11/60] Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga --- management/server/sql_store.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index df0f2b3178b..8a0f432e6ae 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1116,7 +1116,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor func (s *SqlStore) withTx(tx *gorm.DB) Store { return &SqlStore{ - db: tx, + db: tx, + storeEngine: s.storeEngine, } } From bdeb95c58c2081b0a77776692398bc6c20be2b60 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:17:01 +0300 Subject: [PATCH 12/60] Run groups ops in transaction Signed-off-by: bcmmbaga --- management/server/account.go | 4 +- management/server/group.go | 384 ++++++++++++++--------------------- management/server/peer.go | 18 +- 3 files changed, 173 insertions(+), 233 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 2902bc9521c..043b797ab41 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2126,12 +2126,12 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, removeOldGroups) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) if err != nil { return err } - newGroupsAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, addNewGroups) + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) if err != nil { return err } diff --git a/management/server/group.go b/management/server/group.go index da4c0fb9415..c49bb247186 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -79,7 +79,7 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -89,66 +89,35 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return status.NewUserNotPartOfAccountError() } - var ( - eventsToStore []func() - groupsToSave []*nbgroup.Group - ) - - for _, newGroup := range newGroups { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { - return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) - } + var eventsToStore []func() + var groupsToSave []*nbgroup.Group + var updateAccountPeers bool - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) - if err != nil { - s, ok := status.FromError(err) - if !ok || s.ErrorType != status.NotFound { - return err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err } - // Avoid duplicate groups only for the API issued groups. - // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. - if existingGroup != nil { - return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) - } + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) - newGroup.ID = xid.New().String() + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) } - for _, peerID := range newGroup.Peers { - if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err } - newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) - - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup) - eventsToStore = append(eventsToStore, events...) - } - - newGroupIDs := make([]string, 0, len(newGroups)) - for _, newGroup := range newGroups { - newGroupIDs = append(newGroupIDs, newGroup.ID) - } - - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs) - if err != nil { - return err - } - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { - return fmt.Errorf("failed to save groups: %w", err) - } - return nil + return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) }) if err != nil { return err @@ -166,13 +135,13 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) @@ -184,36 +153,34 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID } for _, peerID := range addedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) if err != nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err) continue } - peerCopy := peer // copy to avoid closure issues + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } for _, peerID := range removedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) if err != nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err) continue } - peerCopy := peer // copy to avoid closure issues + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } @@ -246,28 +213,27 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.NewUserNotPartOfAccountError() } - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } + var group *nbgroup.Group - if group.Name == "All" { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + return err + } - if err = am.validateDeleteGroup(ctx, group, userID); err != nil { - return err - } + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { return err } - if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil { - return fmt.Errorf("failed to delete group: %w", err) + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - return nil + + return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) }) if err != nil { return err @@ -279,6 +245,11 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use } // DeleteGroups deletes groups from an account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +// +// If an error occurs while deleting a group, the function skips it and continues deleting other groups. +// Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { @@ -289,36 +260,31 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return status.NewUserNotPartOfAccountError() } - var ( - allErrors error - groupIDsToDelete []string - deletedGroups []*nbgroup.Group - ) + var allErrors error + var groupIDsToDelete []string + var deletedGroups []*nbgroup.Group - for _, groupID := range groupIDs { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - continue - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + continue + } - if err := am.validateDeleteGroup(ctx, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) - continue - } + if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { + allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + continue + } - groupIDsToDelete = append(groupIDsToDelete, groupID) - deletedGroups = append(deletedGroups, group) - } + groupIDsToDelete = append(groupIDsToDelete, groupID) + deletedGroups = append(deletedGroups, group) + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { - return fmt.Errorf("failed to delete group: %w", err) - } - return nil + return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) }) if err != nil { return err @@ -333,36 +299,30 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } + var group *nbgroup.Group + var updateAccountPeers bool + var err error - add := true - for _, itemID := range group.Peers { - if itemID == peerID { - add = false - break + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } - } - if add { - group.Peers = append(group.Peers, peerID) - } - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) - if err != nil { - return err - } + if updated := group.AddPeer(peerID); !updated { + return nil + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { return err } - if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { - return fmt.Errorf("failed to save group: %w", err) + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - return nil + + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) }) if err != nil { return err @@ -377,38 +337,30 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } + var group *nbgroup.Group + var updateAccountPeers bool + var err error - updated := false - for i, itemID := range group.Peers { - if itemID == peerID { - group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - updated = true - break + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } - } - if !updated { - return nil - } + if updated := group.RemovePeer(peerID); !updated { + return nil + } - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) - if err != nil { - return err - } + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { - return fmt.Errorf("failed to save group: %w", err) - } - return nil + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) }) if err != nil { return err @@ -421,10 +373,43 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return nil } -func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error { +// validateNewGroup validates the new group for existence and required fields. +func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } + + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if err != nil { + if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { + return err + } + } + + // Prevent duplicate groups for API-issued groups. + // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() + } + + for _, peerID := range newGroup.Peers { + _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) + } + } + + return nil +} + +func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return status.Errorf(status.NotFound, "user not found") } @@ -433,27 +418,27 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } } - if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -462,7 +447,7 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -477,8 +462,8 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) { - routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil @@ -494,8 +479,8 @@ func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accou } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) { - policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -512,8 +497,8 @@ func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, acco } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -531,8 +516,8 @@ func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, account } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) { - setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -547,8 +532,8 @@ func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, ac } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) { - users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { + users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -563,12 +548,12 @@ func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accoun } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { +func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) if err != nil { return false, err } @@ -577,13 +562,13 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { return true, nil } - if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { return true, nil } - if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { return true, nil } - if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { return true, nil } } @@ -591,40 +576,6 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, return false, nil } -// isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { - for _, r := range routes { - if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { - return true, r - } - } - return false, nil -} - -// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { - for _, policy := range policies { - for _, rule := range policy.Rules { - if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { - return true, policy - } - } - } - return false, nil -} - -// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { - for _, dns := range nameServerGroups { - for _, g := range dns.Groups { - if g == groupID { - return true, dns - } - } - } - return false, nil -} - // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { @@ -634,22 +585,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } - -func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { - return true - } - if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { - return true - } - } - - return false -} diff --git a/management/server/peer.go b/management/server/peer.go index 994cc02879c..33f27d8c7e0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -331,7 +331,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers := isPeerInActiveGroup(account, peerID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) + if err != nil { + return err + } err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { @@ -594,9 +597,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) } - groupsToAdd = append(groupsToAdd, allGroup.ID) - if areGroupChangesAffectPeers(account, groupsToAdd) { + + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + if err != nil { + return nil, nil, nil, err + } + + if newGroupsAffectsPeers { am.updateAccountPeers(ctx, accountID) } @@ -1033,12 +1041,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(account *Account, peerID string) bool { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { peerGroupIDs = append(peerGroupIDs, group.ID) } } - return areGroupChangesAffectPeers(account, peerGroupIDs) + return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) } From 3ed8b9cee93e7d45f3d27210606536c06169ab06 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:48:28 +0300 Subject: [PATCH 13/60] fix missing group removed from setup key activity Signed-off-by: bcmmbaga --- management/server/setupkey.go | 95 +++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 65d7796f1a0..2e8230d1ccb 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -12,8 +12,8 @@ import ( "github.com/google/uuid" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -236,19 +236,21 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } - var groups map[string]*nbgroup.Group var setupKey *SetupKey var plainKey string + var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups) - if err != nil { + if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { return err } setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID + events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) + eventsToStore = append(eventsToStore, events...) + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) }) if err != nil { @@ -256,13 +258,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - - for _, g := range setupKey.AutoGroups { - group, ok := groups[g] - if ok { - am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) - } + for _, storeEvent := range eventsToStore { + storeEvent() } // for the creation return the plain key to the caller @@ -292,13 +289,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } - var groups map[string]*nbgroup.Group var oldKey *SetupKey var newKey *SetupKey + var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups) - if err != nil { + if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { return err } @@ -314,6 +310,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() + addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) + removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + + events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) + eventsToStore = append(eventsToStore, events...) + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) }) if err != nil { @@ -324,26 +326,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) } - defer func() { - addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) - removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - - for _, g := range removedGroups { - group, ok := groups[g] - if ok { - am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } - } - - for _, g := range addedGroups { - group, ok := groups[g] - if ok { - am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } - } - }() + for _, storeEvent := range eventsToStore { + storeEvent() + } return newKey, nil } @@ -412,7 +397,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, var deletedSetupKey *SetupKey err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) if err != nil { return err } @@ -428,20 +413,46 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) (map[string]*nbgroup.Group, error) { - autoGroups := map[string]*nbgroup.Group{} - +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { for _, groupID := range autoGroupIDs { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) if err != nil { - return nil, err + return err } if group.IsGroupAll() { - return nil, status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") + return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") + } + } + + return nil +} + +// prepareSetupKeyEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { + var eventsToStore []func() + + for _, g := range removedGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + continue + } + + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta) + } + + for _, g := range addedGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + continue } - autoGroups[group.ID] = group + + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta) } - return autoGroups, nil + return eventsToStore } From 871500c5cc0523cfb5a0032a2a56bfd10366edf3 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:52:09 +0300 Subject: [PATCH 14/60] fix merge Signed-off-by: bcmmbaga --- management/server/setupkey.go | 4 ++-- management/server/store.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 160e934482f..d6e92fe3ab1 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -433,7 +433,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran var eventsToStore []func() for _, g := range removedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) continue @@ -444,7 +444,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran } for _, g := range addedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) continue diff --git a/management/server/store.go b/management/server/store.go index cb3c533dd09..68b57204b55 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -71,7 +71,7 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error From 174e07fefda60632effa26df6e04e53f09eb1bbe Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 12:37:19 +0300 Subject: [PATCH 15/60] Refactor posture checks to remove get and save account Signed-off-by: bcmmbaga --- management/server/account.go | 2 +- .../server/http/posture_checks_handler.go | 3 +- management/server/mock_server/account_mock.go | 6 +- management/server/posture/checks.go | 6 - management/server/posture_checks.go | 303 +++++++++++------- management/server/posture_checks_test.go | 211 +++++++----- management/server/sql_store.go | 54 +++- management/server/status/error.go | 5 + management/server/store.go | 4 +- 9 files changed, 377 insertions(+), 217 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 043b797ab41..8ebbb0fa0a0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -139,7 +139,7 @@ type AccountManager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 1d020e9bcb7..2c820429278 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. return } - if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { + postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks) + if err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index aa6a47b152e..673ed33bb9b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -96,7 +96,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { if am.SavePostureChecksFunc != nil { return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } - return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index f2739dddf8d..b2f308d76e2 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,6 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh } func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { - if postureChecksID == "" { - postureChecksID = xid.New().String() - } - postureChecks := Checks{ ID: postureChecksID, Name: name, diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 096cff3f5c9..d7b5a79a23b 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,16 +2,15 @@ package server import ( "context" + "fmt" "slices" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -20,85 +19,104 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) -} + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() + } -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) +} - account, err := am.Store.GetAccount(ctx, accountID) +// SavePostureChecks saves a posture check. +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return nil, status.NewAdminPermissionError() } - if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint - } + var updateAccountPeers bool + var isUpdate = postureChecks.ID != "" + var action = activity.PostureCheckCreated - exists, uniqName := am.savePostureChecks(account, postureChecks) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { + return err + } - // we do not allow create new posture checks with non uniq name - if !exists && !uniqName { - return status.Errorf(status.PreconditionFailed, "Posture check name should be unique") - } + if isUpdate { + updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) + if err != nil { + return err + } - action := activity.PostureCheckCreated - if exists { - action = activity.PostureCheckUpdated - account.Network.IncSerial() - } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + action = activity.PostureCheckUpdated + } + + postureChecks.AccountID = accountID + return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + }) + if err != nil { + return nil, err } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - return nil + return postureChecks, nil } +// DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return status.NewAdminPermissionError() } - postureChecks, err := am.deletePostureChecks(account, postureChecksID) - if err != nil { - return err - } + var postureChecks *posture.Checks - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + if err != nil { + return err + } + + if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + }) + if err != nil { return err } @@ -107,132 +125,173 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return nil } +// ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { - uniqName = true - for i, p := range account.PostureChecks { - if !exists && p.ID == postureChecks.ID { - account.PostureChecks[i] = postureChecks - exists = true +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if len(postureChecks) == 0 { + return nil } - if p.Name == postureChecks.Name { - uniqName = false + + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { + return err + } } + + return nil + }) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err } - if !exists { - account.PostureChecks = append(account.PostureChecks, postureChecks) + + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureCheckID) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + } } - return + + return false, nil } -func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { - postureChecksIdx := -1 - for i, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - postureChecksIdx = i - break +// validatePostureChecks validates the posture checks. +func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { + if err := postureChecks.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, err.Error()) //nolint + } + + // If the posture check already has an ID, verify its existence in the store. + if postureChecks.ID != "" { + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + return err } + return nil } - if postureChecksIdx < 0 { - return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) + + // For new posture checks, ensure no duplicates by name. + checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - // Check if posture check is linked to any policy - if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) + for _, check := range checks { + if check.Name == postureChecks.Name && check.ID != postureChecks.ID { + return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name) + } } - postureChecks := account.PostureChecks[postureChecksIdx] - account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...) + postureChecks.ID = xid.New().String() - return postureChecks, nil + return nil } -// getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { - peerPostureChecks := make(map[string]posture.Checks) +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) + if err != nil { + return err + } - if len(account.PostureChecks) == 0 { + if !isInGroup { return nil } - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - if isPeerInPolicySourceGroups(peer.ID, account, policy) { - addPolicyPostureChecks(account, policy, peerPostureChecks) + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) + if err != nil { + return err } + peerPostureChecks[sourcePostureCheckID] = postureCheck } - postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - checkCopy := check - postureChecksList = append(postureChecksList, &checkCopy) - } - - return postureChecksList + return nil } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { +func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, ok := account.Groups[sourceGroup] - if ok && slices.Contains(group.Peers, peerID) { - return true + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) + if err != nil { + log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err) + return false, fmt.Errorf("failed to check peer in policy source group: %w", err) } - } - } - return false -} - -func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) { - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - for _, postureCheck := range account.PostureChecks { - if postureCheck.ID == sourcePostureCheckID { - peerPostureChecks[sourcePostureCheckID] = *postureCheck + if slices.Contains(group.Peers, peerID) { + return true, nil } } } -} -func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { - for _, policy := range account.Policies { - if slices.Contains(policy.SourcePostureChecks, postureChecksID) { - return true, policy - } - } return false, nil } -// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { - if !exists { - return false +// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) - if !isLinked { - return false + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) + } } - return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) + + return nil } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index c63538b9d52..3c5c5fc79e6 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -7,6 +7,7 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/group" @@ -16,7 +17,6 @@ import ( const ( adminUserID = "adminUserID" regularUserID = "regularUserID" - postureCheckID = "existing-id" postureCheckName = "Existing check" ) @@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check @@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, + postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: "new-id", + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, - Name: postureCheckName, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.27.0", - }, + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.27.0", }, - }) + } + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) @@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - postureCheck := posture.Checks{ - ID: "postureCheck", - Name: "postureCheck", + postureCheckA := &posture.Checks{ + Name: "postureCheckA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA) + require.NoError(t, err) + + postureCheckB := &posture.Checks{ + Name: "postureCheckB", AccountID: account.Id, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } // Linking posture check to policy should trigger update account peers and send peer update @@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture checks should update account peers and send peer update t.Run("updating linked to posture check with peers", func(t *testing.T) { - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, @@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID) assert.NoError(t, err) select { @@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update @@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) assert.NoError(t, err) @@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) @@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) assert.NoError(t, err) @@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ ProcessCheck: &posture.ProcessCheck{ Processes: []posture.Process{ { @@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }) } -func TestArePostureCheckChangesAffectingPeers(t *testing.T) { - account := &Account{ - Policies: []*Policy{ - { - ID: "policyA", - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - }, - }, - SourcePostureChecks: []string{"checkA"}, - }, +func TestArePostureCheckChangesAffectPeers(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestPostureChecksAccount(manager) + require.NoError(t, err, "failed to init testing account") + + groupA := &group.Group{ + ID: "groupA", + AccountID: account.Id, + Peers: []string{"peer1"}, + } + + groupB := &group.Group{ + ID: "groupB", + AccountID: account.Id, + Peers: []string{}, + } + err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + require.NoError(t, err, "failed to save groups") + + postureCheckA := &posture.Checks{ + Name: "checkA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, - Groups: map[string]*group.Group{ - "groupA": { - ID: "groupA", - Peers: []string{"peer1"}, - }, - "groupB": { - ID: "groupB", - Peers: []string{}, - }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + require.NoError(t, err, "failed to save postureCheckA") + + postureCheckB := &posture.Checks{ + Name: "checkB", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, - PostureChecks: []*posture.Checks{ - { - ID: "checkA", - }, + } + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + require.NoError(t, err, "failed to save postureCheckB") + + policy := &Policy{ + ID: "policyA", + AccountID: account.Id, + Rules: []*PolicyRule{ { - ID: "checkB", + ID: "ruleA", + PolicyID: "policyA", + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, }, }, + SourcePostureChecks: []string{postureCheckA.ID}, } + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + require.NoError(t, err, "failed to save policy") + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupB"} - account.Policies[0].Rules[0].Destinations = []string{"groupA"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupB"} + policy.Rules[0].Destinations = []string{"groupA"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupA"} - account.Policies[0].Rules[0].Destinations = []string{"groupB"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupA"} + policy.Rules[0].Destinations = []string{"groupB"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) - t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} - account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + groupA.Peers = []string{} + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + require.NoError(t, err, "failed to save groups") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) - t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { - account.Groups["groupA"].Peers = []string{} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + policy.Rules[0].Sources = []string{"nonExistentGroup"} + policy.Rules[0].Destinations = []string{"nonExistentGroup"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 8a0f432e6ae..466d36aff92 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1257,12 +1257,60 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db, lockStrength, accountID) + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks from store") + } + + return postureChecks, nil } // GetPostureChecksByID retrieves posture checks by their ID and account ID. -func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPostureChecksNotFoundError(postureChecksID) + } + log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture check from store") + } + + return postureCheck, nil +} + +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrDuplicatedKey) { + return status.Errorf(status.InvalidArgument, "name should be unique") + } + log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save posture checks to store") + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete posture checks from store") + } + + if result.RowsAffected == 0 { + return status.NewPostureChecksNotFoundError(postureChecksID) + } + + return nil } // GetAccountRoutes retrieves network routes for an account. diff --git a/management/server/status/error.go b/management/server/status/error.go index 00be347ada4..bdf5c754946 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -140,3 +140,8 @@ func NewInvalidKeyIDError() error { func NewGroupNotFoundError(groupID string) error { return Errorf(NotFound, "group: %s not found", groupID) } + +// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks +func NewPostureChecksNotFoundError(postureChecksID string) error { + return Errorf(NotFound, "posture checks: %s not found", postureChecksID) +} diff --git a/management/server/store.go b/management/server/store.go index 68b57204b55..7e258104558 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -83,7 +83,9 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) - GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error From d54b6967ce28b07ff799c08a8d4d789b0dfde322 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 12:38:34 +0300 Subject: [PATCH 16/60] fix refactor Signed-off-by: bcmmbaga --- management/server/dns.go | 2 +- management/server/group.go | 19 +++++++++++++++++-- management/server/nameserver.go | 10 +++++----- management/server/peer.go | 25 +++++++++++++++++++++---- management/server/policy.go | 6 +++--- management/server/route.go | 10 +++++----- 6 files changed, 52 insertions(+), 20 deletions(-) diff --git a/management/server/dns.go b/management/server/dns.go index 4551be5ab92..e52be601639 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { am.updateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index c49bb247186..ee42b0064a7 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -576,8 +576,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -585,3 +584,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + return false, err + } + + if group.HasPeers() { + return true, nil + } + } + + return false, nil +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 957008714e5..9119a3dec72 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return nil, err } - if anyGroupHasPeers(account, newNSGroup.Groups) { + if am.anyGroupHasPeers(account, newNSGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if anyGroupHasPeers(account, nsGroup.Groups) { + if am.anyGroupHasPeers(account, nsGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) @@ -279,9 +279,9 @@ func validateDomain(domain string) error { } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { +func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false } - return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) + return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) } diff --git a/management/server/peer.go b/management/server/peer.go index 33f27d8c7e0..873b460ebae 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -613,7 +613,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, newPeer) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) + if err != nil { + return nil, nil, nil, err + } + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil @@ -695,7 +699,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -868,7 +876,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -1021,7 +1033,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks := am.getPeerPostureChecks(account, p) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + return + } + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) diff --git a/management/server/policy.go b/management/server/policy.go index 8a5733f011c..c7872591d5e 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if anyGroupHasPeers(account, policy.ruleGroups()) { + if am.anyGroupHasPeers(account, policy.ruleGroups()) { am.updateAccountPeers(ctx, accountID) } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if !policyToSave.Enabled && !oldPolicy.Enabled { return false, nil } - updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) return updateAccountPeers, nil } @@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil + return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/route.go b/management/server/route.go index dcf2cb0d32c..ecb562645e6 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - if isRouteChangeAffectPeers(account, &newRoute) { + if am.isRouteChangeAffectPeers(account, &newRoute) { am.updateAccountPeers(ctx, accountID) } @@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { am.updateAccountPeers(ctx, accountID) } @@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - if isRouteChangeAffectPeers(account, routy) { + if am.isRouteChangeAffectPeers(account, routy) { am.updateAccountPeers(ctx, accountID) } @@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } From 601d429d8299302026e775a303792f38364eecaa Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 16:26:12 +0300 Subject: [PATCH 17/60] fix tests Signed-off-by: bcmmbaga --- management/server/http/posture_checks_handler_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 02f0f0d8308..f400cec8154 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return p, nil }, - SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint } - return nil + return postureChecks, nil }, DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] From 664d1388aab5f283684b09ad0e47560dfab48df7 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:29:59 +0300 Subject: [PATCH 18/60] fix merge Signed-off-by: bcmmbaga --- management/server/sql_store.go | 22 ++++++++++++---------- management/server/status/error.go | 1 - 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 730fb990059..502a83f2e32 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -33,12 +33,13 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" - keyQueryCondition = "key = ?" - accountAndIDQueryCondition = "account_id = ? and id = ?" - accountIDCondition = "account_id = ?" - peerNotFoundFMT = "peer %s not found" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + accountAndIDsQueryCondition = "account_id = ? AND id IN ?" + accountIDCondition = "account_id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -1095,10 +1096,11 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) + log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to increment network serial count in store") } return nil } @@ -1213,7 +1215,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // GetGroupsByIDs retrieves groups by their IDs and account ID. func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") @@ -1256,7 +1258,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) diff --git a/management/server/status/error.go b/management/server/status/error.go index 6957a7e0558..db6e4c2fb5a 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,7 +3,6 @@ package status import ( "errors" "fmt" - "time" ) const ( From ab00c41dada6f97d13a3f53f3937071b21621c90 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:38:24 +0300 Subject: [PATCH 19/60] fix sonar Signed-off-by: bcmmbaga --- management/server/group.go | 23 +++++++++++++++++++---- management/server/group/group.go | 6 ++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index c49bb247186..1afb8f3c5e9 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -89,6 +89,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var eventsToStore []func() var groupsToSave []*nbgroup.Group var updateAccountPeers bool @@ -213,6 +217,10 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var group *nbgroup.Group err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { @@ -260,6 +268,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var allErrors error var groupIDsToDelete []string var deletedGroups []*nbgroup.Group @@ -438,6 +450,11 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return &GroupLinkError{"user", linkedUser.Id} } + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err @@ -452,10 +469,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return err } - if settings.Extra != nil { - if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} } return nil diff --git a/management/server/group/group.go b/management/server/group/group.go index bb0f5b7b6e2..24c60d3ceef 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -55,8 +55,7 @@ func (g *Group) IsGroupAll() bool { return g.Name == "All" } -// AddPeer adds peerID to Peers if not already present, -// returning true if added. +// AddPeer adds peerID to Peers if not present, returning true if added. func (g *Group) AddPeer(peerID string) bool { if peerID == "" { return false @@ -72,8 +71,7 @@ func (g *Group) AddPeer(peerID string) bool { return true } -// RemovePeer removes peerID from Peers if present, -// returning true if removed. +// RemovePeer removes peerID from Peers if present, returning true if removed. func (g *Group) RemovePeer(peerID string) bool { for i, itemID := range g.Peers { if itemID == peerID { From 113c21b0e1b56f2c2cc9dd9dc2fd8b3a74365632 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:57:24 +0300 Subject: [PATCH 20/60] Change setup key log level to debug for missing group Signed-off-by: bcmmbaga --- management/server/setupkey.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 554c66ba4fc..f055d877fe2 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -449,14 +449,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran modifiedGroups := slices.Concat(addedGroups, removedGroups) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) if err != nil { - log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil } for _, g := range removedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g) continue } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran for _, g := range addedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g) continue } From d23b5c892b923ad5b3a8f45368da13de26ef64ea Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:58:22 +0300 Subject: [PATCH 21/60] Retrieve modified peers once for group events Signed-off-by: bcmmbaga --- management/server/group.go | 35 ++++++++++++++++++++-------------- management/server/sql_store.go | 17 +++++++++++++++++ management/server/store.go | 1 + 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 1afb8f3c5e9..57960e7f94a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -156,34 +156,41 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac }) } + modifiedPeers := slices.Concat(addedPeers, removedPeers) + peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) + return nil + } + for _, peerID := range addedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } for _, peerID := range removedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 502a83f2e32..7c741d35c8e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1095,6 +1095,23 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength return peer, nil } +// GetPeersByIDs retrieves peers by their IDs and account ID. +func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") + } + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return peersMap, nil +} + func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) diff --git a/management/server/store.go b/management/server/store.go index 2a0c44c678d..71b0d457b4c 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -93,6 +93,7 @@ type Store interface { GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error From 0c0fd380bd0a5a6deda52ebf9d903aee36da14e0 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 11:17:16 +0300 Subject: [PATCH 22/60] Refactor policy get and save account to use store methods Signed-off-by: bcmmbaga --- management/server/account.go | 2 +- management/server/account_test.go | 57 ++-- management/server/http/policies_handler.go | 19 +- management/server/mock_server/account_mock.go | 8 +- management/server/policy.go | 275 +++++++++++------- management/server/sql_store.go | 63 +++- management/server/status/error.go | 5 + management/server/store.go | 5 +- 8 files changed, 283 insertions(+), 151 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 8ebbb0fa0a0..114489c3453 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 97e0d45f016..c8c2d59410b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1238,8 +1238,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - policy := Policy{ - ID: "policy", + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1250,8 +1249,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1320,19 +1318,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - policy := Policy{ - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1345,7 +1330,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }) + if err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1366,7 +1363,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - policy := Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1377,9 +1374,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } @@ -1421,7 +1417,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { require.NoError(t, err, "failed to save group") - policy := Policy{ + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1432,14 +1433,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 73f3803b5ed..8255e489648 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,10 +6,8 @@ import ( "strconv" "github.com/gorilla/mux" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -122,14 +120,9 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - isUpdate := policyID != "" - - if policyID == "" { - policyID = xid.New().String() - } - - policy := server.Policy{ + policy := &server.Policy{ ID: policyID, + AccountID: accountID, Name: req.Name, Enabled: req.Enabled, Description: req.Description, @@ -137,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID for _, rule := range req.Rules { pr := server.PolicyRule{ ID: policyID, // TODO: when policy can contain multiple rules, need refactor + PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, Sources: rule.Sources, @@ -225,7 +219,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { + policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy) + if err != nil { util.WriteError(r.Context(), err, w) return } @@ -236,7 +231,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - resp := toPolicyResponse(allGroups, &policy) + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 673ed33bb9b..46a4fbc1faf 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -49,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } - return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface diff --git a/management/server/policy.go b/management/server/policy.go index c7872591d5e..eb44a0436ee 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,13 +3,13 @@ package server import ( "context" _ "embed" - "slices" "strconv" "strings" + "github.com/netbirdio/netbird/management/proto" + "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -125,6 +125,7 @@ type PolicyRule struct { func (pm *PolicyRule) Copy() *PolicyRule { rule := &PolicyRule{ ID: pm.ID, + PolicyID: pm.PolicyID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, @@ -171,6 +172,7 @@ type Policy struct { func (p *Policy) Copy() *Policy { c := &Policy{ ID: p.ID, + AccountID: p.AccountID, Name: p.Name, Description: p.Description, Enabled: p.Enabled, @@ -343,44 +345,72 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) - if err != nil { - return err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - action := activity.PolicyAdded - if isUpdate { - action = activity.PolicyUpdated + var isUpdate = policy.ID != "" + var updateAccountPeers bool + var action = activity.PolicyAdded + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + saveFunc := transaction.CreatePolicy + if isUpdate { + action = activity.PolicyUpdated + saveFunc = transaction.SavePolicy + } + + return saveFunc(ctx, LockingStrengthUpdate, policy) + }) + if err != nil { + return nil, err } + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - return nil + return policy, nil } // DeletePolicy from the store @@ -388,112 +418,136 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - policy, err := am.deletePolicy(account, policyID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + var policy *Policy + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + if err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + }) + if err != nil { return err } - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if am.anyGroupHasPeers(account, policy.ruleGroups()) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } return nil } -// ListPolicies from the store +// ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break +// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { + if isUpdate { + existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return false, err } - } - if policyIdx < 0 { - return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID) - } - policy := account.Policies[policyIdx] - account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...) - return policy, nil -} - -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } + if !policy.Enabled && !existingPolicy.Enabled { + return false, nil + } - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } + hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + if err != nil { + return false, err + } - if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + if hasPeers { + return true, nil } - oldPolicy := account.Policies[policyIdx] - // Update the existing policy - account.Policies[policyIdx] = policyToSave + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) + } - if !policyToSave.Enabled && !oldPolicy.Enabled { - return false, nil + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) +} + +// validatePolicy validates the policy and its rules. +func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { + if policy.ID != "" { + _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return err } - updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) + } else { + policy.ID = xid.New().String() + policy.AccountID = accountID + } - return updateAccountPeers, nil + groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) + postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } - return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil -} + for i, rule := range policy.Rules { + ruleCopy := rule.Copy() + if ruleCopy.ID == "" { + ruleCopy.ID = xid.New().String() + ruleCopy.PolicyID = policy.ID + } -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] + ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources) + ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations) + policy.Rules[i] = ruleCopy + } - result[i] = &proto.FirewallRule{ - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, - } + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - return result + + return nil } // getAllPeersFromGroups for given peer ID and list of groups @@ -574,27 +628,52 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { + validPostureCheckIDs := make(map[string]struct{}) + for _, check := range postureChecks { + validPostureCheckIDs[check.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := validPostureCheckIDs[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs +} + +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { + validGroupIDs := make(map[string]struct{}) + for _, group := range groups { + validGroupIDs[group.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := validGroupIDs[id]; exists { + validIDs = append(validIDs, id) + } + } + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + result[i] = &proto.FirewallRule{ + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 81dc704c213..2cd7ac7fd37 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1286,12 +1286,67 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) + var policies []*Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policies from store") + } + + return policies, nil } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { + var policy *Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + First(&policy, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewPolicyNotFoundError(policyID) + } + log.WithContext(ctx).Errorf("failed to get policy from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get policy from store") + } + + return policy, nil +} + +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create policy in store") + } + + return nil +} + +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) + return status.Errorf(status.Internal, "failed to save policy to store") + } + return nil +} + +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil } // GetAccountPostureChecks retrieves posture checks for an account. @@ -1324,7 +1379,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) if result.Error != nil { if errors.Is(result.Error, gorm.ErrDuplicatedKey) { return status.Errorf(status.InvalidArgument, "name should be unique") diff --git a/management/server/status/error.go b/management/server/status/error.go index ba9e01c4fd7..bef1f5143a5 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -139,3 +139,8 @@ func NewGroupNotFoundError(groupID string) error { func NewPostureChecksNotFoundError(postureChecksID string) error { return Errorf(NotFound, "posture checks: %s not found", postureChecksID) } + +// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy +func NewPolicyNotFoundError(policyID string) error { + return Errorf(NotFound, "policy: %s not found", policyID) +} diff --git a/management/server/store.go b/management/server/store.go index 03b5821e7a1..108b262b171 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -80,7 +80,10 @@ type Store interface { DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) From 2d7f08c6099a4e264413054dda46ee68ee01a8c9 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 11:18:16 +0300 Subject: [PATCH 23/60] Fix tests Signed-off-by: bcmmbaga --- management/server/group_test.go | 5 +- .../server/http/policies_handler_test.go | 4 +- management/server/peer_test.go | 31 ++-- management/server/policy_test.go | 165 ++++++------------ management/server/posture_checks_test.go | 43 ++--- management/server/route_test.go | 3 +- management/server/setupkey_test.go | 5 +- management/server/user_test.go | 5 +- 8 files changed, 93 insertions(+), 168 deletions(-) diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e81927..0515b9698ee 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) assert.NoError(t, err) // Saving a group linked to policy should update account peers and send peer update diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 228ebcbceef..f8a897eb27b 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } - return nil + return policy, nil }, GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4e2dcb2c313..e410fa8923e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { var ( group1 nbgroup.Group group2 nbgroup.Group - policy Policy ) group1.ID = xid.New().String() group2.ID = xid.New().String() group1.Name = "src" group2.Name = "dst" - policy.ID = xid.New().String() group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) @@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy.Name = "test" - policy.Enabled = true - policy.Rules = []*PolicyRule{ - { - Enabled: true, - Sources: []string{group1.ID}, - Destinations: []string{group2.ID}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, + policy := &Policy{ + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{group1.ID}, + Destinations: []string{group2.ID}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -1445,8 +1445,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1457,7 +1456,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) require.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index e7f0f9cd2f1..62d80f46e7f 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) + var policyWithGroupRulesNoPeers *Policy + var policyWithDestinationPeersOnly *Policy + var policyWithSourceAndDestinationPeers *Policy + // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-rule-groups-no-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with source group containing peers, but destination group without peers should // update account's peers and send peer update t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-has-peers-destination-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, @@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-destination-has-peers-source-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), - Enabled: false, + Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, @@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, @@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Disabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = false + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Updating disabled policy with destination and source groups containing peers should not update account's peers // or send peer update t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Description: "updated description", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Description = "updated description" + policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"} + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Enabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = true + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { - policyID := "policy-source-destination-peers" - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID) assert.NoError(t, err) select { @@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { - policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID) assert.NoError(t, err) select { @@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID) assert.NoError(t, err) select { diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 3c5c5fc79e6..93e5741cf28 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -210,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := Policy{ - ID: "policyA", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -234,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -282,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }() policy.SourcePostureChecks = []string{} - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -316,12 +312,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -330,8 +324,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -362,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - policy = Policy{ - ID: "policyB", + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, @@ -376,9 +368,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -405,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -418,8 +407,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -490,12 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { require.NoError(t, err, "failed to save postureCheckB") policy := &Policy{ - ID: "policyA", AccountID: account.Id, Rules: []*PolicyRule{ { - ID: "ruleA", - PolicyID: "policyA", Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -504,7 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { SourcePostureChecks: []string{postureCheckA.ID}, } - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to save policy") t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { @@ -528,7 +513,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupA"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -539,7 +524,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupB"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -560,7 +545,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { policy.Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c848f68c7b..108f791e02c 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { defaultRule := rules[0] newPolicy := defaultRule.Copy() - newPolicy.ID = xid.New().String() newPolicy.Name = "peer1 only" newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) + _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 2ed8aef95c6..ea239ec0c63 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -390,8 +390,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -403,7 +402,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/user_test.go b/management/server/user_test.go index d4f560a54c7..498017afa1d 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) require.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) From 2806d7316100fb59f6498ebd6aae6975faf62477 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 13:38:34 +0300 Subject: [PATCH 24/60] Add tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 277 +++++++++++++++++++++++++++- 1 file changed, 274 insertions(+), 3 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 20409798b0e..114da1ee6f6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" route2 "github.com/netbirdio/netbird/route" @@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } + +func TestSqlStore_GetGroupsByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectedCount int + }{ + { + name: "retrieve existing groups by existing IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectedCount: 2, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing group IDs", + groupIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing group IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SaveGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + } + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err) + + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + require.NoError(t, err) + require.Equal(t, savedGroup, group) +} + +func TestSqlStore_SaveGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*nbgroup.Group{ + { + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + }, + { + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{"peer3", "peer4"}, + }, + } + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) +} + +func TestSqlStore_DeleteGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupID string + expectError bool + }{ + { + name: "delete existing group", + groupID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "delete non-existing group", + groupID: "non-existing-group-id", + expectError: true, + }, + { + name: "delete with empty group ID", + groupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_DeleteGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectError bool + }{ + { + name: "delete multiple existing groups", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectError: false, + }, + { + name: "delete non-existing groups", + groupIDs: []string{"non-existing-id-1", "non-existing-id-2"}, + expectError: false, + }, + { + name: "delete with empty group IDs list", + groupIDs: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + for _, groupID := range tt.groupIDs { + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.Error(t, err) + require.Nil(t, group) + } + } + }) + } +} + +func TestSqlStore_GetPeerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerID string + expectError bool + }{ + { + name: "retrieve existing peer", + peerID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "retrieve non-existing peer", + peerID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty peer ID", + peerID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.peerID, peer.ID) + } + }) + } +} + +func TestSqlStore_GetPeersByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerIDs []string + expectedCount int + }{ + { + name: "retrieve existing peers by existing IDs", + peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"}, + expectedCount: 2, + }, + { + name: "empty peer IDs list", + peerIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing peer IDs", + peerIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing peer IDs", + peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} From a3abc211b3ceee7da582721a6917bf46d23ce19e Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 17:11:56 +0300 Subject: [PATCH 25/60] Add tests Signed-off-by: bcmmbaga --- management/server/sql_store.go | 5 +- management/server/sql_store_test.go | 135 ++++++++++++++++++ management/server/testdata/extended-store.sql | 1 + 3 files changed, 137 insertions(+), 4 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 81dc704c213..f971f830088 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1324,11 +1324,8 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) if result.Error != nil { - if errors.Is(result.Error, gorm.ErrDuplicatedKey) { - return status.Errorf(status.InvalidArgument, "name should be unique") - } log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save posture checks to store") } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 114da1ee6f6..94c4da6a82c 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1564,3 +1565,137 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { }) } } + +func TestSqlStore_GetPostureChecksByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "retrieve existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "retrieve non-existing posture checks", + postureChecksID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.postureChecksID, peer.ID) + } + }) + } +} + +func TestSqlStore_SavePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + postureChecks := &posture.Checks{ + ID: "posture-checks-id", + AccountID: accountID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.31.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "13.0.1", + }, + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "5.3.3-dev", + }, + }, + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + Action: posture.CheckActionAllow, + }, + }, + } + err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks) + require.NoError(t, err) + + savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id") + require.NoError(t, err) + require.Equal(t, savePostureChecks, postureChecks) +} + +func TestSqlStore_DeletePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "delete existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "delete non-existing posture checks", + postureChecksID: "non-existing-posture-checks-id", + expectError: true, + }, + { + name: "delete with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index b522741e7e0..1646ff4da6c 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -34,4 +34,5 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO installations VALUES(1,''); From 32d1b2d60210ce39064c6a25f8c81f1b81615d59 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 18:53:10 +0300 Subject: [PATCH 26/60] Retrieve policy groups and posture checks once for validation Signed-off-by: bcmmbaga --- management/server/policy.go | 22 ++++++---------------- management/server/sql_store.go | 21 +++++++++++++++++++-- management/server/store.go | 1 + 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/management/server/policy.go b/management/server/policy.go index eb44a0436ee..6dcb963162b 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -521,12 +521,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po policy.AccountID = accountID } - groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) if err != nil { return err } - postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) if err != nil { return err } @@ -629,15 +629,10 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. -func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { - validPostureCheckIDs := make(map[string]struct{}) - for _, check := range postureChecks { - validPostureCheckIDs[check.ID] = struct{}{} - } - +func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - if _, exists := validPostureCheckIDs[id]; exists { + if _, exists := postureChecks[id]; exists { validIDs = append(validIDs, id) } } @@ -646,15 +641,10 @@ func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds [ } // getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { - validGroupIDs := make(map[string]struct{}) - for _, group := range groups { - validGroupIDs[group.ID] = struct{}{} - } - +func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { validIDs := make([]string, 0, len(groupIDs)) for _, id := range groupIDs { - if _, exists := validGroupIDs[id]; exists { + if _, exists := groups[id]; exists { validIDs = append(validIDs, id) } } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index e7a2e50d874..a4191de9f2c 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1234,8 +1234,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren var groups []*nbgroup.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } groupsMap := make(map[string]*nbgroup.Group) @@ -1377,6 +1377,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin return postureCheck, nil } +// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. +func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store") + } + + postureChecksMap := make(map[string]*posture.Checks) + for _, postureCheck := range postureChecks { + postureChecksMap[postureCheck.ID] = postureCheck + } + + return postureChecksMap, nil +} + // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) diff --git a/management/server/store.go b/management/server/store.go index 108b262b171..ba61d552d72 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -88,6 +88,7 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error From bbaee18cd56cbb5c35cbeda907c9ed4a05d6b482 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 19:05:57 +0300 Subject: [PATCH 27/60] Fix typo Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 94c4da6a82c..de939e8d0e9 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1596,17 +1596,17 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peer, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) require.True(t, ok) require.Equal(t, sErr.Type(), status.NotFound) - require.Nil(t, peer) + require.Nil(t, postureChecks) } else { require.NoError(t, err) - require.NotNil(t, peer) - require.Equal(t, tt.postureChecksID, peer.ID) + require.NotNil(t, postureChecks) + require.Equal(t, tt.postureChecksID, postureChecks.ID) } }) } From 3a915decd7c5d136ca9128723fc8d692db791421 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 20:15:47 +0300 Subject: [PATCH 28/60] Add policy tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 156 ++++++++++++++++++ management/server/testdata/extended-store.sql | 1 + 2 files changed, 157 insertions(+) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index de939e8d0e9..8931008d7ff 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1612,6 +1612,49 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } } +func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureCheckIDs []string + expectedCount int + }{ + { + name: "retrieve existing posture checks by existing IDs", + postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"}, + expectedCount: 2, + }, + { + name: "empty posture check IDs list", + postureCheckIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing posture check IDs", + postureCheckIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing posture check IDs", + postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + func TestSqlStore_SavePostureChecks(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) @@ -1699,3 +1742,116 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { }) } } + +func TestSqlStore_GetPolicyByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + policyID string + expectError bool + }{ + { + name: "retrieve existing policy", + policyID: "cs1tnh0hhcjnqoiuebf0", + expectError: false, + }, + { + name: "retrieve non-existing policy checks", + policyID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty policy ID", + policyID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, policy) + } else { + require.NoError(t, err) + require.NotNil(t, policy) + require.Equal(t, tt.policyID, policy.ID) + } + }) + } +} + +func TestSqlStore_CreatePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + policy := &Policy{ + ID: "policy-id", + AccountID: accountID, + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) + +} + +func TestSqlStore_SavePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy.Enabled = false + policy.Description = "policy" + err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) +} + +func TestSqlStore_DeletePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.Error(t, err) + require.Nil(t, policy) +} diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 1646ff4da6c..37db2731625 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -35,4 +35,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); +INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO installations VALUES(1,''); From 9872bee41db34da97e075c2f4491b4db9d57d9b8 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 23:53:29 +0300 Subject: [PATCH 29/60] Refactor anyGroupHasPeers to retrieve all groups once Signed-off-by: bcmmbaga --- management/server/group.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 5d301416902..758b28b760d 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -609,12 +609,12 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { - for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return false, err - } + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) + if err != nil { + return false, err + } + for _, group := range groups { if group.HasPeers() { return true, nil } From 560190519d3de176b2cdaea5b1bb9dde5b6a46c8 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 13:15:47 +0300 Subject: [PATCH 30/60] Refactor dns settings to use store methods Signed-off-by: bcmmbaga --- management/server/dns.go | 142 +++++++++++++++++++++++++-------- management/server/sql_store.go | 21 ++++- management/server/store.go | 1 + 3 files changed, 130 insertions(+), 34 deletions(-) diff --git a/management/server/dns.go b/management/server/dns.go index e52be601639..be7caea4eff 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "slices" "strconv" "sync" @@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) @@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + if dnsSettingsToSave == nil { + return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") + } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") + return status.NewAdminPermissionError() } - if dnsSettingsToSave == nil { - return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") - } + var updateAccountPeers bool + var eventsToStore []func() - if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { - err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { + return err + } + + oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) if err != nil { return err } - } - oldSettings := account.DNSSettings.Copy() - account.DNSSettings = dnsSettingsToSave.Copy() + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) + if err != nil { + return err + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } + events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) + eventsToStore = append(eventsToStore, events...) + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - for _, id := range addedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + }) + if err != nil { + return err } - for _, id := range removedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + for _, storeEvent := range eventsToStore { + storeEvent() } - if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } return nil } +// prepareGroupEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { + var eventsToStore []func() + + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) + return nil + } + + for _, groupID := range addedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + }) + + } + + for _, groupID := range removedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + }) + } + + return eventsToStore +} + +// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) +} + +// validateDNSSettings validates the DNS settings. +func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { + if len(settings.DisabledManagementGroups) == 0 { + return nil + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + if err != nil { + return err + } + + return validateGroups(settings.DisabledManagementGroups, groups) +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a4191de9f2c..4289bfc0180 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1153,9 +1153,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "dns settings not found") + return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get dns settings from store") } return &accountDNSSettings.DNSSettings, nil } @@ -1528,3 +1529,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } + +// SaveDNSSettings saves the DNS settings to the store. +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save dns settings to store") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} diff --git a/management/server/store.go b/management/server/store.go index ba61d552d72..cca014b5214 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -59,6 +59,7 @@ type Store interface { SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) From 4b943c34b7a7dfc1232204b297e388a2b60ef48b Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 13:16:32 +0300 Subject: [PATCH 31/60] Add tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 63 +++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 8931008d7ff..9a40739e665 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1855,3 +1855,66 @@ func TestSqlStore_DeletePolicy(t *testing.T) { require.Error(t, err) require.Nil(t, policy) } + +func TestSqlStore_GetDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectError bool + }{ + { + name: "retrieve existing account dns settings", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectError: false, + }, + { + name: "retrieve non-existing account dns settings", + accountID: "non-existing", + expectError: true, + }, + { + name: "retrieve dns settings with empty account ID", + accountID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, dnsSettings) + } else { + require.NoError(t, err) + require.NotNil(t, dnsSettings) + } + }) + } +} + +func TestSqlStore_SaveDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + + dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"} + err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings) + require.NoError(t, err) + + saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, saveDNSSettings, dnsSettings) +} From ed047ec9dda048120edf4f074162a27136ac3cd6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:16:30 +0300 Subject: [PATCH 32/60] Add account locking and merge group deletion methods Signed-off-by: bcmmbaga --- management/server/group.go | 66 ++++++++++------------------------ management/server/sql_store.go | 2 +- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 57960e7f94a..154a33b1350 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -215,48 +215,9 @@ func difference(a, b []string) []string { // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return err - } - - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() - } - - if user.IsRegularUser() { - return status.NewAdminPermissionError() - } - - var group *nbgroup.Group - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - if group.IsGroupAll() { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { - return err - } - - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return err - } - - return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) - }) - if err != nil { - return err - } - - am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) - - return nil + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -285,13 +246,14 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) if err != nil { + allErrors = errors.Join(allErrors, err) continue } if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + allErrors = errors.Join(allErrors, err) continue } @@ -318,12 +280,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -356,12 +321,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -430,13 +398,17 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. if group.Issued == nbgroup.GroupIssuedIntegration { executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return status.Errorf(status.NotFound, "user not found") + return err } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 7c741d35c8e..0ebda6440c1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1278,7 +1278,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) - return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") } return nil From a4d905ffe77881b682a4798d5564b89860404a0a Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:56:22 +0300 Subject: [PATCH 33/60] Fix tests Signed-off-by: bcmmbaga --- management/server/group_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e81927..59094a23e92 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", From 218345e0ffa6c209505ee5a1893c7d119fadfc4c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 20:41:30 +0300 Subject: [PATCH 34/60] Refactor name server groups to use store methods Signed-off-by: bcmmbaga --- management/server/nameserver.go | 201 +++++++++++++++++++----------- management/server/sql_store.go | 49 +++++++- management/server/status/error.go | 5 + management/server/store.go | 2 + 4 files changed, 181 insertions(+), 76 deletions(-) diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 9119a3dec72..e7a5387a142 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + newNSGroup := &nbdns.NameServerGroup{ ID: xid.New().String(), + AccountID: accountID, Name: name, Description: description, NameServers: nameServerList, @@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco SearchDomainsEnabled: searchDomainEnabled, } - err = validateNameServerGroup(false, newNSGroup, account) - if err != nil { - return nil, err - } + var updateAccountPeers bool - if account.NameServerGroups == nil { - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup) - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { + return err + } - account.NameServerGroups[newNSGroup.ID] = newNSGroup + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + if err != nil { + return err + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + }) + if err != nil { return nil, err } - if am.anyGroupHasPeers(account, newNSGroup.Groups) { + am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil } @@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - err = validateNameServerGroup(true, nsGroupToSave, account) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] - account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + if err != nil { + return err + } + nsGroupToSave.AccountID = accountID + + if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil { + return err + } + + updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + }) + if err != nil { return err } - if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil } // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - nsGroup := account.NameServerGroups[nsGroupID] - if nsGroup == nil { - return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - delete(account.NameServerGroups, nsGroupID) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + var nsGroup *nbdns.NameServerGroup + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + if err != nil { + return err + } + + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + }) + if err != nil { return err } - if am.anyGroupHasPeers(account, nsGroup.Groups) { + am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil } @@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } -func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { - nsGroupID := "" - if existingGroup { - nsGroupID = nameserverGroup.ID - _, found := account.NameServerGroups[nsGroupID] - if !found { - return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID) - } +func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { + err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) + if err != nil { + return err } - err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) + err = validateNSList(nameserverGroup.NameServers) if err != nil { return err } - err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { return err } - err = validateNSList(nameserverGroup.NameServers) + err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups) if err != nil { return err } - err = validateGroups(nameserverGroup.Groups, account.Groups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) if err != nil { return err } - return nil + return validateGroups(nameserverGroup.Groups, groups) +} + +// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { + if !newNSGroup.Enabled && !oldNSGroup.Enabled { + return false, nil + } + + hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { @@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo return nil } -func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { +func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) } - for _, nsGroup := range nsGroupMap { + for _, nsGroup := range groups { if name == nsGroup.Name && nsGroup.ID != nsGroupID { - return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) + return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name) } } @@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na } func validateNSList(list []nbdns.NameServer) error { - nsListLenght := len(list) - if nsListLenght == 0 || nsListLenght > 3 { + nsListLength := len(list) + if nsListLength == 0 || nsListLength > 3 { return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) } return nil @@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if id == "" { return status.Errorf(status.InvalidArgument, "group ID should not be empty string") } - found := false - for groupID := range groups { - if id == groupID { - found = true - break - } - } - if !found { + if _, found := groups[id]; !found { return status.Errorf(status.InvalidArgument, "group id %s not found", id) } } @@ -277,11 +340,3 @@ func validateDomain(domain string) error { return nil } - -// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { - if !newNSGroup.Enabled && !oldNSGroup.Enabled { - return false - } - return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) -} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 4289bfc0180..2f951cd2e1f 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1489,12 +1489,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID) + var nsGroups []*nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server groups from store") + } + + return nsGroups, nil } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. -func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { - return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { + var nsGroup *nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewNameServerGroupNotFoundError(nsGroupID) + } + log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server group from store") + } + + return nsGroup, nil +} + +// SaveNameServerGroup saves a name server group to the database. +func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) + return status.Errorf(status.Internal, "failed to save name server group to store") + } + return nil +} + +// DeleteNameServerGroup deletes a name server group from the database. +func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete name server group from store") + } + + if result.RowsAffected == 0 { + return status.NewNameServerGroupNotFoundError(nsGroupID) + } + + return nil } // getRecords retrieves records from the database based on the account ID. diff --git a/management/server/status/error.go b/management/server/status/error.go index 0fff5355994..59f436f5b19 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -149,3 +149,8 @@ func NewPostureChecksNotFoundError(postureChecksID string) error { func NewPolicyNotFoundError(policyID string) error { return Errorf(NotFound, "policy: %s not found", policyID) } + +// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group +func NewNameServerGroupNotFoundError(nsGroupID string) error { + return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) +} diff --git a/management/server/store.go b/management/server/store.go index cca014b5214..b16ad8a1aa4 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -117,6 +117,8 @@ type Store interface { GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error From ef55b9eccc22bf01743b5476738d85bef70e1688 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 20:41:41 +0300 Subject: [PATCH 35/60] Add tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 131 ++++++++++++++++++ management/server/testdata/extended-store.sql | 1 + 2 files changed, 132 insertions(+) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 9a40739e665..b568b7fe03a 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1918,3 +1918,134 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { require.NoError(t, err) require.Equal(t, saveDNSSettings, dnsSettings) } + +func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve name server groups by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetNameServerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + nsGroupID string + expectError bool + }{ + { + name: "retrieve existing nameserver group", + nsGroupID: "csqdelq7qv97ncu7d9t0", + expectError: false, + }, + { + name: "retrieve non-existing nameserver group", + nsGroupID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty nameserver group ID", + nsGroupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, nsGroup) + } else { + require.NoError(t, err) + require.NotNil(t, nsGroup) + require.Equal(t, tt.nsGroupID, nsGroup.ID) + } + }) + } +} + +func TestSqlStore_SaveNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + nsGroup := &nbdns.NameServerGroup{ + ID: "ns-group-id", + AccountID: accountID, + Name: "NS Group", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: 1, + Port: 53, + }, + }, + Groups: []string{"groupA"}, + Primary: true, + Enabled: true, + SearchDomainsEnabled: false, + } + + err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup) + require.NoError(t, err) + + saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID) + require.NoError(t, err) + require.Equal(t, saveNSGroup, nsGroup) +} + +func TestSqlStore_DeleteNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + nsGroupID := "csqdelq7qv97ncu7d9t0" + + err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.NoError(t, err) + + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.Error(t, err) + require.Nil(t, nsGroup) +} diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 37db2731625..455111439ea 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -36,4 +36,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); +INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO installations VALUES(1,''); From 63156440652afb5eaf8beaadfb068a4c51036ca6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 13:04:36 +0300 Subject: [PATCH 36/60] Add peer store methods Signed-off-by: bcmmbaga --- management/server/sql_store.go | 89 +++++++++++++++++++++++++++++++++- management/server/store.go | 5 ++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 2f951cd2e1f..979c7842d61 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1068,7 +1068,15 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { - return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { @@ -1112,6 +1120,85 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng return peersMap, nil } +// GetAccountPeerDNSLabels retrieves all unique DNS labels for peers associated with a specified account. +func (s *SqlStore) GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + var labels []string + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where(accountIDCondition, accountID).Pluck("dns_label", &labels) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + } + + return labels, nil +} + +// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store") + } + + return peers, nil +} + +// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store") + } + + return peers, nil +} + +// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing. +func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) { + var allEphemeralPeers, batchPeers []*nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("ephemeral = ?", true). + FindInBatches(&batchPeers, 1000, func(tx *gorm.DB, batch int) error { + allEphemeralPeers = append(allEphemeralPeers, batchPeers...) + return nil + }) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error) + return nil, fmt.Errorf("failed to retrieve ephemeral peers") + } + + return allEphemeralPeers, nil +} + +// DeletePeer removes a peer from the store. +func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete peer from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "peer not found") + } + + return nil +} + func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) diff --git a/management/server/store.go b/management/server/store.go index b16ad8a1aa4..6e49a494b66 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -94,6 +94,7 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error @@ -101,9 +102,13 @@ type Store interface { GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) + GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error + DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error From 8420a525633ce9542be4257e93cadc6ed933b7da Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 13:04:49 +0300 Subject: [PATCH 37/60] Refactor ephemeral peers Signed-off-by: bcmmbaga --- management/server/ephemeral.go | 44 ++++++++++++++--------------- management/server/ephemeral_test.go | 18 +++++------- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 590b1d708bc..6e245ec5ac8 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -20,10 +20,10 @@ var ( ) type ephemeralPeer struct { - id string - account *Account - deadline time.Time - next *ephemeralPeer + id string + accountID string + deadline time.Time + next *ephemeralPeer } // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it @@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) - a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err) - return - } - e.peersLock.Lock() defer e.peersLock.Unlock() @@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. return } - e.addPeer(peer.ID, a, newDeadLine()) + e.addPeer(peer.AccountID, peer.ID, newDeadLine()) if e.timer == nil { e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { e.cleanup(ctx) @@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - accounts := e.store.GetAllAccounts(context.Background()) + peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare) + if err != nil { + log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) + return + } + t := newDeadLine() count := 0 - for _, a := range accounts { - for id, p := range a.Peers { - if p.Ephemeral { - count++ - e.addPeer(id, a, t) - } + for _, p := range peers { + if p.Ephemeral { + count++ + e.addPeer(p.AccountID, p.ID, t) } } + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) } @@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { for id, p := range deletePeers { log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) + err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) } } } -func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { +func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { ep := &ephemeralPeer{ - id: id, - account: account, - deadline: deadline, + id: peerID, + accountID: accountID, + deadline: deadline, } if e.headPeer == nil { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 1390352a5d0..00e5d777a79 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,7 +7,6 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" ) type MockStore struct { @@ -15,17 +14,14 @@ type MockStore struct { account *Account } -func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { - return []*Account{s.account} -} - -func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { - _, ok := s.account.Peers[peerId] - if ok { - return s.account, nil +func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ LockingStrength) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + for _, v := range s.account.Peers { + if v.Ephemeral { + peers = append(peers, v) + } } - - return nil, status.NewPeerNotFoundError(peerId) + return peers, nil } type MocAccountManager struct { From f5e7449d01ab7e0906b6630c596fcc0a7fd4557c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:24:51 +0300 Subject: [PATCH 38/60] Add lock for peer store methods Signed-off-by: bcmmbaga --- management/server/sql_store.go | 29 +++++++++++++++++++++-------- management/server/store.go | 9 +++++---- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 979c7842d61..b921ed47d3c 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -300,12 +300,12 @@ func (s *SqlStore) GetInstallationID() string { return installation.InstallationIDValue } -func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { +func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -355,7 +355,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID return nil } -func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { +func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -363,7 +363,7 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe "peer_status_last_seen", "peer_status_connected", "peer_status_login_expired", "peer_status_required_approval", } - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Select(fieldsToUpdate). Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) @@ -378,14 +378,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe return nil } -func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { +func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, // updating the struct ensures the correct data format is inserted into the database. peerCopy.Location = peerWithLocation.Location - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Updates(peerCopy) @@ -740,9 +740,10 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) return accountID, nil } -func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { +func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { var accountID string - result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&User{}). + Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -1066,6 +1067,18 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// GetAccountPeers retrieves peers for an account. +func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil +} + // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer diff --git a/management/server/store.go b/management/server/store.go index 6e49a494b66..9ecb9c1698f 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -48,7 +48,7 @@ type Store interface { GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(userID string) (string, error) + GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later @@ -99,15 +99,16 @@ type Store interface { AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error + SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) From 7d849a92c0ce8436f52f1ea721a62b6d7f5534d3 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:32:34 +0300 Subject: [PATCH 39/60] Refactor peer handlers Signed-off-by: bcmmbaga --- management/server/http/peers_handler.go | 95 ++++++------- management/server/http/peers_handler_test.go | 141 +++++++++++-------- 2 files changed, 126 insertions(+), 110 deletions(-) diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index a5856a0e43c..235e744b351 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) +func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { + peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { util.WriteError(ctx, err, w) return @@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee } dnsDomain := h.accountManager.GetDNSDomain() - groupsInfo := toGroupsInfo(account.Groups, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + if err != nil { + util.WriteError(ctx, err, w) + return + } + groupsInfo := toGroupsInfo(peerGroups) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } @@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) + peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) if err != nil { util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + if err != nil { + util.WriteError(ctx, err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { case http.MethodDelete: h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodGet, http.MethodPut: - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if r.Method == http.MethodGet { - h.getPeer(r.Context(), account, peerID, userID, w) - } else { - h.updatePeer(r.Context(), account, userID, peerID, w, r) - } + case http.MethodGet: + h.getPeer(r.Context(), accountID, peerID, userID, w) + return + case http.MethodPut: + h.updatePeer(r.Context(), accountID, userID, peerID, w, r) return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(account.Peers)) - for _, peer := range account.Peers { + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + + peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(account) + validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request } } - dnsDomain := h.accountManager.GetDNSDomain() - - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) + dnsDomain := h.accountManager.GetDNSDomain() + + customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) @@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - var groupsInfo []api.GroupMinimum - groupsChecked := make(map[string]struct{}) +func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum { + groupsInfo := make([]api.GroupMinimum, 0, len(groups)) for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } + groupsInfo = append(groupsInfo, api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + }) } return groupsInfo } diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index dd49c03b848..9279fc5361b 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -39,6 +39,68 @@ const ( ) func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer.Copy() + } + + policy := &server.Policy{ + ID: "policy", + AccountID: "test_id", + Name: "policy", + Enabled: true, + Rules: []*server.PolicyRule{ + { + ID: "rule", + Name: "rule", + Enabled: true, + Action: "accept", + Destinations: []string{"group1"}, + Sources: []string{"group1"}, + Bidirectional: true, + Protocol: "all", + Ports: []string{"80"}, + }, + }, + } + + srvUser := server.NewRegularUser(serviceUser) + srvUser.IsServiceUser = true + + account := &server.Account{ + Id: "test_id", + Domain: "hotmail.com", + Peers: peersMap, + Users: map[string]*server.User{ + adminUser: server.NewAdminUser(adminUser), + regularUser: server.NewRegularUser(regularUser), + serviceUser: srvUser, + }, + Groups: map[string]*nbgroup.Group{ + "group1": { + ID: "group1", + AccountID: "test_id", + Name: "group1", + Issued: "api", + Peers: maps.Keys(peersMap), + }, + }, + Settings: &server.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour, + }, + Policies: []*server.Policy{policy}, + Network: &server.Network{ + Identifier: "ciclqisab2ss43jdn8q0", + Net: net.IPNet{ + IP: net.ParseIP("100.67.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + Serial: 51, + }, + } + return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -67,74 +129,31 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + peersID := make([]string, len(peers)) + for _, peer := range peers { + peersID = append(peersID, peer.ID) + } + return []*nbgroup.Group{ + { + ID: "group1", + AccountID: accountID, + Name: "group1", + Issued: "api", + Peers: peersID, + }, + }, nil + }, GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, + GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) { + return account, nil + }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - peersMap := make(map[string]*nbpeer.Peer) - for _, peer := range peers { - peersMap[peer.ID] = peer.Copy() - } - - policy := &server.Policy{ - ID: "policy", - AccountID: accountID, - Name: "policy", - Enabled: true, - Rules: []*server.PolicyRule{ - { - ID: "rule", - Name: "rule", - Enabled: true, - Action: "accept", - Destinations: []string{"group1"}, - Sources: []string{"group1"}, - Bidirectional: true, - Protocol: "all", - Ports: []string{"80"}, - }, - }, - } - - srvUser := server.NewRegularUser(serviceUser) - srvUser.IsServiceUser = true - - account := &server.Account{ - Id: accountID, - Domain: "hotmail.com", - Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), - serviceUser: srvUser, - }, - Groups: map[string]*nbgroup.Group{ - "group1": { - ID: "group1", - AccountID: accountID, - Name: "group1", - Issued: "api", - Peers: maps.Keys(peersMap), - }, - }, - Settings: &server.Settings{ - PeerLoginExpirationEnabled: true, - PeerLoginExpiration: time.Hour, - }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ - Identifier: "ciclqisab2ss43jdn8q0", - Net: net.IPNet{ - IP: net.ParseIP("100.67.0.0"), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - Serial: 51, - }, - } - return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { From c557c983908ca2eaf9cec3943c511332626860cf Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:33:57 +0300 Subject: [PATCH 40/60] Refactor peer to use store methods Signed-off-by: bcmmbaga --- management/server/account.go | 105 ++- management/server/integrated_validator.go | 39 +- management/server/mock_server/account_mock.go | 24 +- management/server/peer.go | 664 +++++++++++------- management/server/peer/peer.go | 2 +- management/server/user.go | 36 +- 6 files changed, 543 insertions(+), 327 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 5e9d6ebbc1e..4222179d95f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -92,7 +92,7 @@ type AccountManager interface { GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) @@ -112,6 +112,7 @@ type AccountManager interface { DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error + GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error @@ -134,7 +135,7 @@ type AccountManager interface { GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -145,7 +146,7 @@ type AccountManager interface { GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(account *Account) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error @@ -1160,17 +1161,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) + err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err } @@ -1185,21 +1186,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { event := activity.AccountPeerInactivityExpirationEnabled if !newSettings.PeerInactivityExpirationEnabled { event = activity.AccountPeerInactivityExpirationDisabled am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } return nil @@ -1207,73 +1208,64 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) - return account.GetNextPeerExpiration() + return 0, false } - expiredPeers := account.GetExpiredPeers() var peerIDs []string for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) } - log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextPeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) + return 0, false } - return account.GetNextPeerExpiration() + return am.getNextPeerExpiration(ctx, accountID) } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { - am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextPeerExpiration(); ok { - go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { + go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } } // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { - log.Errorf("failed getting account %s expiring peers", accountID) - return account.GetNextInactivePeerExpiration() + log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) + return 0, false } - expiredPeers := account.GetInactivePeers() var peerIDs []string - for _, peer := range expiredPeers { + for _, peer := range inactivePeers { peerIDs = append(peerIDs, peer.ID) } - log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextInactivePeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", accountID) + return 0, false } - return account.GetNextInactivePeerExpiration() + return am.getNextInactivePeerExpiration(ctx, accountID) } } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions -func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { - am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { - go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) { + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID)) } } @@ -1409,7 +1401,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -2188,7 +2180,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -2235,7 +2227,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -2292,17 +2284,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, nil, nil, status.NewGetAccountError(err) - } - - peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) + peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -2316,12 +2303,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return status.NewGetAccountError(err) - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) + err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -2339,12 +2321,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) + _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 0c70b702a01..1692507dad6 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -4,6 +4,8 @@ import ( "context" "errors" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" @@ -73,6 +75,39 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { - return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + var err error + var groups []*nbgroup.Group + var peers []*nbpeer.Peer + var settings *Settings + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + peers, err = transaction.GetAccountPeers(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + return err + }) + if err != nil { + return nil, err + } + + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 46a4fbc1faf..e1a84b4f9c8 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,6 +47,7 @@ type MockAccountManager struct { DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) @@ -90,7 +91,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -130,7 +131,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + account, err := am.GetAccountFunc(ctx, accountID) + if err != nil { + return nil, err + } + approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} @@ -221,7 +227,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } @@ -682,9 +688,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { - return am.SyncPeerFunc(ctx, sync, account) + return am.SyncPeerFunc(ctx, sync, accountID) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } @@ -831,3 +837,11 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) } return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") } + +// GetPeerGroups mocks GetPeerGroups of the AccountManager interface +func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) { + if am.GetPeerGroupsFunc != nil { + return am.GetPeerGroupsFunc(ctx, accountID, peerID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index a941f404fc4..ba79a5b4808 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -11,8 +11,10 @@ import ( "sync" "time" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" @@ -53,43 +55,55 @@ type PeerLogin struct { // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + if user.IsRegularUser() && settings.RegularUsersViewBlocked { + return []*nbpeer.Peer{}, nil + } + + accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } + peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) - regularUser := !user.HasAdminPower() && !user.IsServiceUser - - if regularUser && account.Settings.RegularUsersViewBlocked { - return peers, nil - } - - for _, peer := range account.Peers { - if regularUser && user.Id != peer.UserID { + for _, peer := range accountPeers { + if user.IsRegularUser() && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin continue } - p := peer.Copy() - peers = append(peers, p) - peersMap[peer.ID] = p + peers = append(peers, peer) + peersMap[peer.ID] = peer } - if !regularUser { + if user.IsAdminOrServiceUser() { return peers, nil } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, err + } + + approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + if err != nil { + return nil, err + } + // fetch all the peers that have access to the user's peers for _, peer := range peers { aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) @@ -98,48 +112,46 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } } - peers = make([]*nbpeer.Peer, 0, len(peersMap)) - for _, peer := range peersMap { - peers = append(peers, peer) - } - - return peers, nil + return maps.Values(peersMap), nil } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { - peer, err := account.FindPeerByPubKey(peerPubKey) +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey) if err != nil { - return fmt.Errorf("failed to find peer by pub key: %w", err) + return err } - expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { - return fmt.Errorf("failed to update peer status and location: %w", err) + return err } - log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected) + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID) + if err != nil { + return err + } if peer.AddedWithSSOLogin() { - if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account.Id) + am.updateAccountPeers(ctx, accountID) } return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -159,18 +171,16 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context peer.Location.CountryCode = location.Country.ISOCode peer.Location.CityName = location.City.Names.En peer.Location.GeoNameID = location.City.GeonameID - err = am.Store.SavePeerLocation(account.Id, peer) + err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer) if err != nil { log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } - account.UpdatePeer(peer) - - err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus) if err != nil { - return false, fmt.Errorf("failed to save peer status: %w", err) + return false, err } return oldStatus.LoginExpired, nil @@ -181,37 +191,51 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - peer := account.GetPeer(update.ID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID) + if err != nil { + return nil, err } var requiresPeerUpdates bool - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) if err != nil { return nil, err } + var sshChanged, peerLabelChanged, loginExpirationChanged, inactivityExpirationChanged bool + if peer.SSHEnabled != update.SSHEnabled { peer.SSHEnabled = update.SSHEnabled - event := activity.PeerSSHEnabled - if !update.SSHEnabled { - event = activity.PeerSSHDisabled - } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + sshChanged = true } - peerLabelUpdated := peer.Name != update.Name - - if peerLabelUpdated { + if peer.Name != update.Name { peer.Name = update.Name + peerLabelChanged = true - existingLabels := account.getPeerDNSLabels() + existingLabels, err := am.getPeerDNSLabels(ctx, accountID) + if err != nil { + return nil, err + } newLabel, err := getPeerHostLabel(peer.Name, existingLabels) if err != nil { @@ -219,108 +243,69 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } peer.DNSLabel = newLabel - - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) } if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { - if !peer.AddedWithSSOLogin() { return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") } - peer.LoginExpirationEnabled = update.LoginExpirationEnabled - - event := activity.PeerLoginExpirationEnabled - if !update.LoginExpirationEnabled { - event = activity.PeerLoginExpirationDisabled - } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) - } + loginExpirationChanged = true } if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { - if !peer.AddedWithSSOLogin() { - return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated") } - peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled - - event := activity.PeerInactivityExpirationEnabled - if !update.InactivityExpirationEnabled { - event = activity.PeerInactivityExpirationDisabled - } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - - if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) - } + inactivityExpirationChanged = true } - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { return nil, err } - if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, accountID) + if sshChanged { + event := activity.PeerSSHEnabled + if !peer.SSHEnabled { + event = activity.PeerSSHDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - return peer, nil -} - -// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { + if peerLabelChanged { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + } - // the first loop is needed to ensure all peers present under the account before modifying, otherwise - // we might have some inconsistencies - peers := make([]*nbpeer.Peer, 0, len(peerIDs)) - for _, peerID := range peerIDs { + if loginExpirationChanged { + event := activity.PeerLoginExpirationEnabled + if !peer.LoginExpirationEnabled { + event = activity.PeerLoginExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - peers = append(peers, peer) } - // the 2nd loop performs the actual modification - for _, peer := range peers { + if inactivityExpirationChanged { + event := activity.PeerInactivityExpirationEnabled + if !peer.InactivityExpirationEnabled { + event = activity.PeerInactivityExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID) - if err != nil { - return err + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } + } - account.DeletePeer(peer.ID) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, - &UpdateMessage{ - Update: &proto.SyncResponse{ - // fill those field for backward compatibility - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - // new field - NetworkMap: &proto.NetworkMap{ - Serial: account.Network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - }, - }, - NetworkMap: &NetworkMap{}, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) - am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + if peerLabelChanged || requiresPeerUpdates { + am.updateAccountPeers(ctx, accountID) } - return nil + return peer, nil } // DeletePeer removes peer from the account by its IP @@ -328,24 +313,30 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID) if err != nil { return err } - updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) - if err != nil { - return err - } + var peer *nbpeer.Peer + var addPeerRemovedEvents []func() - err = am.deletePeers(ctx, account, []string{peerID}, userID) - if err != nil { - return err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID) + if err != nil { + return err + } - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return err + addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) + if err != nil { + return err + } + + return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + }) + + for _, addPeerRemovedEvent := range addPeerRemovedEvents { + addPeerRemovedEvent() } if updateAccountPeers { @@ -411,7 +402,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser := false if len(userID) > 0 { addedByUser = true - accountID, err = am.Store.GetAccountIDByUserID(userID) + accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) } else { accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } @@ -442,12 +433,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer - var groupsToAdd []string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { var setupKeyID string var setupKeyName string var ephemeral bool + var groupsToAdd []string if addedByUser { user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) if err != nil { @@ -590,39 +581,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s unlock() unlock = nil - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, status.NewGetAccountError(err) - } - - allGroup, err := account.GetGroupAll() - if err != nil { - return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) - } - groupsToAdd = append(groupsToAdd, allGroup.ID) - - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID) if err != nil { return nil, nil, nil, err } - if newGroupsAffectsPeers { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - approvedPeersMap, err := am.GetValidatedPeers(account) - if err != nil { - return nil, nil, nil, err - } - - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) - if err != nil { - return nil, nil, nil, err - } - - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) - return newPeer, networkMap, postureChecks, nil + return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { @@ -645,16 +613,16 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() } if peer.UserID != "" { - user, err := account.FindUser(peer.UserID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get user: %w", err) + return nil, nil, nil, err } err = checkIfPeerOwnerIsBlocked(peer, user) @@ -663,52 +631,38 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } } - if peerLoginExpired(ctx, peer, account.Settings) { - return nil, nil, nil, status.NewPeerLoginExpiredError() + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err } - updated := peer.UpdateMetaIfNew(sync.Meta) - if updated { - err = am.Store.SavePeer(ctx, account.Id, peer) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) - } - - if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account.Id) - } + if peerLoginExpired(ctx, peer, settings) { + return nil, nil, nil, status.NewPeerLoginExpiredError() } - peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err) - } - - var postureChecks []*posture.Checks - - if peerNotValid { - emptyMap := &NetworkMap{ - Network: account.Network.Copy(), - } - return peer, emptyMap, postureChecks, nil + return nil, nil, nil, err } - if isStatusChanged { - am.updateAccountPeers(ctx, account.Id) + peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra) + if err != nil { + return nil, nil, nil, err } - validPeersMap, err := am.GetValidatedPeers(account) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) + updated := peer.UpdateMetaIfNew(sync.Meta) + if updated { + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + if err != nil { + return nil, nil, nil, err + } } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) - if err != nil { - return nil, nil, nil, err + if isStatusChanged || (updated && sync.UpdateAccountPeers) { + am.updateAccountPeers(ctx, accountID) } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } // LoginPeer logs in or registers a peer. @@ -814,7 +768,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if shouldStorePeer { - err = am.Store.SavePeer(ctx, accountID, peer) + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) if err != nil { return nil, nil, nil, err } @@ -823,16 +777,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer() unlockPeer = nil - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - if updateRemotePeers || isStatusChanged { am.updateAccountPeers(ctx, accountID) } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) + return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO @@ -864,22 +813,30 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - var postureChecks []*posture.Checks - +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { if isRequiresApproval { + network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err + } + emptyMap := &NetworkMap{ - Network: account.Network.Copy(), + Network: network.Copy(), } return peer, emptyMap, nil, nil } - approvedPeersMap, err := am.GetValidatedPeers(account) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) + if err != nil { + return nil, nil, nil, err + } + + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, peer.ID) if err != nil { return nil, nil, nil, err } @@ -896,7 +853,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. peer = peer.UpdateLastLogin() - err = am.Store.SavePeer(ctx, peer.AccountID, peer) + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer) if err != nil { return err } @@ -943,41 +900,47 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + if user.IsRegularUser() && settings.RegularUsersViewBlocked { return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) } - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return nil, err } // if admin or user owns this peer, return peer - if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID { + if user.IsAdminOrServiceUser() || peer.UserID == userID { return peer, nil } // it is also possible that user doesn't own the peer but some of his peers have access to it, // this is a valid case, show the peer as well. - userPeers, err := account.FindUserPeers(userID) + userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID) + if err != nil { + return nil, err + } + + approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) if err != nil { return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, err } @@ -1006,12 +969,13 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) return } + peers := account.GetPeers() - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) return @@ -1037,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) return } @@ -1050,22 +1014,240 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } -func ConvertSliceToMap(existingLabels []string) map[string]struct{} { - labelMap := make(map[string]struct{}, len(existingLabels)) - for _, label := range existingLabels { - labelMap[label] = struct{}{} +// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) + return 0, false } - return labelMap + + if len(peersWithExpiry) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return 0, false + } + + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) + return 0, false + } + + if len(peersWithInactivity) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return 0, false + } + + var nextExpiry *time.Duration + for _, peer := range peersWithInactivity { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// getExpiredPeers returns peers that have been expired. +func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, peer := range peersWithExpiry { + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// getInactivePeers returns peers that have been expired by inactivity +func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, inactivePeer := range peersWithInactivity { + inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + + return peers, nil +} + +// GetPeerGroups returns groups that the peer is part of. +func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerGroups := make([]*nbgroup.Group, 0) + for _, group := range groups { + if slices.Contains(group.Peers, peerID) { + peerGroups = append(peerGroups, group) + } + } + + return peerGroups, nil +} + +// getPeerGroupIDs returns the IDs of the groups that the peer is part of. +func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) { + groups, err := am.GetPeerGroups(ctx, accountID, peerID) + if err != nil { + return nil, err + } + + groupIDs := make([]string, 0, len(groups)) + for _, group := range groups { + groupIDs = append(groupIDs, group.ID) + } + + return groupIDs, err +} + +func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) { + dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + existingLabels := make(lookupMap) + for _, label := range dnsLabels { + existingLabels[label] = struct{}{} + } + return existingLabels, nil } // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { - peerGroupIDs := make([]string, 0) - for _, group := range account.Groups { - if slices.Contains(group.Peers, peerID) { - peerGroupIDs = append(peerGroupIDs, group.ID) +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) { + peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID) + if err != nil { + return false, err + } + return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction +} + +// deletePeers deletes all specified peers and sends updates to the remote peers. +// Returns a slice of functions to save events after successful peer deletion. +func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { + var peerDeletedEvents []func() + + for _, peer := range peers { + if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { + return nil, err + } + + network, err := transaction.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err } + + if err = transaction.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil { + return nil, err + } + + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + }, + }, + NetworkMap: &NetworkMap{}, + }) + am.peersUpdateManager.CloseChannel(ctx, peer.ID) + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + }) + } + + return peerDeletedEvents, nil +} + +func ConvertSliceToMap(existingLabels []string) map[string]struct{} { + labelMap := make(map[string]struct{}, len(existingLabels)) + for _, label := range existingLabels { + labelMap[label] = struct{}{} } - return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) + return labelMap } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 34d7918446b..146af886178 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -44,7 +44,7 @@ type Peer struct { // CreatedAt records the time the peer was created CreatedAt time.Time // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `gorm:"index"` // Geo location based on connection IP Location Location `gorm:"embedded;embeddedPrefix:location_"` } diff --git a/management/server/user.go b/management/server/user.go index 74062112af6..823eaa311c1 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -487,6 +487,10 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } delete(account.Users, targetUserID) + if updateAccountPeers { + account.Network.IncSerial() + } + err = am.Store.SaveAccount(ctx, account) if err != nil { return err @@ -511,12 +515,16 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU return false, nil } - peerIDs := make([]string, 0, len(peers)) - for _, peer := range peers { - peerIDs = append(peerIDs, peer.ID) + eventsToStore, err := deletePeers(ctx, am, am.Store, account.Id, initiatorUserID, peers) + if err != nil { + return false, err } - return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) + for _, storeEvent := range eventsToStore { + storeEvent() + } + + return hadPeers, nil } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. @@ -823,7 +831,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if len(expiredPeers) > 0 { - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, account.Id, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } @@ -1104,7 +1112,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -1115,16 +1123,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou } peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { - return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err) - } - - log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID) + if err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { + return err + } am.StoreEvent( ctx, - peer.UserID, peer.ID, account.Id, + peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), ) } @@ -1132,7 +1137,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account.Id) + am.updateAccountPeers(ctx, accountID) } return nil } @@ -1234,6 +1239,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account deletedUsersMeta[targetUserID] = meta } + if updateAccountPeers { + account.Network.IncSerial() + } err = am.Store.SaveAccount(ctx, account) if err != nil { return fmt.Errorf("failed to delete users: %w", err) From f6f7260897ac9b84e348fc34e15fe07fcb041586 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:34:05 +0300 Subject: [PATCH 41/60] Fix tests Signed-off-by: bcmmbaga --- management/server/account_test.go | 19 ++++++------------- management/server/sql_store_test.go | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index c8c2d59410b..a13b89f3354 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1472,7 +1472,6 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - userID := "account_creator" account, err := createAccount(manager, "test_account", userID, "netbird.cloud") if err != nil { t.Fatal(err) @@ -1501,7 +1500,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID) + err = manager.DeletePeer(context.Background(), account.Id, peer.ID, userID) if err != nil { return } @@ -1523,7 +1522,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { assert.Equal(t, peer.Name, ev.Meta["name"]) assert.Equal(t, peer.FQDN(account.Domain), ev.Meta["fqdn"]) assert.Equal(t, userID, ev.InitiatorID) - assert.Equal(t, peer.IP.String(), ev.TargetID) + assert.Equal(t, peer.ID, ev.TargetID) assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } @@ -1853,13 +1852,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1927,11 +1923,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1962,7 +1955,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index b568b7fe03a..7f36eb50612 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -400,7 +400,7 @@ func TestSqlite_SavePeer(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() - err = store.SavePeer(ctx, account.Id, peer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -416,7 +416,7 @@ func TestSqlite_SavePeer(t *testing.T) { updatedPeer.Status.Connected = false updatedPeer.Meta.Hostname = "updatedpeer" - err = store.SavePeer(ctx, account.Id, updatedPeer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -442,7 +442,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -461,7 +461,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -472,7 +472,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { newStatus.Connected = true - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -507,7 +507,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { Meta: nbpeer.PeerSystemMeta{}, } // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) account.Peers[peer.ID] = peer @@ -519,7 +519,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { peer.Location.CityName = "Berlin" peer.Location.GeoNameID = 2950159 - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID]) assert.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -529,7 +529,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { assert.Equal(t, peer.Location, actual) peer.ID = "non-existing-peer" - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -908,7 +908,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) // save new status of existing peer @@ -924,7 +924,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) From 4ef3890bf757f149da7bce3588f62bfbf85b9d8c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 15 Nov 2024 17:48:00 +0300 Subject: [PATCH 42/60] Fix typo Signed-off-by: bcmmbaga --- management/server/dns.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/dns.go b/management/server/dns.go index be7caea4eff..8df211b0b0b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -161,7 +161,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return nil } -// prepareGroupEvents prepares a list of event functions to be stored. +// prepareDNSSettingsEvents prepares a list of event functions to be stored. func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { var eventsToStore []func() From 51c1ec283cb9d9dacc9ec18ab6d98b64d954d362 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 15 Nov 2024 19:34:57 +0300 Subject: [PATCH 43/60] Add locks and remove log Signed-off-by: bcmmbaga --- management/server/posture_checks.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d7b5a79a23b..59e726c4165 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -9,7 +9,6 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" - log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" ) @@ -32,6 +31,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -85,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -267,7 +272,6 @@ func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountI for _, sourceGroup := range rule.Sources { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) if err != nil { - log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err) return false, fmt.Errorf("failed to check peer in policy source group: %w", err) } From a61e9da3e9fbf86274807b79adb5c290ca19f09e Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 15:06:25 +0300 Subject: [PATCH 44/60] run peer ops in transaction Signed-off-by: bcmmbaga --- management/server/account.go | 6 +- management/server/peer.go | 344 +++++++++++++++++------------- management/server/sql_store.go | 14 ++ management/server/status/error.go | 5 + management/server/store.go | 1 + management/server/user.go | 4 + 6 files changed, 222 insertions(+), 152 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 4222179d95f..726ef01732e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2390,8 +2390,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } -func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction Store, peer *nbpeer.Peer, settings *Settings) (bool, error) { + user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -2402,7 +2402,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee } if peerLoginExpired(ctx, peer, settings) { - err = am.handleExpiredPeer(ctx, user, peer) + err = am.handleExpiredPeer(ctx, transaction, user, peer) if err != nil { return false, err } diff --git a/management/server/peer.go b/management/server/peer.go index ba79a5b4808..a93f62a1d73 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/management/server/geolocation" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -117,17 +118,25 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey) - if err != nil { - return err - } + var peer *nbpeer.Peer + var settings *Settings + var expired bool + var err error - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, peerPubKey) + if err != nil { + return err + } + + settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } - expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID) + expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID) + return err + }) if err != nil { return err } @@ -151,7 +160,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -162,8 +171,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context } peer.Status = newStatus - if am.geo != nil && realIP != nil { - location, err := am.geo.Lookup(realIP) + if geo != nil && realIP != nil { + location, err := geo.Lookup(realIP) if err != nil { log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) } else { @@ -171,14 +180,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context peer.Location.CountryCode = location.Country.ISOCode peer.Location.CityName = location.City.Names.En peer.Location.GeoNameID = location.City.GeonameID - err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer) + err = transaction.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer) if err != nil { log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } - err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus) + err := transaction.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus) if err != nil { return false, err } @@ -200,23 +209,49 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, status.NewUserNotPartOfAccountError() } - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID) - if err != nil { - return nil, err - } + var peer *nbpeer.Peer + var settings *Settings + var peerGroupList []string + var requiresPeerUpdates bool + var newLabel string - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, update.ID) + if err != nil { + return err + } - peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID) - if err != nil { - return nil, err - } + settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } - var requiresPeerUpdates bool - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) + peerGroupList, err = getPeerGroupIDs(ctx, am.Store, accountID, update.ID) + if err != nil { + return err + } + + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) + if err != nil { + return err + } + + if peer.Name != update.Name { + existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID) + if err != nil { + return err + } + + newLabel, err = getPeerHostLabel(update.Name, existingLabels) + if err != nil { + return err + } + + peer.DNSLabel = newLabel + } + + return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + }) if err != nil { return nil, err } @@ -231,18 +266,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peer.Name != update.Name { peer.Name = update.Name peerLabelChanged = true - - existingLabels, err := am.getPeerDNSLabels(ctx, accountID) - if err != nil { - return nil, err - } - - newLabel, err := getPeerHostLabel(peer.Name, existingLabels) - if err != nil { - return nil, err - } - - peer.DNSLabel = newLabel } if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { @@ -261,10 +284,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user inactivityExpirationChanged = true } - if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { - return nil, err - } - if sshChanged { event := activity.PeerSSHEnabled if !peer.SSHEnabled { @@ -313,13 +332,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID) + peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID) if err != nil { return err } + if peerAccountID != accountID { + return status.NewPeerNotPartOfAccountError() + } + var peer *nbpeer.Peer - var addPeerRemovedEvents []func() + var updateAccountPeers bool + var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID) @@ -327,16 +351,21 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) + updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) if err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) + return err }) - for _, addPeerRemovedEvent := range addPeerRemovedEvents { - addPeerRemovedEvent() + for _, storeEvent := range eventsToStore { + storeEvent() } if updateAccountPeers { @@ -433,6 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer + var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { var setupKeyID string @@ -480,7 +510,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to get free DNS label: %w", err) } - freeIP, err := am.getFreeIP(ctx, transaction, accountID) + freeIP, err := getFreeIP(ctx, transaction, accountID) if err != nil { return fmt.Errorf("failed to get free IP: %w", err) } @@ -564,6 +594,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } + updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID) + if err != nil { + return err + } + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) return nil }) @@ -581,11 +616,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s unlock() unlock = nil - updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID) - if err != nil { - return nil, nil, nil, err - } - if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } @@ -593,13 +623,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } -func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { - takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) +func getFreeIP(ctx context.Context, transaction Store, accountID string) (net.IP, error) { + takenIps, err := transaction.GetTakenIPs(ctx, LockingStrengthShare, accountID) if err != nil { return nil, fmt.Errorf("failed to get taken IPs: %w", err) } - network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + network, err := transaction.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed getting network: %w", err) } @@ -614,50 +644,61 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey) - if err != nil { - return nil, nil, nil, status.NewPeerNotRegisteredError() - } + var peer *nbpeer.Peer + var peerNotValid bool + var isStatusChanged bool + var updated bool + var err error - if peer.UserID != "" { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey) if err != nil { - return nil, nil, nil, err + return status.NewPeerNotRegisteredError() } - err = checkIfPeerOwnerIsBlocked(peer, user) + if peer.UserID != "" { + user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) + if err != nil { + return err + } + + if err = checkIfPeerOwnerIsBlocked(peer, user); err != nil { + return err + } + } + + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, nil, nil, err + return err } - } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, nil, nil, err - } + if peerLoginExpired(ctx, peer, settings) { + return status.NewPeerLoginExpiredError() + } - if peerLoginExpired(ctx, peer, settings) { - return nil, nil, nil, status.NewPeerLoginExpiredError() - } + peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID) + if err != nil { + return err + } - peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID) - if err != nil { - return nil, nil, nil, err - } + peerNotValid, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + if err != nil { + return err + } - peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra) + updated = peer.UpdateMetaIfNew(sync.Meta) + if updated { + err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + if err != nil { + return err + } + } + return nil + }) if err != nil { return nil, nil, nil, err } - updated := peer.UpdateMetaIfNew(sync.Meta) - if updated { - err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) - if err != nil { - return nil, nil, nil, err - } - } - if isStatusChanged || (updated && sync.UpdateAccountPeers) { am.updateAccountPeers(ctx, accountID) } @@ -707,73 +748,73 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } }() - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) - if err != nil { - return nil, nil, nil, err - } + var peer *nbpeer.Peer + var updateRemotePeers bool + var isRequiresApproval bool + var isStatusChanged bool - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, nil, nil, err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) + if err != nil { + return err + } + + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + // this flag prevents unnecessary calls to the persistent store. + shouldStorePeer := false + + if login.UserID != "" { + if peer.UserID != login.UserID { + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) + return status.Errorf(status.Unauthenticated, "invalid user") + } - // this flag prevents unnecessary calls to the persistent store. - shouldStorePeer := false - updateRemotePeers := false + changed, err := am.handleUserPeer(ctx, transaction, peer, settings) + if err != nil { + return err + } - if login.UserID != "" { - if peer.UserID != login.UserID { - log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) - return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user") + if changed { + shouldStorePeer = true + updateRemotePeers = true + } } - changed, err := am.handleUserPeer(ctx, peer, settings) + peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID) if err != nil { - return nil, nil, nil, err + return err } - if changed { + + isRequiresApproval, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + if err != nil { + return err + } + + updated := peer.UpdateMetaIfNew(login.Meta) + if updated { shouldStorePeer = true - updateRemotePeers = true } - } - groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, nil, nil, err - } + if peer.SSHKey != login.SSHKey { + peer.SSHKey = login.SSHKey + shouldStorePeer = true + } - var grps []string - for _, group := range groups { - for _, id := range group.Peers { - if id == peer.ID { - grps = append(grps, group.ID) - break + if shouldStorePeer { + if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { + return err } } - } - isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra) + return nil + }) if err != nil { return nil, nil, nil, err } - updated := peer.UpdateMetaIfNew(login.Meta) - if updated { - shouldStorePeer = true - } - - if peer.SSHKey != login.SSHKey { - peer.SSHKey = login.SSHKey - shouldStorePeer = true - } - - if shouldStorePeer { - err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) - if err != nil { - return nil, nil, nil, err - } - } - unlockPeer() unlockPeer = nil @@ -845,7 +886,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error { +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction Store, user *User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { return err @@ -853,12 +894,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. peer = peer.UpdateLastLogin() - err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer) + err = transaction.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer) if err != nil { return err } - err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) + err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) if err != nil { return err } @@ -1149,7 +1190,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID // GetPeerGroups returns groups that the peer is part of. func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { - groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + return getPeerGroups(ctx, am.Store, accountID, peerID) +} + +// getPeerGroups returns the IDs of the groups that the peer is part of. +func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) { + groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1165,8 +1211,8 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p } // getPeerGroupIDs returns the IDs of the groups that the peer is part of. -func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) { - groups, err := am.GetPeerGroups(ctx, accountID, peerID) +func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) { + groups, err := getPeerGroups(ctx, transaction, accountID, peerID) if err != nil { return nil, err } @@ -1179,8 +1225,8 @@ func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID return groupIDs, err } -func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) { - dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) +func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) { + dnsLabels, err := transaction.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1194,12 +1240,12 @@ func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) { - peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID) +func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) { + peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) if err != nil { return false, err } - return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction + return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction } // deletePeers deletes all specified peers and sends updates to the remote peers. diff --git a/management/server/sql_store.go b/management/server/sql_store.go index b921ed47d3c..a672e7e6fc7 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -754,6 +754,20 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin return accountID, nil } +func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) { + var accountID string + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Select("account_id").Where(idQueryCondition, peerID).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "peer %s account not found", peerID) + } + return "", status.NewGetAccountFromStoreError(result.Error) + } + + return accountID, nil +} + func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) diff --git a/management/server/status/error.go b/management/server/status/error.go index 59f436f5b19..505f874ad1f 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -86,6 +86,11 @@ func NewAccountNotFoundError(accountKey string) error { return Errorf(NotFound, "account not found: %s", accountKey) } +// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account +func NewPeerNotPartOfAccountError() error { + return Errorf(PermissionDenied, "peer is not part of this account") +} + // NewUserNotFoundError creates a new Error with NotFound type for a missing user func NewUserNotFoundError(userKey string) error { return Errorf(NotFound, "user not found: %s", userKey) diff --git a/management/server/store.go b/management/server/store.go index 9ecb9c1698f..94324fcc307 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -50,6 +50,7 @@ type Store interface { GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) + GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) diff --git a/management/server/user.go b/management/server/user.go index 823eaa311c1..45cd45d4966 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -524,6 +524,10 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU storeEvent() } + for _, peer := range peers { + account.DeletePeer(peer.ID) + } + return hadPeers, nil } From a2fb274b86641f1569e1b288ad0f8704c519e7a0 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 15:09:30 +0300 Subject: [PATCH 45/60] remove duplicate store method Signed-off-by: bcmmbaga --- management/server/peer.go | 2 +- management/server/sql_store.go | 17 ----------------- management/server/store.go | 1 - 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index a93f62a1d73..e71587948da 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1226,7 +1226,7 @@ func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, p } func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) { - dnsLabels, err := transaction.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) + dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a672e7e6fc7..6e4f1d3960a 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1147,23 +1147,6 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng return peersMap, nil } -// GetAccountPeerDNSLabels retrieves all unique DNS labels for peers associated with a specified account. -func (s *SqlStore) GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - var labels []string - - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). - Where(accountIDCondition, accountID).Pluck("dns_label", &labels) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "no peers found for the account") - } - log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting dns labels from store") - } - - return labels, nil -} - // GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer diff --git a/management/server/store.go b/management/server/store.go index 94324fcc307..5b48de37823 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -95,7 +95,6 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error From a2a49bdd47c363f8ef7e5ae9fc8d49061d10142b Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 16:43:09 +0300 Subject: [PATCH 46/60] fix peer fields updated after save Signed-off-by: bcmmbaga --- management/server/peer.go | 62 +++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index e71587948da..8e368ec4edf 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -213,7 +213,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var settings *Settings var peerGroupList []string var requiresPeerUpdates bool - var newLabel string + var peerLabelChanged bool + var sshChanged bool + var loginExpirationChanged bool + var inactivityExpirationChanged bool err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, update.ID) @@ -226,7 +229,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - peerGroupList, err = getPeerGroupIDs(ctx, am.Store, accountID, update.ID) + peerGroupList, err = getPeerGroupIDs(ctx, transaction, accountID, update.ID) if err != nil { return err } @@ -242,46 +245,41 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - newLabel, err = getPeerHostLabel(update.Name, existingLabels) + newLabel, err := getPeerHostLabel(update.Name, existingLabels) if err != nil { return err } + peer.Name = update.Name peer.DNSLabel = newLabel + peerLabelChanged = true } - return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) - }) - if err != nil { - return nil, err - } - - var sshChanged, peerLabelChanged, loginExpirationChanged, inactivityExpirationChanged bool - - if peer.SSHEnabled != update.SSHEnabled { - peer.SSHEnabled = update.SSHEnabled - sshChanged = true - } - - if peer.Name != update.Name { - peer.Name = update.Name - peerLabelChanged = true - } + if peer.SSHEnabled != update.SSHEnabled { + peer.SSHEnabled = update.SSHEnabled + sshChanged = true + } - if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { - if !peer.AddedWithSSOLogin() { - return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { + if !peer.AddedWithSSOLogin() { + return status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + } + peer.LoginExpirationEnabled = update.LoginExpirationEnabled + loginExpirationChanged = true } - peer.LoginExpirationEnabled = update.LoginExpirationEnabled - loginExpirationChanged = true - } - if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { - if !peer.AddedWithSSOLogin() { - return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated") + if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + if !peer.AddedWithSSOLogin() { + return status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated") + } + peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + inactivityExpirationChanged = true } - peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled - inactivityExpirationChanged = true + + return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + }) + if err != nil { + return nil, err } if sshChanged { @@ -783,7 +781,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID) + peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID) if err != nil { return err } From 48edfa601f6a569bf805b0da21775643cd20403c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 16:43:19 +0300 Subject: [PATCH 47/60] add tests Signed-off-by: bcmmbaga --- management/server/management_proto_test.go | 2 +- management/server/sql_store_test.go | 151 ++++++++++++++++-- .../testdata/store_with_expired_peers.sql | 9 +- management/server/testdata/storev1.sql | 2 +- 4 files changed, 143 insertions(+), 21 deletions(-) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index dc8765e197f..57ad968b3d7 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -246,7 +246,7 @@ func Test_SyncProtocol(t *testing.T) { t.Fatal("expecting SyncResponse to have non-nil NetworkMap") } - if len(networkMap.GetRemotePeers()) != 3 { + if len(networkMap.GetRemotePeers()) != 4 { t.Fatalf("expecting SyncResponse to have NetworkMap with 3 remote peers, got %d", len(networkMap.GetRemotePeers())) } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 7f36eb50612..f08d4bac78d 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -378,11 +378,6 @@ func TestSqlite_GetAccount(t *testing.T) { } func TestSqlite_SavePeer(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -428,11 +423,6 @@ func TestSqlite_SavePeer(t *testing.T) { } func TestSqlite_SavePeerStatus(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -483,11 +473,6 @@ func TestSqlite_SavePeerStatus(t *testing.T) { } func TestSqlite_SavePeerLocation(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -2049,3 +2034,139 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Error(t, err) require.Nil(t, nsGroup) } + +func TestSqlStore_GetAccountPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve peers by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 4, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve peers with expiration by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve peers with inactivity by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/storev1.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthShare) + require.NoError(t, err) + require.Len(t, peers, 1) + require.True(t, peers[0].Ephemeral) +} + +func TestSqlStore_DeletePeer(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "csrnkiq7qv9d8aitqd50" + + err = store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) + require.NoError(t, err) + + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) + require.Error(t, err) + require.Nil(t, peer) +} diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 100a6470f43..54b946b5ab7 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -1,6 +1,6 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`inactivity_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -27,9 +27,10 @@ CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); -INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql index 69194d62391..281fdac8a3b 100644 --- a/management/server/testdata/storev1.sql +++ b/management/server/testdata/storev1.sql @@ -34,6 +34,6 @@ INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038' INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); -INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',1,'""','','',0); INSERT INTO installations VALUES(1,''); From ec6438e643c3ef0b0388c05c2fad5beed99fb4fe Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 17:12:13 +0300 Subject: [PATCH 48/60] Use update strength and simplify check Signed-off-by: bcmmbaga --- management/server/policy.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/management/server/policy.go b/management/server/policy.go index 6dcb963162b..693ae2872bf 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -435,7 +435,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) if err != nil { return err } @@ -502,8 +502,6 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account if hasPeers { return true, nil } - - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) } return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) From df98c67ac8100ac75995fecfaa32e76691db36c0 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 18:46:52 +0300 Subject: [PATCH 49/60] prevent changing ruleID when not empty Signed-off-by: bcmmbaga --- management/server/http/policies_handler.go | 7 ++++++- management/server/sql_store_test.go | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 8255e489648..ca256a18351 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -128,8 +128,13 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID Description: req.Description, } for _, rule := range req.Rules { + ruleID := policyID // TODO: when policy can contain multiple rules, need refactor + if rule.Id != nil { + ruleID = *rule.Id + } + pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + ID: ruleID, PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 8931008d7ff..c05793fc624 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1832,6 +1832,8 @@ func TestSqlStore_SavePolicy(t *testing.T) { policy.Enabled = false policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) require.NoError(t, err) From b60e2c32614615d5f7c44d95794338ade30d9287 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 22:48:38 +0300 Subject: [PATCH 50/60] prevent duplicate rules during updates Signed-off-by: bcmmbaga --- management/server/http/policies_handler.go | 2 +- management/server/policy.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index ca256a18351..eff9092d45e 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -128,7 +128,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID Description: req.Description, } for _, rule := range req.Rules { - ruleID := policyID // TODO: when policy can contain multiple rules, need refactor + var ruleID string if rule.Id != nil { ruleID = *rule.Id } diff --git a/management/server/policy.go b/management/server/policy.go index 693ae2872bf..2d3abc3f1e2 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -532,7 +532,7 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po for i, rule := range policy.Rules { ruleCopy := rule.Copy() if ruleCopy.ID == "" { - ruleCopy.ID = xid.New().String() + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor ruleCopy.PolicyID = policy.ID } From 20fc8e879e9d47bc5be8b722ad6130c71dd6a9e2 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 19 Nov 2024 00:54:07 +0300 Subject: [PATCH 51/60] fix tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index f08d4bac78d..39c42bfa637 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -392,7 +392,7 @@ func TestSqlite_SavePeer(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, } ctx := context.Background() err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) @@ -418,8 +418,11 @@ func TestSqlite_SavePeer(t *testing.T) { require.NoError(t, err) actual := account.Peers[peer.ID] - assert.Equal(t, updatedPeer.Status, actual.Status) assert.Equal(t, updatedPeer.Meta, actual.Meta) + assert.Equal(t, updatedPeer.Status.Connected, actual.Status.Connected) + assert.Equal(t, updatedPeer.Status.LoginExpired, actual.Status.LoginExpired) + assert.Equal(t, updatedPeer.Status.RequiresApproval, actual.Status.RequiresApproval) + assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen, time.Millisecond, "LastSeen should be equal") } func TestSqlite_SavePeerStatus(t *testing.T) { @@ -431,7 +434,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} + newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().Local()} err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) @@ -445,7 +448,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, } err = store.SaveAccount(context.Background(), account) @@ -458,7 +461,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) + assert.Equal(t, newStatus.Connected, actual.Connected) + assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) + assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") newStatus.Connected = true @@ -469,7 +475,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) actual = account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) + assert.Equal(t, newStatus.Connected, actual.Connected) + assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) + assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") } func TestSqlite_SavePeerLocation(t *testing.T) { From 0ee56e14d92213662bba845f59f9f1683c384030 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 19 Nov 2024 10:47:26 +0300 Subject: [PATCH 52/60] fix lint Signed-off-by: bcmmbaga --- management/server/sql_store.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 6e4f1d3960a..28a5d754d4e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -846,7 +846,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "peer not found") + return nil, status.NewPeerNotFoundError(peerKey) } return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } @@ -1121,7 +1121,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength First(&peer, accountAndIDQueryCondition, accountID, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "peer not found") + return nil, status.NewPeerNotFoundError(peerID) } log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get peer from store") @@ -1203,7 +1203,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "peer not found") + return status.NewPeerNotFoundError(peerID) } return nil From 82746d93ee93c5f1cbcd3d36f28d9dc89377e6b4 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 21 Nov 2024 17:15:07 +0300 Subject: [PATCH 53/60] Use UTC time in test Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 39c42bfa637..ad1a39cdad6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -392,7 +392,7 @@ func TestSqlite_SavePeer(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) @@ -422,7 +422,7 @@ func TestSqlite_SavePeer(t *testing.T) { assert.Equal(t, updatedPeer.Status.Connected, actual.Status.Connected) assert.Equal(t, updatedPeer.Status.LoginExpired, actual.Status.LoginExpired) assert.Equal(t, updatedPeer.Status.RequiresApproval, actual.Status.RequiresApproval) - assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } func TestSqlite_SavePeerStatus(t *testing.T) { @@ -434,7 +434,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().Local()} + newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) @@ -448,7 +448,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -464,7 +464,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.Equal(t, newStatus.Connected, actual.Connected) assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) - assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") newStatus.Connected = true @@ -478,7 +478,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.Equal(t, newStatus.Connected, actual.Connected) assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) - assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } func TestSqlite_SavePeerLocation(t *testing.T) { From fde9f2ffdafda92802753afeef2ac029a932302b Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 12:18:02 +0300 Subject: [PATCH 54/60] Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID Signed-off-by: bcmmbaga --- management/server/peer.go | 6 +++--- management/server/sql_store.go | 24 ++++++++++++++++-------- management/server/store.go | 6 +++--- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index ff5bc23d57e..586f6d9196c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -558,21 +558,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) - err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + err = transaction.AddPeerToAllGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID) if err != nil { return fmt.Errorf("failed adding peer to All group: %w", err) } if len(groupsToAdd) > 0 { for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) + err = transaction.AddPeerToGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID, g) if err != nil { return err } } } - err = transaction.AddPeerToAccount(ctx, newPeer) + err = transaction.AddPeerToAccount(ctx, LockingStrengthUpdate, newPeer) if err != nil { return fmt.Errorf("failed to add peer to account: %w", err) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 84c7ab8a9fe..1280cc88889 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1030,9 +1030,10 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string return nil } -func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { var group nbgroup.Group - result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&group, "account_id = ? AND name = ?", accountID, "All") if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") @@ -1048,16 +1049,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) - if err := s.db.Save(&group).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil } -func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { +func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error { var group nbgroup.Group - result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID). + First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewGroupNotFoundError(groupID) @@ -1074,7 +1076,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) - if err := s.db.Save(&group).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group: %s", err) } @@ -1096,6 +1098,12 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer + + // Exclude peers added via setup keys, as they are not user-specific and have an empty user_id. + if userID == "" { + return peers, nil + } + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) if err := result.Error; err != nil { @@ -1106,8 +1114,8 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt return peers, nil } -func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - if err := s.db.Create(peer).Error; err != nil { +func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } diff --git a/management/server/store.go b/management/server/store.go index 5b48de37823..852ca691142 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -95,9 +95,9 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error - AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error - AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) From a22d5041e3287182252ce7b6d8ab6dc76a796c01 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 12:21:15 +0300 Subject: [PATCH 55/60] Add missing tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 313 +++++++++++++++++- .../server/testdata/store_policy_migrate.sql | 1 + 2 files changed, 310 insertions(+), 4 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index c19eb1117e1..6ee161c7973 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1005,7 +1005,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1018,7 +1018,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{2, 2, 2, 2}, } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1050,7 +1050,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer1.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1062,7 +1062,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer2.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -2045,3 +2045,308 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Error(t, err) require.Nil(t, nsGroup) } + +func TestSqlStore_AddPeerToGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "cfefqs706sqkneg59g4g" + groupID := "cfefqs706sqkneg59g4h" + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 0, "group should have 0 peers") + + err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID) + require.NoError(t, err, "failed to add peer to group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 1, "group should have 1 peers") + require.Contains(t, group.Peers, peerID) +} + +func TestSqlStore_AddPeerToAllGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + groupID := "cfefqs706sqkneg59g3g" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.domain.test", + } + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 2, "group should have 2 peers") + require.NotContains(t, group.Peers, peer.ID) + + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID) + require.NoError(t, err, "failed to add peer to all group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 3, "group should have peers") + require.Contains(t, group.Peers, peer.ID) +} + +func TestSqlStore_AddPeerToAccount(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + Key: "key", + IP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "hostname", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + Name: "peer.test", + DNSLabel: "peer", + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + SSHKey: "ssh-key", + SSHEnabled: false, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + LastLogin: time.Now().UTC(), + CreatedAt: time.Now().UTC(), + Ephemeral: true, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + storedPeer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peer.ID) + require.NoError(t, err, "failed to get peer") + + assert.Equal(t, peer.ID, storedPeer.ID) + assert.Equal(t, peer.AccountID, storedPeer.AccountID) + assert.Equal(t, peer.Key, storedPeer.Key) + assert.Equal(t, peer.IP.String(), storedPeer.IP.String()) + assert.Equal(t, peer.Meta, storedPeer.Meta) + assert.Equal(t, peer.Name, storedPeer.Name) + assert.Equal(t, peer.DNSLabel, storedPeer.DNSLabel) + assert.Equal(t, peer.SSHKey, storedPeer.SSHKey) + assert.Equal(t, peer.SSHEnabled, storedPeer.SSHEnabled) + assert.Equal(t, peer.LoginExpirationEnabled, storedPeer.LoginExpirationEnabled) + assert.Equal(t, peer.InactivityExpirationEnabled, storedPeer.InactivityExpirationEnabled) + assert.WithinDurationf(t, peer.LastLogin, storedPeer.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") + assert.WithinDurationf(t, peer.CreatedAt, storedPeer.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") + assert.Equal(t, peer.Ephemeral, storedPeer.Ephemeral) + assert.Equal(t, peer.Status.Connected, storedPeer.Status.Connected) + assert.Equal(t, peer.Status.LoginExpired, storedPeer.Status.LoginExpired) + assert.Equal(t, peer.Status.RequiresApproval, storedPeer.Status.RequiresApproval) + assert.WithinDurationf(t, peer.Status.LastSeen, storedPeer.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") +} + +func TestSqlStore_GetAccountPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 4, + }, + { + name: "should return no peers for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers for an empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers with expiration for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "should return no peers with expiration for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers with expiration for a empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers with inactivity for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "should return no peers with inactivity for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers with inactivity for an empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/storev1.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthShare) + require.NoError(t, err) + require.Len(t, peers, 1) + require.True(t, peers[0].Ephemeral) +} + +func TestSqlStore_GetUserPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + userID string + expectedCount int + }{ + { + name: "should retrieve peers for existing account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 1, + }, + { + name: "should return no peers for non-existing account ID with existing user ID", + accountID: "nonexistent", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 0, + }, + { + name: "should return no peers for non-existing user ID with existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "nonexistent_user", + expectedCount: 0, + }, + { + name: "should retrieve peers for another valid account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "edafee4e-63fb-11ec-90d6-0242ac120003", + expectedCount: 2, + }, + { + name: "should return no peers for existing account ID with empty user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetUserPeers(context.Background(), LockingStrengthShare, tt.accountID, tt.userID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_DeletePeer(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "csrnkiq7qv9d8aitqd50" + + err = store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) + require.NoError(t, err) + + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) + require.Error(t, err) + require.Nil(t, peer) +} diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql index a9360e9d65c..15917f391ad 100644 --- a/management/server/testdata/store_policy_migrate.sql +++ b/management/server/testdata/store_policy_migrate.sql @@ -32,4 +32,5 @@ INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4h','bf1c8084-ba50-4ce7-9439-34653001fc3b','groupA','api','',0,''); INSERT INTO installations VALUES(1,''); From cde0e51c720883fa308c050dabfd060557c992eb Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 12:30:38 +0300 Subject: [PATCH 56/60] Refactor test names and remove duplicate TestPostgresql_SavePeerStatus Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 47 ++--------------------------- 1 file changed, 3 insertions(+), 44 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 6ee161c7973..eddd628bd14 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -377,7 +377,7 @@ func TestSqlite_GetAccount(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } -func TestSqlite_SavePeer(t *testing.T) { +func TestSqlStore_SavePeer(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -425,7 +425,7 @@ func TestSqlite_SavePeer(t *testing.T) { assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerStatus(t *testing.T) { +func TestSqlStore_SavePeerStatus(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -481,7 +481,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerLocation(t *testing.T) { +func TestSqlStore_SavePeerLocation(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -887,47 +887,6 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } -func TestPostgresql_SavePeerStatus(t *testing.T) { - if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { - t.Skip("skip CI tests on darwin and windows") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) - require.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus.Connected, actual.Connected) -} - func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { t.Skip("skip CI tests on darwin and windows") From f87bc601c60562a7a0e05627c60d4320860f5b7a Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 14:03:08 +0300 Subject: [PATCH 57/60] Add account locks and remove redundant ephemeral check Signed-off-by: bcmmbaga --- management/server/account.go | 8 ++++++-- management/server/ephemeral.go | 8 ++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index f4256954224..d25fe5b7991 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1244,10 +1244,11 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. return nil } - - func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { return 0, false @@ -1279,6 +1280,9 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 6e245ec5ac8..111d5e3fc81 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -127,15 +127,11 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { } t := newDeadLine() - count := 0 for _, p := range peers { - if p.Ephemeral { - count++ - e.addPeer(p.AccountID, p.ID, t) - } + e.addPeer(p.AccountID, p.ID, t) } - log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers)) } func (e *EphemeralManager) cleanup(ctx context.Context) { From 1ba6eb62a607fe3da9bf09abd79fe2cd1e3ec142 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 15:01:44 +0300 Subject: [PATCH 58/60] Retrieve all groups for peers and restrict groups for regular users Signed-off-by: bcmmbaga --- management/server/http/peers_handler.go | 47 ++++++++++++++----------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 235e744b351..4d0bdec2d70 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -62,12 +62,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID st } dnsDomain := h.accountManager.GetDNSDomain() - peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) - if err != nil { - util.WriteError(ctx, err, w) - return - } - groupsInfo := toGroupsInfo(peerGroups) + groups, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) + groupsInfo := toGroupsInfo(groups, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -116,7 +112,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID util.WriteError(ctx, err, w) return } - groupMinimumInfo := toGroupsInfo(peerGroups) + groupMinimumInfo := toGroupsInfo(peerGroups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -187,6 +183,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() + groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) @@ -195,12 +193,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - groupMinimumInfo := toGroupsInfo(peerGroups) + groupMinimumInfo := toGroupsInfo(groups, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -312,14 +305,28 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum { - groupsInfo := make([]api.GroupMinimum, 0, len(groups)) +func toGroupsInfo(groups []*nbgroup.Group, peerID string) []api.GroupMinimum { + groupsInfo := []api.GroupMinimum{} + groupsChecked := make(map[string]struct{}) + for _, group := range groups { - groupsInfo = append(groupsInfo, api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - }) + _, ok := groupsChecked[group.ID] + if ok { + continue + } + + groupsChecked[group.ID] = struct{}{} + for _, pk := range group.Peers { + if pk == peerID { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + } + groupsInfo = append(groupsInfo, info) + break + } + } } return groupsInfo } From d66140fc82af4a55c522df851876832c931807e9 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 15:08:42 +0300 Subject: [PATCH 59/60] Fix merge Signed-off-by: bcmmbaga --- management/server/peer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index 0547cbbbc99..9360ce29f3f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -880,7 +880,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(account, peer.ID) + postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } From 9a96b91d9d679e4bf6081c76a3e341e946f340f2 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 9 Dec 2024 14:21:28 +0100 Subject: [PATCH 60/60] Fix merge Signed-off-by: bcmmbaga --- management/server/peer.go | 85 +++++++++++++++++++++++++++++++++------ 1 file changed, 73 insertions(+), 12 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index f6e4de7b7f5..a174bc41529 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -649,6 +649,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac var isStatusChanged bool var updated bool var err error + var postureChecks []*posture.Checks err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey) @@ -690,7 +691,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) - err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { + return err + } + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) if err != nil { return err } @@ -701,12 +706,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, err } - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - - if isStatusChanged || (updated && sync.UpdateAccountPeers) || (updated && len(postureChecks) > 0){ + if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { am.updateAccountPeers(ctx, accountID) } @@ -764,6 +764,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) var isRequiresApproval bool var isStatusChanged bool var isPeerUpdated bool + var postureChecks []*posture.Checks err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) @@ -809,6 +810,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) if isPeerUpdated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) + if err != nil { + return err + } } if peer.SSHKey != login.SSHKey { @@ -831,11 +837,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer() unlockPeer = nil - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { am.updateAccountPeers(ctx, accountID) } @@ -843,6 +844,66 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } +// getPeerPostureChecks returns the posture checks for the peer. +func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, peerID string) ([]*posture.Checks, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + if len(policies) == 0 { + return nil, nil + } + + var peerPostureChecksIDs []string + + for _, policy := range policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID) + if err != nil { + return nil, err + } + + peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) + } + + peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, peerPostureChecksIDs) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks. +func processPeerPostureChecks(ctx context.Context, transaction Store, policy *Policy, accountID, peerID string) ([]string, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourceGroups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, rule.Sources) + if err != nil { + return nil, err + } + + for _, sourceGroup := range rule.Sources { + group, ok := sourceGroups[sourceGroup] + if !ok { + return nil, fmt.Errorf("failed to check peer in policy source group") + } + + if slices.Contains(group.Peers, peerID) { + return policy.SourcePostureChecks, nil + } + } + } + return nil, nil +} + // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO // and if the peer login is expired. // The NetBird client doesn't have a way to check if the peer needs login besides sending a login request