Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into feature/ipv6-networ…
Browse files Browse the repository at this point in the history
…k-support
  • Loading branch information
pulsastrix committed Mar 11, 2024
2 parents e2d6dff + 5dde044 commit e49664f
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

### Community projects
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)

**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
Expand Down
21 changes: 21 additions & 0 deletions client/internal/peer/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"sync"
"time"

"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"

"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
)
Expand Down Expand Up @@ -377,6 +380,24 @@ func (d *Status) GetManagementState() ManagementState {
}
}

// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.Lock()
defer d.mux.Unlock()

// if peer is connected to the management then login is not expired
if d.managementState {
return false
}

s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true

}
return false
}

func (d *Status) GetSignalState() SignalState {
return SignalState{
d.signalAddress,
Expand Down
82 changes: 82 additions & 0 deletions client/internal/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package internal

import (
"context"
"os/exec"
"strings"
"sync"
"time"

"github.com/netbirdio/netbird/client/internal/peer"
)

type SessionWatcher struct {
ctx context.Context
mutex sync.Mutex

peerStatusRecorder *peer.Status
watchTicker *time.Ticker

sendNotification bool
onExpireListener func()
}

// NewSessionWatcher creates a new instance of SessionWatcher.
func NewSessionWatcher(ctx context.Context, peerStatusRecorder *peer.Status) *SessionWatcher {
s := &SessionWatcher{
ctx: ctx,
peerStatusRecorder: peerStatusRecorder,
watchTicker: time.NewTicker(2 * time.Second),
}
go s.startWatcher()
return s
}

// SetOnExpireListener sets the callback func to be called when the session expires.
func (s *SessionWatcher) SetOnExpireListener(onExpire func()) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.onExpireListener = onExpire
}

// startWatcher continuously checks if the session requires login and
// calls the onExpireListener if login is required.
func (s *SessionWatcher) startWatcher() {
for {
select {
case <-s.ctx.Done():
s.watchTicker.Stop()
return
case <-s.watchTicker.C:
managementState := s.peerStatusRecorder.GetManagementState()
if managementState.Connected {
s.sendNotification = true
}

isLoginRequired := s.peerStatusRecorder.IsLoginRequired()
if isLoginRequired && s.sendNotification && s.onExpireListener != nil {
s.mutex.Lock()
s.onExpireListener()
s.sendNotification = false
s.mutex.Unlock()
}
}
}
}

// CheckUIApp checks whether UI application is running.
func CheckUIApp() bool {
cmd := exec.Command("ps", "-ef")
output, err := cmd.Output()
if err != nil {
return false
}

lines := strings.Split(string(output), "\n")
for _, line := range lines {
if strings.Contains(line, "netbird-ui") && !strings.Contains(line, "grep") {
return true
}
}
return false
}
47 changes: 47 additions & 0 deletions client/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package server
import (
"context"
"fmt"
"os/exec"
"runtime"
"sync"
"time"

Expand Down Expand Up @@ -39,6 +41,7 @@ type Server struct {
proto.UnimplementedDaemonServiceServer

statusRecorder *peer.Status
sessionWatcher *internal.SessionWatcher

mgmProbe *internal.Probe
signalProbe *internal.Probe
Expand Down Expand Up @@ -116,6 +119,11 @@ func (s *Server) Start() error {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)

if s.sessionWatcher == nil {
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}

if !config.DisableAutoConnect {
go func() {
if err := internal.RunClientWithProbes(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil {
Expand Down Expand Up @@ -542,6 +550,17 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
}, nil
}

func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
if !isUIActive {
if err := sendTerminalNotification(); err != nil {
log.Errorf("send session expire terminal notification: %v", err)
}
}
}
}

func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
Expand Down Expand Up @@ -604,3 +623,31 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {

return &pbFullStatus
}

// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {
message := "NetBird connection session expired\n\nPlease re-authenticate to connect to the network."
echoCmd := exec.Command("echo", message)
wallCmd := exec.Command("sudo", "wall")

echoCmdStdout, err := echoCmd.StdoutPipe()
if err != nil {
return err
}
wallCmd.Stdin = echoCmdStdout

if err := echoCmd.Start(); err != nil {
return err
}

if err := wallCmd.Start(); err != nil {
return err
}

if err := echoCmd.Wait(); err != nil {
return err
}

return wallCmd.Wait()
}
68 changes: 56 additions & 12 deletions management/server/sqlite_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"errors"
"fmt"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -255,7 +256,11 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer

result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
}

peer.Status = &peerStatus
Expand All @@ -267,7 +272,11 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee
var peer nbpeer.Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID)
if result.Error != nil {
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
}

peer.Location = peerWithLocation.Location
Expand All @@ -291,7 +300,11 @@ func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error)
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.Errorf("error when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

// TODO: rework to not call GetAccount
Expand All @@ -302,7 +315,11 @@ func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting setup key from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
}

if key.AccountID == "" {
Expand All @@ -316,7 +333,11 @@ func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting token from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}

return token.ID, nil
Expand All @@ -326,7 +347,11 @@ func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, "id = ?", tokenID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting token from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if token.UserID == "" {
Expand Down Expand Up @@ -370,8 +395,11 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
Preload(clause.Associations).
First(&account, "id = ?", accountID)
if result.Error != nil {
log.Errorf("when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.NotFound, "account not found")
log.Errorf("error when getting account from the store: %s", result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
Expand Down Expand Up @@ -431,7 +459,11 @@ func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, "id = ?", userID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if user.AccountID == "" {
Expand All @@ -445,7 +477,11 @@ func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if peer.AccountID == "" {
Expand All @@ -460,7 +496,11 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {

result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if peer.AccountID == "" {
Expand All @@ -476,7 +516,11 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time

result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
if result.Error != nil {
return status.Errorf(status.NotFound, "user %s not found", userID)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
}
log.Errorf("error when getting user from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting user from store")
}

user.LastLogin = lastLogin
Expand Down
Loading

0 comments on commit e49664f

Please sign in to comment.