diff --git a/management/server/policy.go b/management/server/policy.go index 8fa1363925a..1e1374a2017 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -2,6 +2,7 @@ package server import ( _ "embed" + "fmt" "strconv" "strings" @@ -220,23 +221,11 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []* continue } - peer, ok := a.Peers[peerID] - if !ok && peer == nil { - continue - } - - for _, postureChecksID := range policy.SourcePostureChecks { - postureChecks := getPostureCheck(a, postureChecksID) - if postureChecks == nil { - continue - } - - for _, check := range postureChecks.Checks { - if err := check.Check(*peer); err != nil { - log.Debugf("an error occurred on check %s: %s", check.Name(), err.Error()) - continue - } - } + // if peer validation fails, the peer should not be able to connect to the policy peer's + // we return an empty list of peers and firewall rule for that policy + err := a.validatePostureChecksOnPeer(policy.SourcePostureChecks, peerID) + if err != nil { + return nil, nil } for _, rule := range policy.Rules { @@ -294,6 +283,14 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in if peer == nil { continue } + + for _, policy := range a.Policies { + err := a.validatePostureChecksOnPeer(policy.SourcePostureChecks, peer.ID) + if err != nil { + continue + } + } + if _, ok := peersExists[peer.ID]; !ok { peers = append(peers, peer) peersExists[peer.ID] = struct{}{} @@ -533,7 +530,29 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([] return filteredPeers, peerInGroups } -func getPostureCheck(account *Account, postureChecksID string) *posture.Checks { +func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) error { + peer, ok := a.Peers[peerID] + if !ok && peer == nil { + return fmt.Errorf("peer %s does not exists", peerID) + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := getPostureChecks(a, postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.Checks { + if err := check.Check(*peer); err != nil { + return fmt.Errorf("an error occurred on check %s: %s", check.Name(), err.Error()) + } + } + } + + return nil +} + +func getPostureChecks(account *Account, postureChecksID string) *posture.Checks { for _, postureChecks := range account.PostureChecks { if postureChecks.ID == postureChecksID { return postureChecks