diff --git a/CHANGELOG.md b/CHANGELOG.md index 3288acdeb..7857018e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -233,6 +233,7 @@ * [ENHANCEMENT] Adapt `metrics.SendSumOfGaugesPerTenant` to use `metrics.MetricOption`. #584 * [ENHANCEMENT] Cache: Add `.Add()` and `.Set()` methods to cache clients. #591 * [ENHANCEMENT] Cache: Add `.Advance()` methods to mock cache clients for easier testing of TTLs. #601 +* [ENHANCEMENT] Memberlist: Add concurrency to the transport's WriteTo method. #525 * [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538 * [BUGFIX] spanlogger: Support multiple tenant IDs. #59 * [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85 diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index 751ad1163..2010d3919 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -19,7 +19,6 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "go.uber.org/atomic" dstls "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" @@ -52,7 +51,13 @@ type TCPTransportConfig struct { // Timeout for writing packet data. Zero = no timeout. PacketWriteTimeout time.Duration `yaml:"packet_write_timeout" category:"advanced"` - // Transport logs lot of messages at debug level, so it deserves an extra flag for turning it on + // Maximum number of concurrent writes to other nodes. + MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"` + + // Timeout for acquiring one of the concurrent write slots. + AcquireWriterTimeout time.Duration `yaml:"acquire_writer_timeout" category:"advanced"` + + // Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on TransportDebug bool `yaml:"-" category:"advanced"` // Where to put custom metrics. nil = don't register. @@ -73,12 +78,19 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.") f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") + f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.") + f.DurationVar(&cfg.AcquireWriterTimeout, prefix+"memberlist.acquire-writer-timeout", 250*time.Millisecond, "Timeout for acquiring one of the concurrent write slots. After this time, the message will be dropped.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f) } +type writeRequest struct { + b []byte + addr string +} + // TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream // operations ("packet" and "stream" are terms used by memberlist). // It uses a new TCP connections for each operation. There is no connection reuse. @@ -91,7 +103,11 @@ type TCPTransport struct { tcpListeners []net.Listener tlsConfig *tls.Config - shutdown atomic.Int32 + shutdownMu sync.RWMutex + shutdown bool + writeCh chan writeRequest // this channel is protected by shutdownMu + + writeWG sync.WaitGroup advertiseMu sync.RWMutex advertiseAddr string @@ -119,11 +135,21 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr // Build out the new transport. var ok bool + concurrentWrites := config.MaxConcurrentWrites + if concurrentWrites <= 0 { + concurrentWrites = 1 + } t := TCPTransport{ cfg: config, logger: log.With(logger, "component", "memberlist TCPTransport"), packetCh: make(chan *memberlist.Packet), connCh: make(chan net.Conn), + writeCh: make(chan writeRequest), + } + + for i := 0; i < concurrentWrites; i++ { + t.writeWG.Add(1) + go t.writeWorker() } var err error @@ -205,7 +231,10 @@ func (t *TCPTransport) tcpListen(tcpLn net.Listener) { for { conn, err := tcpLn.Accept() if err != nil { - if s := t.shutdown.Load(); s == 1 { + t.shutdownMu.RLock() + isShuttingDown := t.shutdown + t.shutdownMu.RUnlock() + if isShuttingDown { break } @@ -424,29 +453,49 @@ func (t *TCPTransport) getAdvertisedAddr() string { // WriteTo is a packet-oriented interface that fires off the given // payload to the given address. func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { - t.sentPackets.Inc() - t.sentPacketsBytes.Add(float64(len(b))) + t.shutdownMu.RLock() + defer t.shutdownMu.RUnlock() // Unlock at the end to protect the chan + if t.shutdown { + return time.Time{}, errors.New("transport is shutting down") + } - err := t.writeTo(b, addr) - if err != nil { + // Send the packet to the write workers + // If this blocks for too long (as configured), abort and log an error. + select { + case <-time.After(t.cfg.AcquireWriterTimeout): + level.Warn(t.logger).Log("msg", "WriteTo failed to acquire a writer. Dropping message", "timeout", t.cfg.AcquireWriterTimeout, "addr", addr) t.sentPacketsErrors.Inc() - - logLevel := level.Warn(t.logger) - if strings.Contains(err.Error(), "connection refused") { - // The connection refused is a common error that could happen during normal operations when a node - // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. - logLevel = t.debugLog() - } - logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) - // WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors, // but memberlist library doesn't seem to cope with that very well. That is why we return nil instead. return time.Now(), nil + case t.writeCh <- writeRequest{b: b, addr: addr}: + // OK } return time.Now(), nil } +func (t *TCPTransport) writeWorker() { + defer t.writeWG.Done() + for req := range t.writeCh { + b, addr := req.b, req.addr + t.sentPackets.Inc() + t.sentPacketsBytes.Add(float64(len(b))) + err := t.writeTo(b, addr) + if err != nil { + t.sentPacketsErrors.Inc() + + logLevel := level.Warn(t.logger) + if strings.Contains(err.Error(), "connection refused") { + // The connection refused is a common error that could happen during normal operations when a node + // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. + logLevel = t.debugLog() + } + logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) + } + } +} + func (t *TCPTransport) writeTo(b []byte, addr string) error { // Open connection, write packet header and data, data hash, close. Simple. c, err := t.getConnection(addr, t.cfg.PacketDialTimeout) @@ -559,17 +608,31 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn { // Shutdown is called when memberlist is shutting down; this gives the // transport a chance to clean up any listeners. +// This will avoid log spam about errors when we shut down. func (t *TCPTransport) Shutdown() error { + t.shutdownMu.Lock() // This will avoid log spam about errors when we shut down. - t.shutdown.Store(1) + if t.shutdown { + t.shutdownMu.Unlock() + return nil // already shut down + } + + // Set the shutdown flag and close the write channel. + t.shutdown = true + close(t.writeCh) + t.shutdownMu.Unlock() // Rip through all the connections and shut them down. for _, conn := range t.tcpListeners { _ = conn.Close() } + // Wait until all write workers have finished. + t.writeWG.Wait() + // Block until all the listener threads have died. t.wg.Wait() + return nil } diff --git a/kv/memberlist/tcp_transport_test.go b/kv/memberlist/tcp_transport_test.go index 310e11ecb..1803c8280 100644 --- a/kv/memberlist/tcp_transport_test.go +++ b/kv/memberlist/tcp_transport_test.go @@ -1,7 +1,11 @@ package memberlist import ( + "net" + "strings" + "sync" "testing" + "time" "github.com/go-kit/log" "github.com/prometheus/client_golang/prometheus" @@ -9,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/dskit/concurrency" + "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" ) @@ -51,6 +56,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T _, err = transport.WriteTo([]byte("test"), testData.remoteAddr) require.NoError(t, err) + require.NoError(t, transport.Shutdown()) + if testData.expectedLogs != "" { assert.Contains(t, logs.String(), testData.expectedLogs) } @@ -61,6 +68,88 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T } } +type timeoutReader struct{} + +func (f *timeoutReader) ReadSecret(_ string) ([]byte, error) { + time.Sleep(1 * time.Second) + return nil, nil +} + +func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { + writeCt := 50 + + // Listen for TCP connections on a random port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.MaxConcurrentWrites = writeCt + cfg.PacketDialTimeout = 500 * time.Millisecond + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + // Configure TLS only for writes. The dialing should timeout (because of the timeoutReader) + transport.cfg.TLSEnabled = true + transport.cfg.TLS = tls.ClientConfig{ + Reader: &timeoutReader{}, + CertPath: "fake", + KeyPath: "fake", + CAPath: "fake", + } + + timeStart := time.Now() + + for i := 0; i < writeCt; i++ { + _, err = transport.WriteTo([]byte("test"), listener.Addr().String()) + require.NoError(t, err) + } + + require.NoError(t, transport.Shutdown()) + + gotErrorCt := strings.Count(logs.String(), "context deadline exceeded") + assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt) + assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)") + assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block") +} + +func TestTCPTransportWriterAcquireTimeout(t *testing.T) { + // Listen for TCP connections on a random port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.MaxConcurrentWrites = 1 + cfg.AcquireWriterTimeout = 1 * time.Millisecond // very short timeout + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + writeCt := 100 + var reqWg sync.WaitGroup + for i := 0; i < writeCt; i++ { + reqWg.Add(1) + go func() { + defer reqWg.Done() + transport.WriteTo([]byte("test"), listener.Addr().String()) // nolint:errcheck + }() + } + reqWg.Wait() + + require.NoError(t, transport.Shutdown()) + gotErrorCt := strings.Count(logs.String(), "WriteTo failed to acquire a writer. Dropping message") + assert.Less(t, gotErrorCt, writeCt, "expected to have less errors (%d) than total writes (%d). Some writes should pass.", gotErrorCt, writeCt) + assert.NotZero(t, gotErrorCt, "expected errors, got none") +} + func TestFinalAdvertiseAddr(t *testing.T) { tests := map[string]struct { advertiseAddr string