Skip to content

Commit

Permalink
[management] Refactor DNS settings to use store methods (#2883)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga authored Nov 26, 2024
1 parent f118d81 commit 0e48a77
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 34 deletions.
142 changes: 110 additions & 32 deletions management/server/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"slices"
"strconv"
"sync"

Expand Down Expand Up @@ -85,73 +86,150 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err
}

if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}

return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
}

// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}

user, err := account.FindUser(userID)
if err != nil {
return err
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}

if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
return status.NewAdminPermissionError()
}

if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
var updateAccountPeers bool
var eventsToStore []func()

if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}

oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
}

oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)

addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}

account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)

if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}

for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
})
if err != nil {
return err
}

for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
for _, storeEvent := range eventsToStore {
storeEvent()
}

if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}

return nil
}

// prepareDNSSettingsEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()

modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
}

for _, groupID := range addedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
continue
}

eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
})

}

for _, groupID := range removedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
continue
}

eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
})
}

return eventsToStore
}

// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}

if hasPeers {
return true, nil
}

return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
}

// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {
return nil
}

groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}

return validateGroups(settings.DisabledManagementGroups, groups)
}

// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
Expand Down
21 changes: 19 additions & 2 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
}
return &accountDNSSettings.DNSSettings, nil
}
Expand Down Expand Up @@ -1537,3 +1538,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
}
return &record, nil
}

// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save dns settings to store")
}

if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}

return nil
}
63 changes: 63 additions & 0 deletions management/server/sql_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1857,3 +1857,66 @@ func TestSqlStore_DeletePolicy(t *testing.T) {
require.Error(t, err)
require.Nil(t, policy)
}

func TestSqlStore_GetDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)

tests := []struct {
name string
accountID string
expectError bool
}{
{
name: "retrieve existing account dns settings",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectError: false,
},
{
name: "retrieve non-existing account dns settings",
accountID: "non-existing",
expectError: true,
},
{
name: "retrieve dns settings with empty account ID",
accountID: "",
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, dnsSettings)
} else {
require.NoError(t, err)
require.NotNil(t, dnsSettings)
}
})
}
}

func TestSqlStore_SaveDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)

accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"

dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)

dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"}
err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings)
require.NoError(t, err)

saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Equal(t, saveDNSSettings, dnsSettings)
}
1 change: 1 addition & 0 deletions management/server/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type Store interface {
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error

GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
Expand Down

0 comments on commit 0e48a77

Please sign in to comment.