Skip to content
This repository has been archived by the owner on Dec 28, 2024. It is now read-only.

Commit

Permalink
fix: watch nats epoch changes
Browse files Browse the repository at this point in the history
When restarting a cluster, there can be a race condition in epoch calculation and writing, so we must keep it in sync
  • Loading branch information
palkan committed Nov 2, 2023
1 parent ed7ea37 commit 967a28a
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 20 deletions.
1 change: 1 addition & 0 deletions broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type LocalBroker interface {
Start() error
Shutdown(ctx context.Context) error
SetEpoch(epoch string)
GetEpoch() string
HistoryFrom(stream string, epoch string, offset uint64) ([]common.StreamMessage, error)
HistorySince(stream string, ts int64) ([]common.StreamMessage, error)
Store(stream string, msg []byte, seq uint64, ts time.Time) (uint64, error)
Expand Down
22 changes: 16 additions & 6 deletions broker/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ type Memory struct {

streamsMu sync.RWMutex
sessionsMu sync.RWMutex
epochMu sync.RWMutex
}

var _ Broker = (*Memory)(nil)
Expand All @@ -201,18 +202,24 @@ func NewMemoryBroker(node Broadcaster, config *Config) *Memory {
func (b *Memory) Announce() string {
return fmt.Sprintf(
"Using in-memory broker (epoch: %s, history limit: %d, history ttl: %ds, sessions ttl: %ds)",
b.epoch,
b.GetEpoch(),
b.config.HistoryLimit,
b.config.HistoryTTL,
b.config.SessionsTTL,
)
}

func (b *Memory) GetEpoch() string {
b.epochMu.RLock()
defer b.epochMu.RUnlock()

return b.epoch
}

func (b *Memory) SetEpoch(v string) {
b.epochMu.Lock()
defer b.epochMu.Unlock()

b.epoch = v
}

Expand All @@ -229,7 +236,7 @@ func (b *Memory) Shutdown(ctx context.Context) error {
func (b *Memory) HandleBroadcast(msg *common.StreamMessage) {
offset := b.add(msg.Stream, msg.Data)

msg.Epoch = b.epoch
msg.Epoch = b.GetEpoch()
msg.Offset = offset

if b.tracker.Has(msg.Stream) {
Expand Down Expand Up @@ -264,8 +271,10 @@ func (b *Memory) Unsubscribe(stream string) string {
}

func (b *Memory) HistoryFrom(name string, epoch string, offset uint64) ([]common.StreamMessage, error) {
if b.epoch != epoch {
return nil, fmt.Errorf("Unknown epoch: %s, current: %s", epoch, b.epoch)
bepoch := b.GetEpoch()

if bepoch != epoch {
return nil, fmt.Errorf("Unknown epoch: %s, current: %s", epoch, bepoch)
}

stream := b.get(name)
Expand All @@ -281,7 +290,7 @@ func (b *Memory) HistoryFrom(name string, epoch string, offset uint64) ([]common
Stream: name,
Data: entry.data,
Offset: entry.offset,
Epoch: b.epoch,
Epoch: bepoch,
})
})

Expand All @@ -299,14 +308,15 @@ func (b *Memory) HistorySince(name string, ts int64) ([]common.StreamMessage, er
return nil, nil
}

bepoch := b.GetEpoch()
history := []common.StreamMessage{}

err := stream.filterByTime(ts, func(entry *entry) {
history = append(history, common.StreamMessage{
Stream: name,
Data: entry.data,
Offset: entry.offset,
Epoch: b.epoch,
Epoch: bepoch,
})
})

Expand Down
79 changes: 65 additions & 14 deletions broker/nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ type NATS struct {
nconf *natsconfig.NATSConfig
conn *nats.Conn

js jetstream.JetStream
kv jetstream.KeyValue
js jetstream.JetStream
kv jetstream.KeyValue
epochKV jetstream.KeyValue

jstreams *lru[string]
jconsumers *lru[jetstream.Consumer]
Expand All @@ -36,9 +37,13 @@ type NATS struct {
local LocalBroker

clientMu sync.RWMutex
epochMu sync.RWMutex

epoch string

shutdownCtx context.Context
shutdownFn func()

log *log.Entry
}

Expand All @@ -61,10 +66,14 @@ func WithNATSLocalBroker(b LocalBroker) NATSOption {
}

func NewNATSBroker(broadcaster Broadcaster, c *Config, nc *natsconfig.NATSConfig, opts ...NATSOption) *NATS {
shutdownCtx, shutdownFn := context.WithCancel(context.Background())

n := NATS{
broadcaster: broadcaster,
conf: c,
nconf: nc,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
tracker: NewStreamsTracker(),
streamSync: newStreamsSynchronizer(),
jstreams: newLRU[string](time.Duration(c.HistoryTTL * int64(time.Second))),
Expand Down Expand Up @@ -134,17 +143,20 @@ func (n *NATS) Start() error {
return errorx.Decorate(err, "Failed to calculate epoch")
}

n.epoch = epoch

n.local.SetEpoch(epoch)

n.writeEpoch(epoch)
err = n.local.Start()

if err != nil {
return errorx.Decorate(err, "Failed to start internal memory broker")
}

n.log.Debugf("Current epoch: %s", n.epoch)
err = n.watchEpoch(n.shutdownCtx)

if err != nil {
n.log.Warnf("failed to set up epoch watcher: %s", err)
}

n.log.Debugf("Current epoch: %s", epoch)

return nil
}
Expand All @@ -153,6 +165,8 @@ func (n *NATS) Shutdown(ctx context.Context) error {
n.clientMu.Lock()
defer n.clientMu.Unlock()

n.shutdownFn()

if n.conn != nil {
n.conn.Close()
n.conn = nil
Expand All @@ -172,8 +186,8 @@ func (n *NATS) Announce() string {
}

func (n *NATS) Epoch() string {
n.clientMu.RLock()
defer n.clientMu.RUnlock()
n.epochMu.RLock()
defer n.epochMu.RUnlock()

return n.epoch
}
Expand All @@ -193,13 +207,19 @@ func (n *NATS) SetEpoch(epoch string) error {
return err
}

n.epoch = epoch
n.writeEpoch(epoch)

return nil
}

func (n *NATS) writeEpoch(val string) {
n.epochMu.Lock()
defer n.epochMu.Unlock()

n.epoch = val
if n.local != nil {
n.local.SetEpoch(epoch)
n.local.SetEpoch(val)
}

return nil
}

func (n *NATS) HandleBroadcast(msg *common.StreamMessage) {
Expand All @@ -210,7 +230,7 @@ func (n *NATS) HandleBroadcast(msg *common.StreamMessage) {
return
}

msg.Epoch = n.epoch
msg.Epoch = n.Epoch()
msg.Offset = offset

if n.tracker.Has(msg.Stream) {
Expand Down Expand Up @@ -487,6 +507,8 @@ fetchEpoch:
return "", errorx.Decorate(err, "failed to connect to JetStream KV")
}

n.epochKV = kv

_, err = kv.Create(context.Background(), epochKey, []byte(maybeNewEpoch))

if err != nil && strings.Contains(err.Error(), "key exists") {
Expand All @@ -508,6 +530,35 @@ fetchEpoch:
return maybeNewEpoch, nil
}

func (n *NATS) watchEpoch(ctx context.Context) error {
watcher, err := n.epochKV.Watch(context.Background(), epochKey, jetstream.IgnoreDeletes())

if err != nil {
return err
}

go func() {
for {
select {
case <-ctx.Done():
watcher.Stop() // nolint:errcheck
return
case entry := <-watcher.Updates():
if entry != nil {
newEpoch := string(entry.Value())

if n.Epoch() != newEpoch {
n.log.Warnf("epoch updated: %s", newEpoch)
n.writeEpoch(newEpoch)
}
}
}
}
}()

return nil
}

func (n *NATS) fetchBucketWithTTL(key string, ttl time.Duration) (jetstream.KeyValue, error) {
var bucket jetstream.KeyValue
newBucket := true
Expand Down
23 changes: 23 additions & 0 deletions broker/nats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,29 @@ func TestNATSBroker_Epoch(t *testing.T) {
defer anotherBroker.Shutdown(context.Background()) // nolint: errcheck

assert.Equal(t, epoch, anotherBroker.Epoch())

// Now let's test that epoch changes are picked up
require.NoError(t, anotherBroker.SetEpoch("new-epoch"))

assert.Equal(t, "new-epoch", anotherBroker.Epoch())
assert.Equal(t, "new-epoch", anotherBroker.local.GetEpoch())

timer := time.After(2 * time.Second)

wait:
for {
select {
case <-timer:
assert.Fail(t, "Epoch change wasn't picked up")
return
default:
if broker.Epoch() == "new-epoch" {
break wait
}

time.Sleep(100 * time.Millisecond)
}
}
}

func buildNATSServer(t *testing.T, addr string) *enats.Service {
Expand Down

0 comments on commit 967a28a

Please sign in to comment.