Skip to content

Commit

Permalink
[management] Refactor nameserver groups to use store methods (#2888)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga authored Nov 26, 2024
1 parent 0e48a77 commit 9683da5
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 76 deletions.
201 changes: 128 additions & 73 deletions management/server/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err
}

if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}

return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
}

// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {

unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Description: description,
NameServers: nameServerList,
Expand All @@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}

err = validateNameServerGroup(false, newNSGroup, account)
if err != nil {
return nil, err
}
var updateAccountPeers bool

if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}

account.NameServerGroups[newNSGroup.ID] = newNSGroup
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}

account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}

return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
})
if err != nil {
return nil, err
}

if am.anyGroupHasPeers(account, newNSGroup.Groups) {
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())

if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())

return newNSGroup.Copy(), nil
}
Expand All @@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}

err = validateNameServerGroup(true, nsGroupToSave, account)
if err != nil {
return err
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}

oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
var updateAccountPeers bool

err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
nsGroupToSave.AccountID = accountID

if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
return err
}

updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}

if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}

account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
})
if err != nil {
return err
}

if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())

if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())

return nil
}

// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {

unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}

nsGroup := account.NameServerGroups[nsGroupID]
if nsGroup == nil {
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
delete(account.NameServerGroups, nsGroupID)

account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool

err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
if err != nil {
return err
}

updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}

if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}

return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
})
if err != nil {
return err
}

if am.anyGroupHasPeers(account, nsGroup.Groups) {
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

return nil
}
Expand All @@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err
}

if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}

if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}

return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}

func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil {
return err
}

err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
err = validateNSList(nameserverGroup.NameServers)
if err != nil {
return err
}

err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}

err = validateNSList(nameserverGroup.NameServers)
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
if err != nil {
return err
}

err = validateGroups(nameserverGroup.Groups, account.Groups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
if err != nil {
return err
}

return nil
return validateGroups(nameserverGroup.Groups, groups)
}

// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}

hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}

if hasPeers {
return true, nil
}

return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}

func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
Expand All @@ -213,23 +283,23 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
return nil
}

func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}

for _, nsGroup := range nsGroupMap {
for _, nsGroup := range groups {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
}
}

return nil
}

func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 3 {
nsListLength := len(list)
if nsListLength == 0 || nsListLength > 3 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
}
return nil
Expand All @@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
if _, found := groups[id]; !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
}
}
Expand All @@ -277,11 +340,3 @@ func validateDomain(domain string) error {

return nil
}

// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups)
}
49 changes: 46 additions & 3 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -1498,12 +1498,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren

// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
var nsGroups []*nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
}

return nsGroups, nil
}

// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
var nsGroup *nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
}
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
}

return nsGroup, nil
}

// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
}

// DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete name server group from store")
}

if result.RowsAffected == 0 {
return status.NewNameServerGroupNotFoundError(nsGroupID)
}

return nil
}

// getRecords retrieves records from the database based on the account ID.
Expand Down
Loading

0 comments on commit 9683da5

Please sign in to comment.