Skip to content

Commit

Permalink
Add concurrency to the memberlist transport's WriteTo method (#525)
Browse files Browse the repository at this point in the history
* make WriteTo non-blocking

* Try to make this PR ready to go
- Create goroutines and keep them while the TCPTransport is alive. End them on the `Shutdown` function
- Add `TestTCPTransportWriteToUnreachableAddr` test to check that writing is not blocking anymore (without this PR, it takes `writeCt * timeout` to run and it fails)

* Add CHANGELOG

* Address PR comments
- Rename CHANGELOG
- Mutex lock on shutdown rather than write
- Wait when workers are ended rather than for each write

* Address PR comments
- Move variables around
- Add timeout before dropping requests. This prevents blocking on the `WriteTo` function

---------

Co-authored-by: Julien Duchesne <[email protected]>
  • Loading branch information
aldernero and julienduchesne authored Oct 10, 2024
1 parent f4d4811 commit 879ff5a
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
* [ENHANCEMENT] Adapt `metrics.SendSumOfGaugesPerTenant` to use `metrics.MetricOption`. #584
* [ENHANCEMENT] Cache: Add `.Add()` and `.Set()` methods to cache clients. #591
* [ENHANCEMENT] Cache: Add `.Advance()` methods to mock cache clients for easier testing of TTLs. #601
* [ENHANCEMENT] Memberlist: Add concurrency to the transport's WriteTo method. #525
* [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538
* [BUGFIX] spanlogger: Support multiple tenant IDs. #59
* [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85
Expand Down
99 changes: 81 additions & 18 deletions kv/memberlist/tcp_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.uber.org/atomic"

dstls "github.com/grafana/dskit/crypto/tls"
"github.com/grafana/dskit/flagext"
Expand Down Expand Up @@ -52,7 +51,13 @@ type TCPTransportConfig struct {
// Timeout for writing packet data. Zero = no timeout.
PacketWriteTimeout time.Duration `yaml:"packet_write_timeout" category:"advanced"`

// Transport logs lot of messages at debug level, so it deserves an extra flag for turning it on
// Maximum number of concurrent writes to other nodes.
MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"`

// Timeout for acquiring one of the concurrent write slots.
AcquireWriterTimeout time.Duration `yaml:"acquire_writer_timeout" category:"advanced"`

// Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on
TransportDebug bool `yaml:"-" category:"advanced"`

// Where to put custom metrics. nil = don't register.
Expand All @@ -73,12 +78,19 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s
f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.")
f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.")
f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.")
f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.")
f.DurationVar(&cfg.AcquireWriterTimeout, prefix+"memberlist.acquire-writer-timeout", 250*time.Millisecond, "Timeout for acquiring one of the concurrent write slots. After this time, the message will be dropped.")
f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.")

f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.")
cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f)
}

type writeRequest struct {
b []byte
addr string
}

// TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream
// operations ("packet" and "stream" are terms used by memberlist).
// It uses a new TCP connections for each operation. There is no connection reuse.
Expand All @@ -91,7 +103,11 @@ type TCPTransport struct {
tcpListeners []net.Listener
tlsConfig *tls.Config

shutdown atomic.Int32
shutdownMu sync.RWMutex
shutdown bool
writeCh chan writeRequest // this channel is protected by shutdownMu

writeWG sync.WaitGroup

advertiseMu sync.RWMutex
advertiseAddr string
Expand Down Expand Up @@ -119,11 +135,21 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr

// Build out the new transport.
var ok bool
concurrentWrites := config.MaxConcurrentWrites
if concurrentWrites <= 0 {
concurrentWrites = 1
}
t := TCPTransport{
cfg: config,
logger: log.With(logger, "component", "memberlist TCPTransport"),
packetCh: make(chan *memberlist.Packet),
connCh: make(chan net.Conn),
writeCh: make(chan writeRequest),
}

for i := 0; i < concurrentWrites; i++ {
t.writeWG.Add(1)
go t.writeWorker()
}

var err error
Expand Down Expand Up @@ -205,7 +231,10 @@ func (t *TCPTransport) tcpListen(tcpLn net.Listener) {
for {
conn, err := tcpLn.Accept()
if err != nil {
if s := t.shutdown.Load(); s == 1 {
t.shutdownMu.RLock()
isShuttingDown := t.shutdown
t.shutdownMu.RUnlock()
if isShuttingDown {
break
}

Expand Down Expand Up @@ -424,29 +453,49 @@ func (t *TCPTransport) getAdvertisedAddr() string {
// WriteTo is a packet-oriented interface that fires off the given
// payload to the given address.
func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) {
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
t.shutdownMu.RLock()
defer t.shutdownMu.RUnlock() // Unlock at the end to protect the chan
if t.shutdown {
return time.Time{}, errors.New("transport is shutting down")
}

err := t.writeTo(b, addr)
if err != nil {
// Send the packet to the write workers
// If this blocks for too long (as configured), abort and log an error.
select {
case <-time.After(t.cfg.AcquireWriterTimeout):
level.Warn(t.logger).Log("msg", "WriteTo failed to acquire a writer. Dropping message", "timeout", t.cfg.AcquireWriterTimeout, "addr", addr)
t.sentPacketsErrors.Inc()

logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)

// WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors,
// but memberlist library doesn't seem to cope with that very well. That is why we return nil instead.
return time.Now(), nil
case t.writeCh <- writeRequest{b: b, addr: addr}:
// OK
}

return time.Now(), nil
}

func (t *TCPTransport) writeWorker() {
defer t.writeWG.Done()
for req := range t.writeCh {
b, addr := req.b, req.addr
t.sentPackets.Inc()
t.sentPacketsBytes.Add(float64(len(b)))
err := t.writeTo(b, addr)
if err != nil {
t.sentPacketsErrors.Inc()

logLevel := level.Warn(t.logger)
if strings.Contains(err.Error(), "connection refused") {
// The connection refused is a common error that could happen during normal operations when a node
// shutdown (or crash). It shouldn't be considered a warning condition on the sender side.
logLevel = t.debugLog()
}
logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err)
}
}
}

func (t *TCPTransport) writeTo(b []byte, addr string) error {
// Open connection, write packet header and data, data hash, close. Simple.
c, err := t.getConnection(addr, t.cfg.PacketDialTimeout)
Expand Down Expand Up @@ -559,17 +608,31 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn {

// Shutdown is called when memberlist is shutting down; this gives the
// transport a chance to clean up any listeners.
// This will avoid log spam about errors when we shut down.
func (t *TCPTransport) Shutdown() error {
t.shutdownMu.Lock()
// This will avoid log spam about errors when we shut down.
t.shutdown.Store(1)
if t.shutdown {
t.shutdownMu.Unlock()
return nil // already shut down
}

// Set the shutdown flag and close the write channel.
t.shutdown = true
close(t.writeCh)
t.shutdownMu.Unlock()

// Rip through all the connections and shut them down.
for _, conn := range t.tcpListeners {
_ = conn.Close()
}

// Wait until all write workers have finished.
t.writeWG.Wait()

// Block until all the listener threads have died.
t.wg.Wait()

return nil
}

Expand Down
89 changes: 89 additions & 0 deletions kv/memberlist/tcp_transport_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package memberlist

import (
"net"
"strings"
"sync"
"testing"
"time"

"github.com/go-kit/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/grafana/dskit/concurrency"
"github.com/grafana/dskit/crypto/tls"
"github.com/grafana/dskit/flagext"
)

Expand Down Expand Up @@ -51,6 +56,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
_, err = transport.WriteTo([]byte("test"), testData.remoteAddr)
require.NoError(t, err)

require.NoError(t, transport.Shutdown())

if testData.expectedLogs != "" {
assert.Contains(t, logs.String(), testData.expectedLogs)
}
Expand All @@ -61,6 +68,88 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T
}
}

type timeoutReader struct{}

func (f *timeoutReader) ReadSecret(_ string) ([]byte, error) {
time.Sleep(1 * time.Second)
return nil, nil
}

func TestTCPTransportWriteToUnreachableAddr(t *testing.T) {
writeCt := 50

// Listen for TCP connections on a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
flagext.DefaultValues(&cfg)
cfg.MaxConcurrentWrites = writeCt
cfg.PacketDialTimeout = 500 * time.Millisecond
transport, err := NewTCPTransport(cfg, logger, nil)
require.NoError(t, err)

// Configure TLS only for writes. The dialing should timeout (because of the timeoutReader)
transport.cfg.TLSEnabled = true
transport.cfg.TLS = tls.ClientConfig{
Reader: &timeoutReader{},
CertPath: "fake",
KeyPath: "fake",
CAPath: "fake",
}

timeStart := time.Now()

for i := 0; i < writeCt; i++ {
_, err = transport.WriteTo([]byte("test"), listener.Addr().String())
require.NoError(t, err)
}

require.NoError(t, transport.Shutdown())

gotErrorCt := strings.Count(logs.String(), "context deadline exceeded")
assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt)
assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)")
assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block")
}

func TestTCPTransportWriterAcquireTimeout(t *testing.T) {
// Listen for TCP connections on a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
flagext.DefaultValues(&cfg)
cfg.MaxConcurrentWrites = 1
cfg.AcquireWriterTimeout = 1 * time.Millisecond // very short timeout
transport, err := NewTCPTransport(cfg, logger, nil)
require.NoError(t, err)

writeCt := 100
var reqWg sync.WaitGroup
for i := 0; i < writeCt; i++ {
reqWg.Add(1)
go func() {
defer reqWg.Done()
transport.WriteTo([]byte("test"), listener.Addr().String()) // nolint:errcheck
}()
}
reqWg.Wait()

require.NoError(t, transport.Shutdown())
gotErrorCt := strings.Count(logs.String(), "WriteTo failed to acquire a writer. Dropping message")
assert.Less(t, gotErrorCt, writeCt, "expected to have less errors (%d) than total writes (%d). Some writes should pass.", gotErrorCt, writeCt)
assert.NotZero(t, gotErrorCt, "expected errors, got none")
}

func TestFinalAdvertiseAddr(t *testing.T) {
tests := map[string]struct {
advertiseAddr string
Expand Down

0 comments on commit 879ff5a

Please sign in to comment.