From f5cb2077c063fd535fb7e759e3f1dfddd85a57ec Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Wed, 11 Oct 2023 18:11:45 +0200 Subject: [PATCH] Rework peer connection status based on the update channel existence (#1213) With this change, we don't need to update all peers on startup. We will check the existence of an update channel when returning a list or single peer on API. Then after restarting of server consumers of API will see peer not connected status till the creation of an updated channel which indicates peer successful connection. --- management/server/account.go | 6 ++ management/server/file_store.go | 4 -- management/server/http/peers_handler.go | 33 ++++++++++- management/server/http/peers_handler_test.go | 56 +++++++++++++++++-- management/server/mock_server/account_mock.go | 9 +++ 5 files changed, 97 insertions(+), 11 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index d35ad2566e0..8e453a1fede 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -103,6 +103,7 @@ type AccountManager interface { UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API + GetAllConnectedPeers() (map[string]struct{}, error) } type DefaultAccountManager struct { @@ -1558,6 +1559,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } +// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() +func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { + return am.peersUpdateManager.GetAllConnectedPeers(), nil +} + func isDomainValid(domain string) bool { re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) return re.Match([]byte(domain)) diff --git a/management/server/file_store.go b/management/server/file_store.go index ecd02ba9910..b90b1d607eb 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -111,10 +111,6 @@ func restore(file string) (*FileStore, error) { for _, peer := range account.Peers { store.PeerKeyID2AccountID[peer.Key] = accountID store.PeerID2AccountID[peer.ID] = accountID - // reset all peers to status = Disconnected - if peer.Status != nil && peer.Status.Connected { - peer.Status.Connected = false - } } for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index adf4a972102..a485d6ccf96 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -31,6 +31,24 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } +func (h *PeersHandler) checkPeerStatus(peer *server.Peer) (*server.Peer, error) { + peerToReturn := peer.Copy() + if peer.Status.Connected { + statuses, err := h.accountManager.GetAllConnectedPeers() + if err != nil { + return peerToReturn, err + } + + // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected + // This may happen after server restart when not all peers are yet connected + if _, connected := statuses[peerToReturn.ID]; !connected { + peerToReturn.Status.Connected = false + } + } + + return peerToReturn, nil +} + func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(account.Id, peerID, userID) if err != nil { @@ -38,7 +56,13 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w return } - util.WriteJSONObject(w, toPeerResponse(peer, account, h.accountManager.GetDNSDomain())) + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, toPeerResponse(peerToReturn, account, h.accountManager.GetDNSDomain())) } func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { @@ -120,7 +144,12 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody := []*api.Peer{} for _, peer := range peers { - respBody = append(respBody, toPeerResponse(peer, account, dnsDomain)) + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + respBody = append(respBody, toPeerResponse(peerToReturn, account, dnsDomain)) } util.WriteJSONObject(w, respBody) return diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 7fe732f2fc2..1856861d549 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -3,6 +3,7 @@ package http import ( "bytes" "encoding/json" + "fmt" "io" "net" "net/http" @@ -23,19 +24,33 @@ import ( ) const testPeerID = "test_peer" +const noUpdateChannelTestPeerID = "no-update-channel" func initTestMetaData(peers ...*server.Peer) *PeersHandler { return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) { - p := peers[0].Copy() + var p *server.Peer + for _, peer := range peers { + if update.ID == peer.ID { + p = peer.Copy() + break + } + } p.SSHEnabled = update.SSHEnabled p.LoginExpirationEnabled = update.LoginExpirationEnabled p.Name = update.Name return p, nil }, GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) { - return peers[0], nil + var p *server.Peer + for _, peer := range peers { + if peerID == peer.ID { + p = peer.Copy() + break + } + } + return p, nil }, GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) { return peers, nil @@ -57,6 +72,16 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler { }, }, user, nil }, + GetAllConnectedPeersFunc: func() (map[string]struct{}, error) { + statuses := make(map[string]struct{}) + for _, peer := range peers { + if peer.ID == noUpdateChannelTestPeerID { + break + } + statuses[peer.ID] = struct{}{} + } + return statuses, nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { @@ -79,7 +104,7 @@ func TestGetPeers(t *testing.T) { Key: "key", SetupKey: "setupkey", IP: net.ParseIP("100.64.0.1"), - Status: &server.PeerStatus{}, + Status: &server.PeerStatus{Connected: true}, Name: "PeerName", LoginExpirationEnabled: false, Meta: server.PeerSystemMeta{ @@ -93,11 +118,17 @@ func TestGetPeers(t *testing.T) { }, } + peer1 := peer.Copy() + peer1.ID = noUpdateChannelTestPeerID + expectedUpdatedPeer := peer.Copy() expectedUpdatedPeer.LoginExpirationEnabled = true expectedUpdatedPeer.SSHEnabled = true expectedUpdatedPeer.Name = "New Name" + expectedPeer1 := peer1.Copy() + expectedPeer1.Status.Connected = false + tt := []struct { name string expectedStatus int @@ -116,13 +147,21 @@ func TestGetPeers(t *testing.T) { expectedPeer: peer, }, { - name: "GetPeer", + name: "GetPeer with update channel", requestType: http.MethodGet, requestPath: "/api/peers/" + testPeerID, expectedStatus: http.StatusOK, expectedArray: false, expectedPeer: peer, }, + { + name: "GetPeer with no update channel", + requestType: http.MethodGet, + requestPath: "/api/peers/" + peer1.ID, + expectedStatus: http.StatusOK, + expectedArray: false, + expectedPeer: expectedPeer1, + }, { name: "PutPeer", requestType: http.MethodPut, @@ -136,7 +175,7 @@ func TestGetPeers(t *testing.T) { rr := httptest.NewRecorder() - p := initTestMetaData(peer) + p := initTestMetaData(peer, peer1) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -171,6 +210,10 @@ func TestGetPeers(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + // hardcode this check for now as we only have two peers in this suite + assert.Equal(t, len(respBody), 2) + assert.Equal(t, respBody[1].Connected, false) + got = respBody[0] } else { got = &api.Peer{} @@ -180,12 +223,15 @@ func TestGetPeers(t *testing.T) { } } + fmt.Println(got) + assert.Equal(t, got.Name, tc.expectedPeer.Name) assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion) assert.Equal(t, got.Ip, tc.expectedPeer.IP.String()) assert.Equal(t, got.Os, "OS core") assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled) assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled) + assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected) }) } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 5432b201bfb..ab3748c01bd 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -75,6 +75,7 @@ type MockAccountManager struct { LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error) SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -583,3 +584,11 @@ func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *ser } return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } + +// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface +func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { + if am.GetAllConnectedPeersFunc != nil { + return am.GetAllConnectedPeersFunc() + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented") +}