Skip to content

Commit

Permalink
Improve login performance (netbirdio#2061)
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer authored May 31, 2024
1 parent f9ec0a9 commit 521f7dd
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 26 deletions.
5 changes: 5 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ type Account struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}

// Subclass used in gorm to only load settings and not whole account
type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}

type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
Expand Down
69 changes: 64 additions & 5 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {

accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewUserNotFoundError(userID)
}

account, err := s.getAccount(accountID)
Expand Down Expand Up @@ -540,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
if _, ok := account.Peers[peerID]; !ok {
delete(s.PeerID2AccountID, peerID)
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerID)
return nil, status.NewPeerNotFoundError(peerID)
}

return account.Copy(), nil
Expand All @@ -553,7 +553,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {

accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
return nil, status.NewPeerNotFoundError(peerKey)
}

account, err := s.getAccount(accountID)
Expand All @@ -573,7 +573,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
if stale {
delete(s.PeerKeyID2AccountID, peerKey)
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerKey)
return nil, status.NewPeerNotFoundError(peerKey)
}

return account.Copy(), nil
Expand All @@ -585,12 +585,71 @@ func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {

accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
return "", status.NewPeerNotFoundError(peerKey)
}

return accountID, nil
}

func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, ok := s.UserID2AccountID[userID]
if !ok {
return "", status.NewUserNotFoundError(userID)
}

return accountID, nil
}

func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}

return accountID, nil
}

func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return nil, status.NewPeerNotFoundError(peerKey)
}

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

for _, peer := range account.Peers {
if peer.Key == peerKey {
return peer.Copy(), nil
}
}

return nil, status.NewPeerNotFoundError(peerKey)
}

func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

return account.Settings.Copy(), nil
}

// GetInstallationID returns the installation ID from the store
func (s *FileStore) GetInstallationID() string {
return s.InstallationID
Expand Down
93 changes: 74 additions & 19 deletions management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,24 +335,29 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
}

upperKey := strings.ToUpper(setupKey)
var account *Account
var accountID string
var err error
addedByUser := false
if len(userID) > 0 {
addedByUser = true
account, err = am.Store.GetAccountByUser(userID)
accountID, err = am.Store.GetAccountIDByUserID(userID)
} else {
account, err = am.Store.GetAccountBySetupKey(setupKey)
accountID, err = am.Store.GetAccountIDBySetupKey(setupKey)
}
if err != nil {
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}

unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer func() {
if unlock != nil {
unlock()
}
}()

var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
account, err = am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -485,6 +490,10 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
return nil, nil, err
}

// Account is saved, we can release the lock
unlock()
unlock = nil

opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
Expand All @@ -507,15 +516,15 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
return nil, nil, status.NewPeerNotRegisteredError()
}

err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
}

if peerLoginExpired(peer, account) {
if peerLoginExpired(peer, account.Settings) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}

Expand Down Expand Up @@ -545,7 +554,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey)
accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
Expand Down Expand Up @@ -574,19 +583,59 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
return nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
}

// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.NewPeerNotRegisteredError()
}

accSettings, err := am.Store.GetAccountSettings(accountID)
if err != nil {
return nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
}

var isWriteLock bool

// duplicated logic from after the lock to have an early exit
expired := peerLoginExpired(peer, accSettings)
switch {
case expired:
if err := checkAuth(login.UserID, peer); err != nil {
return nil, nil, err
}
isWriteLock = true
log.Debugf("peer login expired, acquiring write lock")

case peer.UpdateMetaIfNew(login.Meta):
isWriteLock = true
log.Debugf("peer changed meta, acquiring write lock")

default:
isWriteLock = false
log.Debugf("peer meta is the same, acquiring read lock")
}

var unlock func()

if isWriteLock {
unlock = am.Store.AcquireAccountWriteLock(accountID)
} else {
unlock = am.Store.AcquireAccountReadLock(accountID)
}
defer func() {
if unlock != nil {
unlock()
}
}()

// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err = am.Store.GetAccount(account.Id)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
}

peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
peer, err = account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
return nil, nil, status.NewPeerNotRegisteredError()
}

err = checkIfPeerOwnerIsBlocked(peer, account)
Expand All @@ -597,7 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
// this flag prevents unnecessary calls to the persistent store.
shouldStoreAccount := false
updateRemotePeers := false
if peerLoginExpired(peer, account) {
if peerLoginExpired(peer, account.Settings) {
err = checkAuth(login.UserID, peer)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -633,11 +682,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}

if shouldStoreAccount {
if !isWriteLock {
log.Errorf("account %s should be stored but is not write locked", accountID)
return nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
}
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
}
}
unlock()
unlock = nil

if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(account)
Expand Down Expand Up @@ -683,9 +738,9 @@ func checkAuth(loginUserID string, peer *nbpeer.Peer) error {
return nil
}

func peerLoginExpired(peer *nbpeer.Peer, account *Account) bool {
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired
func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool {
expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration)
expired = settings.PeerLoginExpirationEnabled && expired
if expired || peer.Status.LoginExpired {
log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
return true
Expand Down
57 changes: 55 additions & 2 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

Expand Down Expand Up @@ -521,6 +520,61 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
return accountID, nil
}

func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string
result := s.db.Model(&user).Select("account_id").Where("id = ?", userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.Errorf(status.Internal, "issue getting account from store")
}

return accountID, nil
}

func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
var key SetupKey
var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting setup key from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting setup key from store")
}

return accountID, nil
}

func (s *SqlStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}

return &peer, nil
}

func (s *SqlStore) GetAccountSettings(accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where("id = ?", accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
log.Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}

// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
Expand All @@ -530,7 +584,6 @@ func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Ti
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
}
log.Errorf("error when getting user from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting user from store")
}

Expand Down
Loading

0 comments on commit 521f7dd

Please sign in to comment.