diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 9119a3dec72..e7a5387a142 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + newNSGroup := &nbdns.NameServerGroup{ ID: xid.New().String(), + AccountID: accountID, Name: name, Description: description, NameServers: nameServerList, @@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco SearchDomainsEnabled: searchDomainEnabled, } - err = validateNameServerGroup(false, newNSGroup, account) - if err != nil { - return nil, err - } + var updateAccountPeers bool - if account.NameServerGroups == nil { - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup) - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { + return err + } - account.NameServerGroups[newNSGroup.ID] = newNSGroup + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + if err != nil { + return err + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + }) + if err != nil { return nil, err } - if am.anyGroupHasPeers(account, newNSGroup.Groups) { + am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil } @@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - err = validateNameServerGroup(true, nsGroupToSave, account) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] - account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + if err != nil { + return err + } + nsGroupToSave.AccountID = accountID + + if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil { + return err + } + + updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + }) + if err != nil { return err } - if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil } // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - nsGroup := account.NameServerGroups[nsGroupID] - if nsGroup == nil { - return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - delete(account.NameServerGroups, nsGroupID) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + var nsGroup *nbdns.NameServerGroup + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + if err != nil { + return err + } + + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + }) + if err != nil { return err } - if am.anyGroupHasPeers(account, nsGroup.Groups) { + am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil } @@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } -func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { - nsGroupID := "" - if existingGroup { - nsGroupID = nameserverGroup.ID - _, found := account.NameServerGroups[nsGroupID] - if !found { - return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID) - } +func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { + err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) + if err != nil { + return err } - err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) + err = validateNSList(nameserverGroup.NameServers) if err != nil { return err } - err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { return err } - err = validateNSList(nameserverGroup.NameServers) + err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups) if err != nil { return err } - err = validateGroups(nameserverGroup.Groups, account.Groups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) if err != nil { return err } - return nil + return validateGroups(nameserverGroup.Groups, groups) +} + +// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { + if !newNSGroup.Enabled && !oldNSGroup.Enabled { + return false, nil + } + + hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { @@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo return nil } -func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { +func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) } - for _, nsGroup := range nsGroupMap { + for _, nsGroup := range groups { if name == nsGroup.Name && nsGroup.ID != nsGroupID { - return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) + return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name) } } @@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na } func validateNSList(list []nbdns.NameServer) error { - nsListLenght := len(list) - if nsListLenght == 0 || nsListLenght > 3 { + nsListLength := len(list) + if nsListLength == 0 || nsListLength > 3 { return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) } return nil @@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if id == "" { return status.Errorf(status.InvalidArgument, "group ID should not be empty string") } - found := false - for groupID := range groups { - if id == groupID { - found = true - break - } - } - if !found { + if _, found := groups[id]; !found { return status.Errorf(status.InvalidArgument, "group id %s not found", id) } } @@ -277,11 +340,3 @@ func validateDomain(domain string) error { return nil } - -// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { - if !newNSGroup.Enabled && !oldNSGroup.Enabled { - return false - } - return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) -} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index f58ceb1ad85..1fd8ae2aabe 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1498,12 +1498,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID) + var nsGroups []*nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server groups from store") + } + + return nsGroups, nil } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. -func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { - return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { + var nsGroup *nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewNameServerGroupNotFoundError(nsGroupID) + } + log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server group from store") + } + + return nsGroup, nil +} + +// SaveNameServerGroup saves a name server group to the database. +func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) + return status.Errorf(status.Internal, "failed to save name server group to store") + } + return nil +} + +// DeleteNameServerGroup deletes a name server group from the database. +func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete name server group from store") + } + + if result.RowsAffected == 0 { + return status.NewNameServerGroupNotFoundError(nsGroupID) + } + + return nil } // getRecords retrieves records from the database based on the account ID. diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index df5294d7391..6064b019f29 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1920,3 +1920,134 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { require.NoError(t, err) require.Equal(t, saveDNSSettings, dnsSettings) } + +func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve name server groups by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetNameServerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + nsGroupID string + expectError bool + }{ + { + name: "retrieve existing nameserver group", + nsGroupID: "csqdelq7qv97ncu7d9t0", + expectError: false, + }, + { + name: "retrieve non-existing nameserver group", + nsGroupID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty nameserver group ID", + nsGroupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, nsGroup) + } else { + require.NoError(t, err) + require.NotNil(t, nsGroup) + require.Equal(t, tt.nsGroupID, nsGroup.ID) + } + }) + } +} + +func TestSqlStore_SaveNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + nsGroup := &nbdns.NameServerGroup{ + ID: "ns-group-id", + AccountID: accountID, + Name: "NS Group", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: 1, + Port: 53, + }, + }, + Groups: []string{"groupA"}, + Primary: true, + Enabled: true, + SearchDomainsEnabled: false, + } + + err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup) + require.NoError(t, err) + + saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID) + require.NoError(t, err) + require.Equal(t, saveNSGroup, nsGroup) +} + +func TestSqlStore_DeleteNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + nsGroupID := "csqdelq7qv97ncu7d9t0" + + err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.NoError(t, err) + + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.Error(t, err) + require.Nil(t, nsGroup) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 0fff5355994..59f436f5b19 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -149,3 +149,8 @@ func NewPostureChecksNotFoundError(postureChecksID string) error { func NewPolicyNotFoundError(policyID string) error { return Errorf(NotFound, "policy: %s not found", policyID) } + +// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group +func NewNameServerGroupNotFoundError(nsGroupID string) error { + return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) +} diff --git a/management/server/store.go b/management/server/store.go index cca014b5214..b16ad8a1aa4 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -117,6 +117,8 @@ type Store interface { GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 37db2731625..455111439ea 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -36,4 +36,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); +INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO installations VALUES(1,'');