diff --git a/management/server/dns.go b/management/server/dns.go index e52be601639..8df211b0b0b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "slices" "strconv" "sync" @@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) @@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + if dnsSettingsToSave == nil { + return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") + } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") + return status.NewAdminPermissionError() } - if dnsSettingsToSave == nil { - return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") - } + var updateAccountPeers bool + var eventsToStore []func() - if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { - err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { + return err + } + + oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) if err != nil { return err } - } - oldSettings := account.DNSSettings.Copy() - account.DNSSettings = dnsSettingsToSave.Copy() + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) + if err != nil { + return err + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } + events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) + eventsToStore = append(eventsToStore, events...) + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - for _, id := range addedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + }) + if err != nil { + return err } - for _, id := range removedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + for _, storeEvent := range eventsToStore { + storeEvent() } - if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } return nil } +// prepareDNSSettingsEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { + var eventsToStore []func() + + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) + return nil + } + + for _, groupID := range addedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + }) + + } + + for _, groupID := range removedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + }) + } + + return eventsToStore +} + +// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) +} + +// validateDNSSettings validates the DNS settings. +func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { + if len(settings.DisabledManagementGroups) == 0 { + return nil + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + if err != nil { + return err + } + + return validateGroups(settings.DisabledManagementGroups, groups) +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9a24857d10b..f58ceb1ad85 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "dns settings not found") + return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get dns settings from store") } return &accountDNSSettings.DNSSettings, nil } @@ -1537,3 +1538,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } + +// SaveDNSSettings saves the DNS settings to the store. +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save dns settings to store") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index c05793fc624..df5294d7391 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1857,3 +1857,66 @@ func TestSqlStore_DeletePolicy(t *testing.T) { require.Error(t, err) require.Nil(t, policy) } + +func TestSqlStore_GetDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectError bool + }{ + { + name: "retrieve existing account dns settings", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectError: false, + }, + { + name: "retrieve non-existing account dns settings", + accountID: "non-existing", + expectError: true, + }, + { + name: "retrieve dns settings with empty account ID", + accountID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, dnsSettings) + } else { + require.NoError(t, err) + require.NotNil(t, dnsSettings) + } + }) + } +} + +func TestSqlStore_SaveDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + + dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"} + err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings) + require.NoError(t, err) + + saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, saveDNSSettings, dnsSettings) +} diff --git a/management/server/store.go b/management/server/store.go index ba61d552d72..cca014b5214 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -59,6 +59,7 @@ type Store interface { SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)