diff --git a/pkg/receive/handler.go b/pkg/receive/handler.go index e632c9788f9..4cdd4b47304 100644 --- a/pkg/receive/handler.go +++ b/pkg/receive/handler.go @@ -106,7 +106,7 @@ type Handler struct { mtx sync.RWMutex hashring Hashring - peers *peerGroup + peers peersContainer expBackoff backoff.Backoff peerStates map[string]*retryState receiverMode ReceiverMode @@ -253,11 +253,49 @@ func (h *Handler) Hashring(hashring Hashring) { h.mtx.Lock() defer h.mtx.Unlock() + if h.hashring != nil { + previousNodes := h.hashring.Nodes() + newNodes := hashring.Nodes() + + disappearedNodes := getSortedStringSliceDiff(previousNodes, newNodes) + for _, node := range disappearedNodes { + if err := h.peers.close(node); err != nil { + level.Error(h.logger).Log("msg", "closing gRPC connection failed, we might have leaked a file descriptor", "addr", node, "err", err.Error()) + } + } + } + h.hashring = hashring h.expBackoff.Reset() h.peerStates = make(map[string]*retryState) } +// getSortedStringSliceDiff returns items which are in slice1 but not in slice2. +// The returned slice also only contains unique items i.e. it is a set. +func getSortedStringSliceDiff(slice1, slice2 []string) []string { + slice1Items := make(map[string]struct{}, len(slice1)) + slice2Items := make(map[string]struct{}, len(slice2)) + + for _, s1 := range slice1 { + slice1Items[s1] = struct{}{} + } + for _, s2 := range slice2 { + slice2Items[s2] = struct{}{} + } + + var difference = make([]string, 0) + for s1 := range slice1Items { + _, s2Contains := slice2Items[s1] + if s2Contains { + continue + } + difference = append(difference, s1) + } + sort.Strings(difference) + + return difference +} + // Verifies whether the server is ready or not. func (h *Handler) isReady() bool { h.mtx.RLock() @@ -1123,31 +1161,53 @@ func newReplicationErrors(threshold, numErrors int) []*replicationErrors { return errs } -func newPeerGroup(dialOpts ...grpc.DialOption) *peerGroup { +func newPeerGroup(dialOpts ...grpc.DialOption) peersContainer { return &peerGroup{ dialOpts: dialOpts, - cache: map[string]storepb.WriteableStoreClient{}, + cache: map[string]*grpc.ClientConn{}, m: sync.RWMutex{}, dialer: grpc.DialContext, } } +type peersContainer interface { + close(string) error + get(context.Context, string) (storepb.WriteableStoreClient, error) +} + type peerGroup struct { dialOpts []grpc.DialOption - cache map[string]storepb.WriteableStoreClient + cache map[string]*grpc.ClientConn m sync.RWMutex // dialer is used for testing. dialer func(ctx context.Context, target string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) } +func (p *peerGroup) close(addr string) error { + p.m.Lock() + defer p.m.Unlock() + + c, ok := p.cache[addr] + if !ok { + return fmt.Errorf("address %s not found", addr) + } + + if err := c.Close(); err != nil { + return fmt.Errorf("closing connection for %s", addr) + } + + delete(p.cache, addr) + return nil +} + func (p *peerGroup) get(ctx context.Context, addr string) (storepb.WriteableStoreClient, error) { // use a RLock first to prevent blocking if we don't need to. p.m.RLock() c, ok := p.cache[addr] p.m.RUnlock() if ok { - return c, nil + return storepb.NewWriteableStoreClient(c), nil } p.m.Lock() @@ -1155,14 +1215,13 @@ func (p *peerGroup) get(ctx context.Context, addr string) (storepb.WriteableStor // Make sure that another caller hasn't created the connection since obtaining the write lock. c, ok = p.cache[addr] if ok { - return c, nil + return storepb.NewWriteableStoreClient(c), nil } conn, err := p.dialer(ctx, addr, p.dialOpts...) if err != nil { return nil, errors.Wrap(err, "failed to dial peer") } - client := storepb.NewWriteableStoreClient(conn) - p.cache[addr] = client - return client, nil + p.cache[addr] = conn + return storepb.NewWriteableStoreClient(conn), nil } diff --git a/pkg/receive/handler_test.go b/pkg/receive/handler_test.go index b71e438edf8..de511dc8b69 100644 --- a/pkg/receive/handler_test.go +++ b/pkg/receive/handler_test.go @@ -166,24 +166,38 @@ func (f *fakeAppender) AppendCTZeroSample(ref storage.SeriesRef, l labels.Labels panic("not implemented") } +type fakePeersGroup struct { + clients map[string]storepb.WriteableStoreClient + + closeCalled map[string]bool +} + +func (g *fakePeersGroup) close(addr string) error { + if g.closeCalled == nil { + g.closeCalled = map[string]bool{} + } + g.closeCalled[addr] = true + return nil +} + +func (g *fakePeersGroup) get(_ context.Context, addr string) (storepb.WriteableStoreClient, error) { + c, ok := g.clients[addr] + if !ok { + return nil, fmt.Errorf("client %s not found", addr) + } + return c, nil +} + +var _ = (peersContainer)(&fakePeersGroup{}) + func newTestHandlerHashring(appendables []*fakeAppendable, replicationFactor uint64, hashringAlgo HashringAlgorithm) ([]*Handler, Hashring, error) { var ( cfg = []HashringConfig{{Hashring: "test"}} handlers []*Handler wOpts = &WriterOptions{} ) - // create a fake peer group where we manually fill the cache with fake addresses pointed to our handlers - // This removes the network from the tests and creates a more consistent testing harness. - peers := &peerGroup{ - dialOpts: nil, - m: sync.RWMutex{}, - cache: map[string]storepb.WriteableStoreClient{}, - dialer: func(context.Context, string, ...grpc.DialOption) (*grpc.ClientConn, error) { - // dialer should never be called since we are creating fake clients with fake addresses - // this protects against some leaking test that may attempt to dial random IP addresses - // which may pose a security risk. - return nil, errors.New("unexpected dial called in testing") - }, + fakePeers := &fakePeersGroup{ + clients: map[string]storepb.WriteableStoreClient{}, } ag := addrGen{} @@ -198,11 +212,11 @@ func newTestHandlerHashring(appendables []*fakeAppendable, replicationFactor uin Limiter: limiter, }) handlers = append(handlers, h) - h.peers = peers addr := ag.newAddr() + h.peers = fakePeers + fakePeers.clients[addr] = &fakeRemoteWriteGRPCServer{h: h} h.options.Endpoint = addr cfg[0].Endpoints = append(cfg[0].Endpoints, Endpoint{Address: h.options.Endpoint}) - peers.cache[addr] = &fakeRemoteWriteGRPCServer{h: h} } // Use hashmod as default. if hashringAlgo == "" { @@ -1573,3 +1587,35 @@ func TestGetStatsLimitParameter(t *testing.T) { testutil.Equals(t, limit, givenLimit) }) } + +func TestSortedSliceDiff(t *testing.T) { + testutil.Equals(t, []string{"a"}, getSortedStringSliceDiff([]string{"a", "a", "foo"}, []string{"b", "b", "foo"})) + testutil.Equals(t, []string{}, getSortedStringSliceDiff([]string{}, []string{"b", "b", "foo"})) + testutil.Equals(t, []string{}, getSortedStringSliceDiff([]string{}, []string{})) +} + +func TestHashringChangeCallsClose(t *testing.T) { + appendables := []*fakeAppendable{ + { + appender: newFakeAppender(nil, nil, nil), + }, + { + appender: newFakeAppender(nil, nil, nil), + }, + { + appender: newFakeAppender(nil, nil, nil), + }, + } + allHandlers, _, err := newTestHandlerHashring(appendables, 3, AlgorithmHashmod) + testutil.Ok(t, err) + + appendables = appendables[1:] + + _, smallHashring, err := newTestHandlerHashring(appendables, 2, AlgorithmHashmod) + testutil.Ok(t, err) + + allHandlers[0].Hashring(smallHashring) + + pg := allHandlers[0].peers.(*fakePeersGroup) + testutil.Assert(t, len(pg.closeCalled) > 0) +} diff --git a/pkg/receive/hashring.go b/pkg/receive/hashring.go index 18925cc4cc2..0d7c2dc10c5 100644 --- a/pkg/receive/hashring.go +++ b/pkg/receive/hashring.go @@ -55,6 +55,9 @@ type Hashring interface { Get(tenant string, timeSeries *prompb.TimeSeries) (string, error) // GetN returns the nth node that should handle the given tenant and time series. GetN(tenant string, timeSeries *prompb.TimeSeries, n uint64) (string, error) + // Nodes returns a sorted slice of nodes that are in this hashring. Addresses could be duplicated + // if, for example, the same address is used for multiple tenants in the multi-hashring. + Nodes() []string } // SingleNodeHashring always returns the same node. @@ -65,6 +68,10 @@ func (s SingleNodeHashring) Get(tenant string, ts *prompb.TimeSeries) (string, e return s.GetN(tenant, ts, 0) } +func (s SingleNodeHashring) Nodes() []string { + return []string{string(s)} +} + // GetN implements the Hashring interface. func (s SingleNodeHashring) GetN(_ string, _ *prompb.TimeSeries, n uint64) (string, error) { if n > 0 { @@ -84,9 +91,15 @@ func newSimpleHashring(endpoints []Endpoint) (Hashring, error) { } addresses[i] = endpoints[i].Address } + sort.Strings(addresses) + return simpleHashring(addresses), nil } +func (s simpleHashring) Nodes() []string { + return s +} + // Get returns a target to handle the given tenant and time series. func (s simpleHashring) Get(tenant string, ts *prompb.TimeSeries) (string, error) { return s.GetN(tenant, ts, 0) @@ -120,6 +133,7 @@ type ketamaHashring struct { endpoints []Endpoint sections sections numEndpoints uint64 + nodes []string } func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFactor uint64) (*ketamaHashring, error) { @@ -132,8 +146,11 @@ func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFac hash := xxhash.New() availabilityZones := make(map[string]struct{}) ringSections := make(sections, 0, numSections) + + nodes := []string{} for endpointIndex, endpoint := range endpoints { availabilityZones[endpoint.AZ] = struct{}{} + nodes = append(nodes, endpoint.Address) for i := 1; i <= sectionsPerNode; i++ { _, _ = hash.Write([]byte(endpoint.Address + ":" + strconv.Itoa(i))) n := §ion{ @@ -148,15 +165,21 @@ func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFac } } sort.Sort(ringSections) + sort.Strings(nodes) calculateSectionReplicas(ringSections, replicationFactor, availabilityZones) return &ketamaHashring{ endpoints: endpoints, sections: ringSections, numEndpoints: uint64(len(endpoints)), + nodes: nodes, }, nil } +func (k *ketamaHashring) Nodes() []string { + return k.nodes +} + func sizeOfLeastOccupiedAZ(azSpread map[string]int64) int64 { minValue := int64(math.MaxInt64) for _, value := range azSpread { @@ -232,6 +255,8 @@ type multiHashring struct { // to the cache map, as this is both written to // and read from. mu sync.RWMutex + + nodes []string } // Get returns a target to handle the given tenant and time series. @@ -269,6 +294,10 @@ func (m *multiHashring) GetN(tenant string, ts *prompb.TimeSeries, n uint64) (st return "", errors.New("no matching hashring to handle tenant") } +func (m *multiHashring) Nodes() []string { + return m.nodes +} + // newMultiHashring creates a multi-tenant hashring for a given slice of // groups. // Which hashring to use for a tenant is determined @@ -289,6 +318,7 @@ func NewMultiHashring(algorithm HashringAlgorithm, replicationFactor uint64, cfg if err != nil { return nil, err } + m.nodes = append(m.nodes, hashring.Nodes()...) m.hashrings = append(m.hashrings, hashring) var t map[string]struct{} if len(h.Tenants) != 0 { @@ -299,6 +329,7 @@ func NewMultiHashring(algorithm HashringAlgorithm, replicationFactor uint64, cfg } m.tenantSets = append(m.tenantSets, t) } + sort.Strings(m.nodes) return m, nil }