diff --git a/peer.go b/peer.go index bd268ca4..186a089c 100644 --- a/peer.go +++ b/peer.go @@ -21,7 +21,6 @@ package tchannel import ( - "container/heap" "errors" "strings" "sync" @@ -63,19 +62,21 @@ type Connectable interface { type PeerList struct { sync.RWMutex - parent *RootPeerList - peersByHostPort map[string]*peerScore - peerHeap *peerHeap - scoreCalculator ScoreCalculator - lastSelected uint64 + parent *RootPeerList + peersByHostPort map[string]*peerScore + peerHeap *peerHeap + scoreCalculator ScoreCalculator + peerConnectionCount uint32 + lastSelected uint64 } func newPeerList(root *RootPeerList) *PeerList { return &PeerList{ - parent: root, - peersByHostPort: make(map[string]*peerScore), - scoreCalculator: newPreferIncomingCalculator(), - peerHeap: newPeerHeap(), + parent: root, + peersByHostPort: make(map[string]*peerScore), + scoreCalculator: newPreferIncomingCalculator(), + peerHeap: newPeerHeap(), + peerConnectionCount: 1, } } @@ -91,6 +92,20 @@ func (l *PeerList) SetStrategy(sc ScoreCalculator) { } } +// SetPeerConnectionCount sets the number of peer connections to be used in +// combination with the ScoreCalculator to achieve a random load balancing +// of a single client node to `peerConnectionCount` number of server nodes +func (l *PeerList) SetPeerConnectionCount(peerConnectionCount uint32) error { + l.Lock() + defer l.Unlock() + + if peerConnectionCount == 0 { + return errors.New("peer connection count must be greater than 0") + } + l.peerConnectionCount = peerConnectionCount + return nil +} + // Siblings don't share peer lists (though they take care not to double-connect // to the same hosts). func (l *PeerList) newSibling() *PeerList { @@ -175,8 +190,8 @@ func (l *PeerList) Remove(hostPort string) error { return nil } func (l *PeerList) choosePeer(prevSelected map[string]struct{}, avoidHost bool) *Peer { - var psPopList []*peerScore - var ps *peerScore + var chosenPSList = make([]*peerScore, 0, l.peerConnectionCount) + var poppedList = make([]*peerScore, 0, l.peerConnectionCount) canChoosePeer := func(hostPort string) bool { if _, ok := prevSelected[hostPort]; ok { @@ -191,29 +206,39 @@ func (l *PeerList) choosePeer(prevSelected map[string]struct{}, avoidHost bool) } size := l.peerHeap.Len() + + var connectionCount uint32 for i := 0; i < size; i++ { popped := l.peerHeap.popPeer() + poppedList = append(poppedList, popped) if canChoosePeer(popped.HostPort()) { - ps = popped - break + chosenPSList = append(chosenPSList, popped) + connectionCount++ + if connectionCount >= l.peerConnectionCount { + break + } } - psPopList = append(psPopList, popped) - } - for _, p := range psPopList { - heap.Push(l.peerHeap, p) } - if ps == nil { + for _, p := range poppedList { + l.peerHeap.pushPeer(p) + } + if len(chosenPSList) == 0 { return nil } - l.peerHeap.pushPeer(ps) + ps := randomSampling(chosenPSList) ps.chosenCount.Inc() return ps.Peer } +func randomSampling(psList []*peerScore) *peerScore { + r := peerRng.Intn(len(psList)) + return psList[r] +} + // GetOrAdd returns a peer for the given hostPort, creating one if it doesn't yet exist. func (l *PeerList) GetOrAdd(hostPort string) *Peer { if ps, ok := l.exists(hostPort); ok { diff --git a/peer_test.go b/peer_test.go index e5bc3fcc..9b841666 100644 --- a/peer_test.go +++ b/peer_test.go @@ -697,6 +697,98 @@ func TestPeerSelectionRanking(t *testing.T) { } } +func TestZeroPeerConnectionCount(t *testing.T) { + ch := testutils.NewClient(t, nil) + defer ch.Close() + err := ch.Peers().SetPeerConnectionCount(0) + require.Error(t, err, "peerConnectionCount should not accept 0") +} + +func TestPeerRandomSampling(t *testing.T) { + const numIterations = 1000 + + testCases := []struct { + numPeers int + peerConnectionCount uint32 + distMin float64 + distMax float64 + }{ + // the higher `peerConnectionCount` is, the smoother the impact of uneven scores + // become as we are random sampling among `peerConnectionCount` peers + {numPeers: 10, peerConnectionCount: 1, distMin: 1000, distMax: 1000}, + {numPeers: 10, peerConnectionCount: 2, distMin: 470, distMax: 530}, + {numPeers: 10, peerConnectionCount: 5, distMin: 160, distMax: 240}, + {numPeers: 10, peerConnectionCount: 10, distMin: 50, distMax: 150}, + {numPeers: 10, peerConnectionCount: 15, distMin: 50, distMax: 150}, + } + + for _, tc := range testCases { + // Selected is a map from rank -> [peer, count] + // It tracks how often a peer gets selected at a specific rank. + selected := make([]map[string]int, tc.numPeers) + for i := 0; i < tc.numPeers; i++ { + selected[i] = make(map[string]int) + } + + for i := 0; i < numIterations; i++ { + ch := testutils.NewClient(t, nil) + defer ch.Close() + ch.SetRandomSeed(int64(i * 100)) + // Using a strategy that has uneven scores + strategy, _ := createScoreStrategy(0, 1) + ch.Peers().SetStrategy(strategy) + ch.Peers().SetPeerConnectionCount(tc.peerConnectionCount) + + for i := 0; i < tc.numPeers; i++ { + hp := fmt.Sprintf("127.0.0.1:60%v", i) + ch.Peers().Add(hp) + } + + for i := 0; i < tc.numPeers; i++ { + peer, err := ch.Peers().Get(nil) + require.NoError(t, err, "Peers.Get failed") + selected[i][peer.HostPort()]++ + } + } + + for _, m := range selected { + testDistribution(t, m, tc.distMin, tc.distMax) + } + } + +} + +func BenchmarkGetPeerWithPeerConnectionCount1(b *testing.B) { + doBenchmarkGetPeerWithPeerConnectionCount(b, 10, uint32(1)) +} + +func BenchmarkGetPeerWithPeerConnectionCount10(b *testing.B) { + doBenchmarkGetPeerWithPeerConnectionCount(b, 10, uint32(10)) +} + +func doBenchmarkGetPeerWithPeerConnectionCount(b *testing.B, numPeers int, peerConnectionCount uint32) { + ch := testutils.NewClient(b, nil) + defer ch.Close() + ch.SetRandomSeed(int64(100)) + // Using a strategy that has uneven scores + strategy, _ := createScoreStrategy(0, 1) + ch.Peers().SetStrategy(strategy) + ch.Peers().SetPeerConnectionCount(peerConnectionCount) + + for i := 0; i < numPeers; i++ { + hp := fmt.Sprintf("127.0.0.1:60%v", i) + ch.Peers().Add(hp) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + peer, _ := ch.Peers().Get(nil) + if peer == nil { + b.Fatal("Just a dummy check to guard against compiler optimization") + } + } +} + func createScoreStrategy(initial, delta int64) (calc ScoreCalculator, retCount *atomic.Uint64) { var ( count atomic.Uint64