diff --git a/pkg/ticker/ticker.go b/pkg/ticker/ticker.go new file mode 100644 index 0000000000..566fc03f8b --- /dev/null +++ b/pkg/ticker/ticker.go @@ -0,0 +1,140 @@ +// Package ticker provides a dynamic ticker that can change its interval at runtime. +// The ticker can be stopped gracefully and handles context-based termination. +// +// This package is useful for scenarios where periodic execution of a function is needed +// and the interval might need to change dynamically based on runtime conditions. +// +// It also invokes a first tick immediately after the ticker starts. It's safe to use it concurrently. +// +// It also terminates gracefully when the context is done (return ctx.Err()) or when the stop signal is received. +// +// Example usage: +// +// ticker := New(time.Second, func(ctx context.Context, t *Ticker) error { +// resp, err := client.GetPrice(ctx) +// if err != nil { +// logger.Err(err).Error().Msg("failed to get price") +// return nil +// } +// +// observer.SetPrice(resp.GasPrice) +// t.SetInterval(resp.GasPriceInterval) +// +// return nil +// }) +// +// err := ticker.Run(ctx) +package ticker + +import ( + "context" + "fmt" + "sync" + "time" + + "cosmossdk.io/errors" +) + +// Ticker represents a ticker that will run a function periodically. +// It also invokes BEFORE ticker starts. +type Ticker struct { + interval time.Duration + ticker *time.Ticker + task Task + signalChan chan struct{} + + // runnerMu is a mutex to prevent double run + runnerMu sync.Mutex + + // stateMu is a mutex to prevent concurrent SetInterval calls + stateMu sync.Mutex + + stopped bool +} + +// Task is a function that will be called by the Ticker +type Task func(ctx context.Context, t *Ticker) error + +// New creates a new Ticker. +func New(interval time.Duration, runner Task) *Ticker { + return &Ticker{interval: interval, task: runner} +} + +// Run creates and runs a new Ticker. +func Run(ctx context.Context, interval time.Duration, task Task) error { + return New(interval, task).Run(ctx) +} + +// SecondsFromUint64 converts uint64 to time.Duration in seconds. +func SecondsFromUint64(d uint64) time.Duration { + return time.Duration(d) * time.Second +} + +// Run runs the ticker by blocking current goroutine. It also invokes BEFORE ticker starts. +// Stops when (if any): +// - context is done (returns ctx.Err()) +// - task returns an error or panics +// - shutdown signal is received +func (t *Ticker) Run(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during ticker run: %v", r) + } + }() + + // prevent concurrent runs + t.runnerMu.Lock() + defer t.runnerMu.Unlock() + + // setup + t.ticker = time.NewTicker(t.interval) + t.signalChan = make(chan struct{}) + t.stopped = false + + // initial run + if err := t.task(ctx, t); err != nil { + return errors.Wrap(err, "ticker task failed") + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.ticker.C: + if err := t.task(ctx, t); err != nil { + return errors.Wrap(err, "ticker task failed") + } + case <-t.signalChan: + return nil + } + } +} + +// SetInterval updates the interval of the ticker. +func (t *Ticker) SetInterval(interval time.Duration) { + t.stateMu.Lock() + defer t.stateMu.Unlock() + + // noop + if t.interval == interval || t.ticker == nil { + return + } + + t.interval = interval + t.ticker.Reset(interval) +} + +// Stop stops the ticker. Safe to call concurrently or multiple times. +func (t *Ticker) Stop() { + t.stateMu.Lock() + defer t.stateMu.Unlock() + + // noop + if t.stopped || t.signalChan == nil { + return + } + + close(t.signalChan) + t.stopped = true + t.ticker.Stop() +} diff --git a/pkg/ticker/ticker_test.go b/pkg/ticker/ticker_test.go new file mode 100644 index 0000000000..671091c71f --- /dev/null +++ b/pkg/ticker/ticker_test.go @@ -0,0 +1,173 @@ +package ticker + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTicker(t *testing.T) { + const ( + dur = time.Millisecond * 100 + durSmall = dur / 10 + ) + + t.Run("Basic case with context", func(t *testing.T) { + // ARRANGE + // Given a counter + var counter int + + // And a context + ctx, cancel := context.WithTimeout(context.Background(), dur+durSmall) + defer cancel() + + // And a ticker + ticker := New(dur, func(_ context.Context, t *Ticker) error { + counter++ + + return nil + }) + + // ACT + err := ticker.Run(ctx) + + // ASSERT + assert.ErrorIs(t, err, context.DeadlineExceeded) + + // two runs: start run + 1 tick + assert.Equal(t, 2, counter) + }) + + t.Run("Halts when error occurred", func(t *testing.T) { + // ARRANGE + // Given a counter + var counter int + + ctx := context.Background() + + // And a ticker func that returns an error after 10 runs + ticker := New(durSmall, func(_ context.Context, t *Ticker) error { + counter++ + if counter > 9 { + return fmt.Errorf("oops") + } + + return nil + }) + + // ACT + err := ticker.Run(ctx) + + // ASSERT + assert.ErrorContains(t, err, "oops") + assert.Equal(t, 10, counter) + }) + + t.Run("Dynamic interval update", func(t *testing.T) { + // ARRANGE + // Given a counter + var counter int + + // Given duration + duration := dur * 10 + + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + // And a ticker what decreases the interval by 2 each time + ticker := New(durSmall, func(_ context.Context, ticker *Ticker) error { + t.Logf("Counter: %d, Duration: %s", counter, duration.String()) + + counter++ + duration /= 2 + + ticker.SetInterval(duration) + + return nil + }) + + // ACT + err := ticker.Run(ctx) + + // ASSERT + assert.ErrorIs(t, err, context.DeadlineExceeded) + + // It should have run at 2 times with ctxTimeout = tickerDuration (start + 1 tick), + // But it should have run more than that because of the interval decrease + assert.GreaterOrEqual(t, counter, 2) + }) + + t.Run("Stop ticker", func(t *testing.T) { + // ARRANGE + // Given a counter + var counter int + + // And a context + ctx := context.Background() + + // And a ticker + ticker := New(durSmall, func(_ context.Context, _ *Ticker) error { + counter++ + return nil + }) + + // And a function with a stop signal + go func() { + time.Sleep(dur) + ticker.Stop() + }() + + // ACT + err := ticker.Run(ctx) + + // ASSERT + assert.NoError(t, err) + assert.Greater(t, counter, 8) + + t.Run("Stop ticker for the second time", func(t *testing.T) { + ticker.Stop() + }) + }) + + t.Run("Panic", func(t *testing.T) { + // ARRANGE + // Given a context + ctx := context.Background() + + // And a ticker + ticker := New(durSmall, func(_ context.Context, _ *Ticker) error { + panic("oops") + }) + + // ACT + err := ticker.Run(ctx) + + // ASSERT + assert.ErrorContains(t, err, "panic during ticker run: oops") + }) + + t.Run("Run as a single call", func(t *testing.T) { + // ARRANGE + // Given a counter + var counter int + + // Given a context + ctx, cancel := context.WithTimeout(context.Background(), dur+durSmall) + defer cancel() + + tick := func(ctx context.Context, t *Ticker) error { + counter++ + return nil + } + + // ACT + err := Run(ctx, dur, tick) + + // ASSERT + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, 2, counter) + }) +} diff --git a/zetaclient/chains/evm/observer/inbound.go b/zetaclient/chains/evm/observer/inbound.go index 19ad1f14d5..abf21e7e5b 100644 --- a/zetaclient/chains/evm/observer/inbound.go +++ b/zetaclient/chains/evm/observer/inbound.go @@ -20,9 +20,11 @@ import ( "github.com/zeta-chain/protocol-contracts/pkg/contracts/evm/erc20custody.sol" "github.com/zeta-chain/protocol-contracts/pkg/contracts/evm/zetaconnector.non-eth.sol" + "github.com/zeta-chain/zetacore/pkg/bg" "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/pkg/coin" "github.com/zeta-chain/zetacore/pkg/constant" + "github.com/zeta-chain/zetacore/pkg/ticker" "github.com/zeta-chain/zetacore/x/crosschain/types" "github.com/zeta-chain/zetacore/zetaclient/chains/evm" "github.com/zeta-chain/zetacore/zetaclient/compliance" @@ -36,42 +38,46 @@ import ( // WatchInbound watches evm chain for incoming txs and post votes to zetacore // TODO(revamp): move ticker function to a separate file func (ob *Observer) WatchInbound(ctx context.Context) error { - app, err := zctx.FromContext(ctx) - if err != nil { - return err + sampledLogger := ob.Logger().Inbound.Sample(&zerolog.BasicSampler{N: 10}) + interval := ticker.SecondsFromUint64(ob.GetChainParams().InboundTicker) + task := func(ctx context.Context, t *ticker.Ticker) error { + return ob.watchInboundOnce(ctx, t, sampledLogger) } - ticker, err := clienttypes.NewDynamicTicker( - fmt.Sprintf("EVM_WatchInbound_%d", ob.Chain().ChainId), - ob.GetChainParams().InboundTicker, - ) + t := ticker.New(interval, task) + + bg.Work(ctx, func(_ context.Context) error { + <-ob.StopChannel() + t.Stop() + ob.Logger().Inbound.Info().Msg("WatchInbound stopped") + return nil + }) + + ob.Logger().Inbound.Info().Msgf("WatchInbound started") + + return t.Run(ctx) +} + +func (ob *Observer) watchInboundOnce(ctx context.Context, t *ticker.Ticker, sampledLogger zerolog.Logger) error { + app, err := zctx.FromContext(ctx) if err != nil { - ob.Logger().Inbound.Error().Err(err).Msg("error creating ticker") return err } - defer ticker.Stop() - ob.Logger().Inbound.Info().Msgf("WatchInbound started for chain %d", ob.Chain().ChainId) - sampledLogger := ob.Logger().Inbound.Sample(&zerolog.BasicSampler{N: 10}) + // noop + if !app.IsInboundObservationEnabled() { + ob.Logger().Inbound.Warn().Msg("WatchInbound: inbound observation is disabled") + return nil + } - for { - select { - case <-ticker.C(): - if !app.IsInboundObservationEnabled() { - sampledLogger.Info(). - Msgf("WatchInbound: inbound observation is disabled for chain %d", ob.Chain().ChainId) - continue - } - err := ob.ObserveInbound(ctx, sampledLogger) - if err != nil { - ob.Logger().Inbound.Err(err).Msg("WatchInbound: observeInbound error") - } - ticker.UpdateInterval(ob.GetChainParams().InboundTicker, ob.Logger().Inbound) - case <-ob.StopChannel(): - ob.Logger().Inbound.Info().Msgf("WatchInbound stopped for chain %d", ob.Chain().ChainId) - return nil - } + if err := ob.ObserveInbound(ctx, sampledLogger); err != nil { + ob.Logger().Inbound.Err(err).Msg("WatchInbound: observeInbound error") } + + newInterval := ticker.SecondsFromUint64(ob.GetChainParams().InboundTicker) + t.SetInterval(newInterval) + + return nil } // WatchInboundTracker gets a list of Inbound tracker suggestions from zeta-core at each tick and tries to check if the in-tx was confirmed.