Skip to content

Commit

Permalink
Rework peer connection status based on the update channel existence (n…
Browse files Browse the repository at this point in the history
…etbirdio#1213)

With this change, we don't need to update all peers on startup. We will
check the existence of an update channel when returning a list or single peer on API.
Then after restarting of server consumers of API will see peer not
connected status till the creation of an updated channel which indicates
peer successful connection.
  • Loading branch information
surik authored Oct 11, 2023
1 parent 608a4eb commit f5cb207
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 11 deletions.
6 changes: 6 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ type AccountManager interface {
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
}

type DefaultAccountManager struct {
Expand Down Expand Up @@ -1558,6 +1559,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
}
}

// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil
}

func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
Expand Down
4 changes: 0 additions & 4 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ func restore(file string) (*FileStore, error) {
for _, peer := range account.Peers {
store.PeerKeyID2AccountID[peer.Key] = accountID
store.PeerID2AccountID[peer.ID] = accountID
// reset all peers to status = Disconnected
if peer.Status != nil && peer.Status.Connected {
peer.Status.Connected = false
}
}
for _, user := range account.Users {
store.UserID2AccountID[user.Id] = accountID
Expand Down
33 changes: 31 additions & 2 deletions management/server/http/peers_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,38 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee
}
}

func (h *PeersHandler) checkPeerStatus(peer *server.Peer) (*server.Peer, error) {
peerToReturn := peer.Copy()
if peer.Status.Connected {
statuses, err := h.accountManager.GetAllConnectedPeers()
if err != nil {
return peerToReturn, err
}

// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected
if _, connected := statuses[peerToReturn.ID]; !connected {
peerToReturn.Status.Connected = false
}
}

return peerToReturn, nil
}

func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
if err != nil {
util.WriteError(err, w)
return
}

util.WriteJSONObject(w, toPeerResponse(peer, account, h.accountManager.GetDNSDomain()))
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}

util.WriteJSONObject(w, toPeerResponse(peerToReturn, account, h.accountManager.GetDNSDomain()))
}

func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -120,7 +144,12 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {

respBody := []*api.Peer{}
for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account, dnsDomain))
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}
respBody = append(respBody, toPeerResponse(peerToReturn, account, dnsDomain))
}
util.WriteJSONObject(w, respBody)
return
Expand Down
56 changes: 51 additions & 5 deletions management/server/http/peers_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
Expand All @@ -23,19 +24,33 @@ import (
)

const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"

func initTestMetaData(peers ...*server.Peer) *PeersHandler {
return &PeersHandler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) {
p := peers[0].Copy()
var p *server.Peer
for _, peer := range peers {
if update.ID == peer.ID {
p = peer.Copy()
break
}
}
p.SSHEnabled = update.SSHEnabled
p.LoginExpirationEnabled = update.LoginExpirationEnabled
p.Name = update.Name
return p, nil
},
GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) {
return peers[0], nil
var p *server.Peer
for _, peer := range peers {
if peerID == peer.ID {
p = peer.Copy()
break
}
}
return p, nil
},
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
Expand All @@ -57,6 +72,16 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
},
}, user, nil
},
GetAllConnectedPeersFunc: func() (map[string]struct{}, error) {
statuses := make(map[string]struct{})
for _, peer := range peers {
if peer.ID == noUpdateChannelTestPeerID {
break
}
statuses[peer.ID] = struct{}{}
}
return statuses, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
Expand All @@ -79,7 +104,7 @@ func TestGetPeers(t *testing.T) {
Key: "key",
SetupKey: "setupkey",
IP: net.ParseIP("100.64.0.1"),
Status: &server.PeerStatus{},
Status: &server.PeerStatus{Connected: true},
Name: "PeerName",
LoginExpirationEnabled: false,
Meta: server.PeerSystemMeta{
Expand All @@ -93,11 +118,17 @@ func TestGetPeers(t *testing.T) {
},
}

peer1 := peer.Copy()
peer1.ID = noUpdateChannelTestPeerID

expectedUpdatedPeer := peer.Copy()
expectedUpdatedPeer.LoginExpirationEnabled = true
expectedUpdatedPeer.SSHEnabled = true
expectedUpdatedPeer.Name = "New Name"

expectedPeer1 := peer1.Copy()
expectedPeer1.Status.Connected = false

tt := []struct {
name string
expectedStatus int
Expand All @@ -116,13 +147,21 @@ func TestGetPeers(t *testing.T) {
expectedPeer: peer,
},
{
name: "GetPeer",
name: "GetPeer with update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + testPeerID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: peer,
},
{
name: "GetPeer with no update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + peer1.ID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: expectedPeer1,
},
{
name: "PutPeer",
requestType: http.MethodPut,
Expand All @@ -136,7 +175,7 @@ func TestGetPeers(t *testing.T) {

rr := httptest.NewRecorder()

p := initTestMetaData(peer)
p := initTestMetaData(peer, peer1)

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -171,6 +210,10 @@ func TestGetPeers(t *testing.T) {
t.Fatalf("Sent content is not in correct json format; %v", err)
}

// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)

got = respBody[0]
} else {
got = &api.Peer{}
Expand All @@ -180,12 +223,15 @@ func TestGetPeers(t *testing.T) {
}
}

fmt.Println(got)

assert.Equal(t, got.Name, tc.expectedPeer.Name)
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
assert.Equal(t, got.Os, "OS core")
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
})
}
}
9 changes: 9 additions & 0 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type MockAccountManager struct {
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
}

// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
Expand Down Expand Up @@ -583,3 +584,11 @@ func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *ser
}
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}

// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface
func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
if am.GetAllConnectedPeersFunc != nil {
return am.GetAllConnectedPeersFunc()
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented")
}

0 comments on commit f5cb207

Please sign in to comment.