Skip to content

Commit

Permalink
[management] Refactor setup key to use store methods (#2861)
Browse files Browse the repository at this point in the history
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <[email protected]>

* add lock to get account groups

Signed-off-by: bcmmbaga <[email protected]>

* add check for regular user

Signed-off-by: bcmmbaga <[email protected]>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <[email protected]>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <[email protected]>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <[email protected]>

* Remove context from DB queries

Signed-off-by: bcmmbaga <[email protected]>

* Add user permission check and add setup events into events to store slice

Signed-off-by: bcmmbaga <[email protected]>

* Retrieve all groups once during setup key auto-group validation

Signed-off-by: bcmmbaga <[email protected]>

* Fix lint

Signed-off-by: bcmmbaga <[email protected]>

* Fix sonar

Signed-off-by: bcmmbaga <[email protected]>

---------

Signed-off-by: bcmmbaga <[email protected]>
  • Loading branch information
bcmmbaga authored Nov 11, 2024
1 parent e0bed2b commit 6cb697e
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 141 deletions.
4 changes: 2 additions & 2 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2029,7 +2029,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err)
}

groups, err := transaction.GetAccountGroups(ctx, accountID)
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
Expand Down Expand Up @@ -2059,7 +2059,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st

// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, accountID)
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2773,7 +2773,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")

groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")

Expand Down
2 changes: 1 addition & 1 deletion management/server/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
return nil, err
}

return am.Store.GetAccountGroups(ctx, accountID)
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
}

// GetGroupByName filters all groups in an account by name and returns the one with the most peers
Expand Down
5 changes: 5 additions & 0 deletions management/server/group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ func (g *Group) Copy() *Group {
func (g *Group) HasPeers() bool {
return len(g.Peers) > 0
}

// IsGroupAll checks if the group is a default "All" group.
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
2 changes: 1 addition & 1 deletion management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}

groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
Expand Down
236 changes: 148 additions & 88 deletions management/server/setupkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/fnv"
"slices"
"strconv"
"strings"
"time"
"unicode/utf8"

"github.com/google/uuid"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -229,32 +228,43 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
return nil, err
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, status.Errorf(status.Internal, "failed adding account key")
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}

am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
var setupKey *SetupKey
var plainKey string
var eventsToStore []func()

for _, g := range setupKey.AutoGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
return err
}

setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
setupKey.AccountID = accountID

events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
eventsToStore = append(eventsToStore, events...)

return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
})
if err != nil {
return nil, err
}

am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
for _, storeEvent := range eventsToStore {
storeEvent()
}

// for the creation return the plain key to the caller
Expand All @@ -268,74 +278,66 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
}

account, err := am.Store.GetAccount(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

var oldKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyToSave.Id {
oldKey = key.Copy()
break
}
}
if oldKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
return nil, err
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}

// only auto groups, revoked status, and name can be updated for now
newKey := oldKey.Copy()
newKey.Name = keyToSave.Name
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
var oldKey *SetupKey
var newKey *SetupKey
var eventsToStore []func()

err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
return err
}

account.SetupKeys[newKey.Key] = newKey
oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
if err != nil {
return err
}

if err = am.Store.SaveAccount(ctx, account); err != nil {
// only auto groups, revoked status, and name can be updated for now
newKey = oldKey.Copy()
newKey.Name = keyToSave.Name
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()

addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)

events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
eventsToStore = append(eventsToStore, events...)

return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
})
if err != nil {
return nil, err
}

if !oldKey.Revoked && newKey.Revoked {
am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta())
}

defer func() {
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}

}

for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
}
}()
for _, storeEvent := range eventsToStore {
storeEvent()
}

return newKey, nil
}
Expand All @@ -347,16 +349,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err
}

if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}

return setupKeys, nil
return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
}

// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
Expand All @@ -366,8 +367,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, err
}

if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}

setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
Expand All @@ -387,37 +392,92 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
return err
}

if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}

deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}

err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
var deletedSetupKey *SetupKey

err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil {
return err
}

return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
})
if err != nil {
return fmt.Errorf("failed to delete setup key: %w", err)
return err
}

am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())

return nil
}

func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
for _, group := range autoGroups {
g, ok := account.Groups[group]
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
if err != nil {
return err
}

for _, groupID := range autoGroupIDs {
group, ok := groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
return status.Errorf(status.NotFound, "group not found: %s", groupID)
}
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")

if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
}
}

return nil
}

// prepareSetupKeyEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
var eventsToStore []func()

modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
return nil
}

for _, g := range removedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
continue
}

eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
})
}

for _, g := range addedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
continue
}

eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
})
}

return eventsToStore
}
Loading

0 comments on commit 6cb697e

Please sign in to comment.