Skip to content

Commit

Permalink
receive: do not leak grpc connections
Browse files Browse the repository at this point in the history
Prevent a leak in gRPC connections by garbage collecting old ones when
the hashring changes. For that purpose, I propose adding a `Nodes()
string` method so that it would be possible to know what nodes do not
exist in the hashring anymore.

Signed-off-by: Giedrius Statkevičius <[email protected]>
  • Loading branch information
GiedriusS committed Jan 4, 2024
1 parent b884c51 commit 3127d48
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 23 deletions.
77 changes: 68 additions & 9 deletions pkg/receive/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type Handler struct {

mtx sync.RWMutex
hashring Hashring
peers *peerGroup
peers peersContainer
expBackoff backoff.Backoff
peerStates map[string]*retryState
receiverMode ReceiverMode
Expand Down Expand Up @@ -253,11 +253,49 @@ func (h *Handler) Hashring(hashring Hashring) {
h.mtx.Lock()
defer h.mtx.Unlock()

if h.hashring != nil {
previousNodes := h.hashring.Nodes()
newNodes := hashring.Nodes()

disappearedNodes := getSortedStringSliceDiff(previousNodes, newNodes)
for _, node := range disappearedNodes {
if err := h.peers.close(node); err != nil {
level.Error(h.logger).Log("msg", "closing gRPC connection failed, we might have leaked a file descriptor", "addr", node, "err", err.Error())
}
}
}

h.hashring = hashring
h.expBackoff.Reset()
h.peerStates = make(map[string]*retryState)
}

// getSortedStringSliceDiff returns items which are in slice1 but not in slice2.
// The returned slice also only contains unique items i.e. it is a set.
func getSortedStringSliceDiff(slice1, slice2 []string) []string {
slice1Items := make(map[string]struct{}, len(slice1))
slice2Items := make(map[string]struct{}, len(slice2))

for _, s1 := range slice1 {
slice1Items[s1] = struct{}{}
}
for _, s2 := range slice2 {
slice2Items[s2] = struct{}{}
}

var difference = make([]string, 0)
for s1 := range slice1Items {
_, s2Contains := slice2Items[s1]
if s2Contains {
continue
}
difference = append(difference, s1)
}
sort.Strings(difference)

return difference
}

// Verifies whether the server is ready or not.
func (h *Handler) isReady() bool {
h.mtx.RLock()
Expand Down Expand Up @@ -1123,46 +1161,67 @@ func newReplicationErrors(threshold, numErrors int) []*replicationErrors {
return errs
}

func newPeerGroup(dialOpts ...grpc.DialOption) *peerGroup {
func newPeerGroup(dialOpts ...grpc.DialOption) peersContainer {
return &peerGroup{
dialOpts: dialOpts,
cache: map[string]storepb.WriteableStoreClient{},
cache: map[string]*grpc.ClientConn{},
m: sync.RWMutex{},
dialer: grpc.DialContext,
}
}

type peersContainer interface {
close(string) error
get(context.Context, string) (storepb.WriteableStoreClient, error)
}

type peerGroup struct {
dialOpts []grpc.DialOption
cache map[string]storepb.WriteableStoreClient
cache map[string]*grpc.ClientConn
m sync.RWMutex

// dialer is used for testing.
dialer func(ctx context.Context, target string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error)
}

func (p *peerGroup) close(addr string) error {
p.m.Lock()
defer p.m.Unlock()

c, ok := p.cache[addr]
if !ok {
return fmt.Errorf("address %s not found", addr)
}

if err := c.Close(); err != nil {
return fmt.Errorf("closing connection for %s", addr)
}

delete(p.cache, addr)
return nil
}

func (p *peerGroup) get(ctx context.Context, addr string) (storepb.WriteableStoreClient, error) {
// use a RLock first to prevent blocking if we don't need to.
p.m.RLock()
c, ok := p.cache[addr]
p.m.RUnlock()
if ok {
return c, nil
return storepb.NewWriteableStoreClient(c), nil
}

p.m.Lock()
defer p.m.Unlock()
// Make sure that another caller hasn't created the connection since obtaining the write lock.
c, ok = p.cache[addr]
if ok {
return c, nil
return storepb.NewWriteableStoreClient(c), nil
}
conn, err := p.dialer(ctx, addr, p.dialOpts...)
if err != nil {
return nil, errors.Wrap(err, "failed to dial peer")
}

client := storepb.NewWriteableStoreClient(conn)
p.cache[addr] = client
return client, nil
p.cache[addr] = conn
return storepb.NewWriteableStoreClient(conn), nil
}
74 changes: 60 additions & 14 deletions pkg/receive/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,38 @@ func (f *fakeAppender) AppendCTZeroSample(ref storage.SeriesRef, l labels.Labels
panic("not implemented")
}

type fakePeersGroup struct {
clients map[string]storepb.WriteableStoreClient

closeCalled map[string]bool
}

func (g *fakePeersGroup) close(addr string) error {
if g.closeCalled == nil {
g.closeCalled = map[string]bool{}
}
g.closeCalled[addr] = true
return nil
}

func (g *fakePeersGroup) get(_ context.Context, addr string) (storepb.WriteableStoreClient, error) {
c, ok := g.clients[addr]
if !ok {
return nil, fmt.Errorf("client %s not found", addr)
}
return c, nil
}

var _ = (peersContainer)(&fakePeersGroup{})

func newTestHandlerHashring(appendables []*fakeAppendable, replicationFactor uint64, hashringAlgo HashringAlgorithm) ([]*Handler, Hashring, error) {
var (
cfg = []HashringConfig{{Hashring: "test"}}
handlers []*Handler
wOpts = &WriterOptions{}
)
// create a fake peer group where we manually fill the cache with fake addresses pointed to our handlers
// This removes the network from the tests and creates a more consistent testing harness.
peers := &peerGroup{
dialOpts: nil,
m: sync.RWMutex{},
cache: map[string]storepb.WriteableStoreClient{},
dialer: func(context.Context, string, ...grpc.DialOption) (*grpc.ClientConn, error) {
// dialer should never be called since we are creating fake clients with fake addresses
// this protects against some leaking test that may attempt to dial random IP addresses
// which may pose a security risk.
return nil, errors.New("unexpected dial called in testing")
},
fakePeers := &fakePeersGroup{
clients: map[string]storepb.WriteableStoreClient{},
}

ag := addrGen{}
Expand All @@ -198,11 +212,11 @@ func newTestHandlerHashring(appendables []*fakeAppendable, replicationFactor uin
Limiter: limiter,
})
handlers = append(handlers, h)
h.peers = peers
addr := ag.newAddr()
h.peers = fakePeers
fakePeers.clients[addr] = &fakeRemoteWriteGRPCServer{h: h}
h.options.Endpoint = addr
cfg[0].Endpoints = append(cfg[0].Endpoints, Endpoint{Address: h.options.Endpoint})
peers.cache[addr] = &fakeRemoteWriteGRPCServer{h: h}
}
// Use hashmod as default.
if hashringAlgo == "" {
Expand Down Expand Up @@ -1573,3 +1587,35 @@ func TestGetStatsLimitParameter(t *testing.T) {
testutil.Equals(t, limit, givenLimit)
})
}

func TestSortedSliceDiff(t *testing.T) {
testutil.Equals(t, []string{"a"}, getSortedStringSliceDiff([]string{"a", "a", "foo"}, []string{"b", "b", "foo"}))
testutil.Equals(t, []string{}, getSortedStringSliceDiff([]string{}, []string{"b", "b", "foo"}))
testutil.Equals(t, []string{}, getSortedStringSliceDiff([]string{}, []string{}))
}

func TestHashringChangeCallsClose(t *testing.T) {
appendables := []*fakeAppendable{
{
appender: newFakeAppender(nil, nil, nil),
},
{
appender: newFakeAppender(nil, nil, nil),
},
{
appender: newFakeAppender(nil, nil, nil),
},
}
allHandlers, _, err := newTestHandlerHashring(appendables, 3, AlgorithmHashmod)
testutil.Ok(t, err)

appendables = appendables[1:]

_, smallHashring, err := newTestHandlerHashring(appendables, 2, AlgorithmHashmod)
testutil.Ok(t, err)

allHandlers[0].Hashring(smallHashring)

pg := allHandlers[0].peers.(*fakePeersGroup)
testutil.Assert(t, len(pg.closeCalled) > 0)
}
31 changes: 31 additions & 0 deletions pkg/receive/hashring.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ type Hashring interface {
Get(tenant string, timeSeries *prompb.TimeSeries) (string, error)
// GetN returns the nth node that should handle the given tenant and time series.
GetN(tenant string, timeSeries *prompb.TimeSeries, n uint64) (string, error)
// Nodes returns a sorted slice of nodes that are in this hashring. Addresses could be duplicated
// if, for example, the same address is used for multiple tenants in the multi-hashring.
Nodes() []string
}

// SingleNodeHashring always returns the same node.
Expand All @@ -65,6 +68,10 @@ func (s SingleNodeHashring) Get(tenant string, ts *prompb.TimeSeries) (string, e
return s.GetN(tenant, ts, 0)
}

func (s SingleNodeHashring) Nodes() []string {
return []string{string(s)}
}

// GetN implements the Hashring interface.
func (s SingleNodeHashring) GetN(_ string, _ *prompb.TimeSeries, n uint64) (string, error) {
if n > 0 {
Expand All @@ -84,9 +91,15 @@ func newSimpleHashring(endpoints []Endpoint) (Hashring, error) {
}
addresses[i] = endpoints[i].Address
}
sort.Strings(addresses)

return simpleHashring(addresses), nil
}

func (s simpleHashring) Nodes() []string {
return s
}

// Get returns a target to handle the given tenant and time series.
func (s simpleHashring) Get(tenant string, ts *prompb.TimeSeries) (string, error) {
return s.GetN(tenant, ts, 0)
Expand Down Expand Up @@ -120,6 +133,7 @@ type ketamaHashring struct {
endpoints []Endpoint
sections sections
numEndpoints uint64
nodes []string
}

func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFactor uint64) (*ketamaHashring, error) {
Expand All @@ -132,8 +146,11 @@ func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFac
hash := xxhash.New()
availabilityZones := make(map[string]struct{})
ringSections := make(sections, 0, numSections)

nodes := []string{}
for endpointIndex, endpoint := range endpoints {
availabilityZones[endpoint.AZ] = struct{}{}
nodes = append(nodes, endpoint.Address)
for i := 1; i <= sectionsPerNode; i++ {
_, _ = hash.Write([]byte(endpoint.Address + ":" + strconv.Itoa(i)))
n := &section{
Expand All @@ -148,15 +165,21 @@ func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFac
}
}
sort.Sort(ringSections)
sort.Strings(nodes)
calculateSectionReplicas(ringSections, replicationFactor, availabilityZones)

return &ketamaHashring{
endpoints: endpoints,
sections: ringSections,
numEndpoints: uint64(len(endpoints)),
nodes: nodes,
}, nil
}

func (k *ketamaHashring) Nodes() []string {
return k.nodes
}

func sizeOfLeastOccupiedAZ(azSpread map[string]int64) int64 {
minValue := int64(math.MaxInt64)
for _, value := range azSpread {
Expand Down Expand Up @@ -232,6 +255,8 @@ type multiHashring struct {
// to the cache map, as this is both written to
// and read from.
mu sync.RWMutex

nodes []string
}

// Get returns a target to handle the given tenant and time series.
Expand Down Expand Up @@ -269,6 +294,10 @@ func (m *multiHashring) GetN(tenant string, ts *prompb.TimeSeries, n uint64) (st
return "", errors.New("no matching hashring to handle tenant")
}

func (m *multiHashring) Nodes() []string {
return m.nodes
}

// newMultiHashring creates a multi-tenant hashring for a given slice of
// groups.
// Which hashring to use for a tenant is determined
Expand All @@ -289,6 +318,7 @@ func NewMultiHashring(algorithm HashringAlgorithm, replicationFactor uint64, cfg
if err != nil {
return nil, err
}
m.nodes = append(m.nodes, hashring.Nodes()...)
m.hashrings = append(m.hashrings, hashring)
var t map[string]struct{}
if len(h.Tenants) != 0 {
Expand All @@ -299,6 +329,7 @@ func NewMultiHashring(algorithm HashringAlgorithm, replicationFactor uint64, cfg
}
m.tenantSets = append(m.tenantSets, t)
}
sort.Strings(m.nodes)
return m, nil
}

Expand Down

0 comments on commit 3127d48

Please sign in to comment.