diff --git a/management/server/account.go b/management/server/account.go index 5eaa8320b09..984139a12d6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -242,6 +242,11 @@ type Account struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +// Subclass used in gorm to only load settings and not whole account +type AccountSettings struct { + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` +} + type UserPermissions struct { DashboardView string `json:"dashboard_view"` } diff --git a/management/server/file_store.go b/management/server/file_store.go index d1997a3bf0e..60497824caf 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -509,7 +509,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { accountID, ok := s.UserID2AccountID[userID] if !ok { - return nil, status.Errorf(status.NotFound, "account not found") + return nil, status.NewUserNotFoundError(userID) } account, err := s.getAccount(accountID) @@ -540,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { if _, ok := account.Peers[peerID]; !ok { delete(s.PeerID2AccountID, peerID) log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) - return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerID) + return nil, status.NewPeerNotFoundError(peerID) } return account.Copy(), nil @@ -553,7 +553,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { accountID, ok := s.PeerKeyID2AccountID[peerKey] if !ok { - return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey) + return nil, status.NewPeerNotFoundError(peerKey) } account, err := s.getAccount(accountID) @@ -573,7 +573,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { if stale { delete(s.PeerKeyID2AccountID, peerKey) log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) - return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerKey) + return nil, status.NewPeerNotFoundError(peerKey) } return account.Copy(), nil @@ -585,12 +585,71 @@ func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { accountID, ok := s.PeerKeyID2AccountID[peerKey] if !ok { - return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey) + return "", status.NewPeerNotFoundError(peerKey) } return accountID, nil } +func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.UserID2AccountID[userID] + if !ok { + return "", status.NewUserNotFoundError(userID) + } + + return accountID, nil +} + +func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] + if !ok { + return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") + } + + return accountID, nil +} + +func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.PeerKeyID2AccountID[peerKey] + if !ok { + return nil, status.NewPeerNotFoundError(peerKey) + } + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + for _, peer := range account.Peers { + if peer.Key == peerKey { + return peer.Copy(), nil + } + } + + return nil, status.NewPeerNotFoundError(peerKey) +} + +func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Settings.Copy(), nil +} + // GetInstallationID returns the installation ID from the store func (s *FileStore) GetInstallationID() string { return s.InstallationID diff --git a/management/server/peer.go b/management/server/peer.go index 08455a0f90f..13ac3801daa 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -335,24 +335,29 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } upperKey := strings.ToUpper(setupKey) - var account *Account + var accountID string var err error addedByUser := false if len(userID) > 0 { addedByUser = true - account, err = am.Store.GetAccountByUser(userID) + accountID, err = am.Store.GetAccountIDByUserID(userID) } else { - account, err = am.Store.GetAccountBySetupKey(setupKey) + accountID, err = am.Store.GetAccountIDBySetupKey(setupKey) } if err != nil { return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") } - unlock := am.Store.AcquireAccountWriteLock(account.Id) - defer unlock() + unlock := am.Store.AcquireAccountWriteLock(accountID) + defer func() { + if unlock != nil { + unlock() + } + }() + var account *Account // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(accountID) if err != nil { return nil, nil, err } @@ -485,6 +490,10 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P return nil, nil, err } + // Account is saved, we can release the lock + unlock() + unlock = nil + opEvent.TargetID = newPeer.ID opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) if !addedByUser { @@ -507,7 +516,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { - return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") + return nil, nil, status.NewPeerNotRegisteredError() } err = checkIfPeerOwnerIsBlocked(peer, account) @@ -515,7 +524,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp return nil, nil, err } - if peerLoginExpired(peer, account) { + if peerLoginExpired(peer, account.Settings) { return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } @@ -545,7 +554,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { - account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey) + accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. @@ -574,19 +583,59 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") } - // we found the peer, and we follow a normal login flow - unlock := am.Store.AcquireAccountWriteLock(account.Id) - defer unlock() + peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey) + if err != nil { + return nil, nil, status.NewPeerNotRegisteredError() + } + + accSettings, err := am.Store.GetAccountSettings(accountID) + if err != nil { + return nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err) + } + + var isWriteLock bool + + // duplicated logic from after the lock to have an early exit + expired := peerLoginExpired(peer, accSettings) + switch { + case expired: + if err := checkAuth(login.UserID, peer); err != nil { + return nil, nil, err + } + isWriteLock = true + log.Debugf("peer login expired, acquiring write lock") + + case peer.UpdateMetaIfNew(login.Meta): + isWriteLock = true + log.Debugf("peer changed meta, acquiring write lock") + + default: + isWriteLock = false + log.Debugf("peer meta is the same, acquiring read lock") + } + + var unlock func() + + if isWriteLock { + unlock = am.Store.AcquireAccountWriteLock(accountID) + } else { + unlock = am.Store.AcquireAccountReadLock(accountID) + } + defer func() { + if unlock != nil { + unlock() + } + }() // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies - account, err = am.Store.GetAccount(account.Id) + account, err := am.Store.GetAccount(accountID) if err != nil { return nil, nil, err } - peer, err := account.FindPeerByPubKey(login.WireGuardPubKey) + peer, err = account.FindPeerByPubKey(login.WireGuardPubKey) if err != nil { - return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") + return nil, nil, status.NewPeerNotRegisteredError() } err = checkIfPeerOwnerIsBlocked(peer, account) @@ -597,7 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw // this flag prevents unnecessary calls to the persistent store. shouldStoreAccount := false updateRemotePeers := false - if peerLoginExpired(peer, account) { + if peerLoginExpired(peer, account.Settings) { err = checkAuth(login.UserID, peer) if err != nil { return nil, nil, err @@ -633,11 +682,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } if shouldStoreAccount { + if !isWriteLock { + log.Errorf("account %s should be stored but is not write locked", accountID) + return nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked") + } err = am.Store.SaveAccount(account) if err != nil { return nil, nil, err } } + unlock() + unlock = nil if updateRemotePeers || isStatusChanged { am.updateAccountPeers(account) @@ -683,9 +738,9 @@ func checkAuth(loginUserID string, peer *nbpeer.Peer) error { return nil } -func peerLoginExpired(peer *nbpeer.Peer, account *Account) bool { - expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration) - expired = account.Settings.PeerLoginExpirationEnabled && expired +func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool { + expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration) + expired = settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn) return true diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 6f6a70ee35e..56136327a59 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -458,7 +458,6 @@ func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting user from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -521,6 +520,61 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { return accountID, nil } +func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { + var user User + var accountID string + result := s.db.Model(&user).Select("account_id").Where("id = ?", userID).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + return "", status.Errorf(status.Internal, "issue getting account from store") + } + + return accountID, nil +} + +func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) { + var key SetupKey + var accountID string + result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting setup key from the store: %s", result.Error) + return "", status.Errorf(status.Internal, "issue getting setup key from store") + } + + return accountID, nil +} + +func (s *SqlStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { + var peer nbpeer.Peer + result := s.db.First(&peer, "key = ?", peerKey) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.Errorf("error when getting peer from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting peer from store") + } + + return &peer, nil +} + +func (s *SqlStore) GetAccountSettings(accountID string) (*Settings, error) { + var accountSettings AccountSettings + if err := s.db.Model(&Account{}).Where("id = ?", accountID).First(&accountSettings).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "settings not found") + } + log.Errorf("error when getting settings from the store: %s", err) + return nil, status.Errorf(status.Internal, "issue getting settings from store") + } + return accountSettings.Settings, nil +} + // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { var user User @@ -530,7 +584,6 @@ func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Ti if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "user %s not found", userID) } - log.Errorf("error when getting user from the store: %s", result.Error) return status.Errorf(status.Internal, "issue getting user from store") } diff --git a/management/server/status/error.go b/management/server/status/error.go index 66e46151948..39cd6c613e2 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -75,3 +75,23 @@ func FromError(err error) (s *Error, ok bool) { } return nil, false } + +// NewPeerNotFoundError creates a new Error with NotFound type for a missing peer +func NewPeerNotFoundError(peerKey string) error { + return Errorf(NotFound, "peer not found: %s", peerKey) +} + +// NewAccountNotFoundError creates a new Error with NotFound type for a missing account +func NewAccountNotFoundError(accountKey string) error { + return Errorf(NotFound, "account not found: %s", accountKey) +} + +// NewUserNotFoundError creates a new Error with NotFound type for a missing user +func NewUserNotFoundError(userKey string) error { + return Errorf(NotFound, "user not found: %s", userKey) +} + +// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer +func NewPeerNotRegisteredError() error { + return Errorf(Unauthenticated, "peer is not registered") +} diff --git a/management/server/store.go b/management/server/store.go index a1824351da5..5210f1210c8 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -27,6 +27,8 @@ type Store interface { GetAccountByUser(userID string) (*Account, error) GetAccountByPeerPubKey(peerKey string) (*Account, error) GetAccountIDByPeerPubKey(peerKey string) (string, error) + GetAccountIDByUserID(peerKey string) (string, error) + GetAccountIDBySetupKey(peerKey string) (string, error) GetAccountByPeerID(peerID string) (*Account, error) GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(domain string) (*Account, error) @@ -52,6 +54,8 @@ type Store interface { // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine + GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) + GetAccountSettings(accountID string) (*Settings, error) } type StoreEngine string