Skip to content

Commit

Permalink
refactor: Load balancer concurrency (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidNix authored Sep 11, 2023
1 parent 69ae697 commit 322b9bc
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 75 deletions.
74 changes: 18 additions & 56 deletions privval/load_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package privval

import (
"errors"
"sync"

cometlog "github.com/cometbft/cometbft/libs/log"
privvalproto "github.com/cometbft/cometbft/proto/tendermint/privval"
Expand All @@ -12,80 +11,43 @@ import (
type RemoteSignerLoadBalancer struct {
logger cometlog.Logger
listeners []SignerListener
avail chan SignerListener // Available listeners that are ready to accept requests.
}

func NewRemoteSignerLoadBalancer(logger cometlog.Logger, listeners []SignerListener) *RemoteSignerLoadBalancer {
ch := make(chan SignerListener, len(listeners))
for i := range listeners {
ch <- listeners[i]
}
return &RemoteSignerLoadBalancer{
logger: logger,
listeners: listeners,
avail: ch,
}
}

// SendRequest sends a request to the first available listener.
func (sl *RemoteSignerLoadBalancer) SendRequest(request privvalproto.Message) (*privvalproto.Message, error) {
var r racer
var res signerListenerEndpointResponse

r.wg.Add(1)

for _, listener := range sl.listeners {
go sl.sendRequestIfFirst(listener, &r, request, &res)
}
func (lb *RemoteSignerLoadBalancer) SendRequest(request privvalproto.Message) (*privvalproto.Message, error) {
lis := <-lb.avail
defer func() { lb.avail <- lis }()

r.wg.Wait()

return res.res, res.err
lb.logger.Debug("Sent request to listener", "address", lis.address)
return lis.SendRequest(request)
}

func (sl *RemoteSignerLoadBalancer) Start() error {
for _, listener := range sl.listeners {
func (lb *RemoteSignerLoadBalancer) Start() error {
for _, listener := range lb.listeners {
if err := listener.Start(); err != nil {
return err
}
}
return nil
}

func (sl *RemoteSignerLoadBalancer) Stop() error {
var errs []error
for _, listener := range sl.listeners {
if err := listener.Stop(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

type signerListenerEndpointResponse struct {
res *privvalproto.Message
err error
}

func (l *RemoteSignerLoadBalancer) sendRequestIfFirst(listener SignerListener, r *racer, request privvalproto.Message, res *signerListenerEndpointResponse) {
listener.instanceMtx.Lock()
defer listener.instanceMtx.Unlock()
first := r.race()
if !first {
return
}
res.res, res.err = listener.SendRequestLocked(request)
r.wg.Done()
l.logger.Debug("Sent request to listener", "address", listener.address)
}

type racer struct {
mu sync.Mutex
wg sync.WaitGroup
handled bool
}

// returns true if first
func (r *racer) race() bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.handled {
return false
func (lb *RemoteSignerLoadBalancer) Stop() error {
var err error
for _, listener := range lb.listeners {
err = errors.Join(err, listener.Stop())
}
r.handled = true
return true
return err
}
16 changes: 4 additions & 12 deletions privval/load_balancer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package privval_test

import (
"io"
"net"
"testing"
"time"
Expand All @@ -14,12 +15,6 @@ import (
"github.com/strangelove-ventures/horcrux-proxy/privval"
)

type devNull struct{}

func (devNull) Write(p []byte) (int, error) {
return len(p), nil
}

func TestLoadBalancer(t *testing.T) {
var listenAddrs = []string{
"tcp://127.0.0.1:37321",
Expand All @@ -28,7 +23,7 @@ func TestLoadBalancer(t *testing.T) {
"tcp://127.0.0.1:37324",
}

logger := log.NewTMJSONLogger(devNull{})
logger := log.NewTMJSONLogger(io.Discard)

listeners := make([]privval.SignerListener, len(listenAddrs))
for i, addr := range listenAddrs {
Expand All @@ -37,13 +32,11 @@ func TestLoadBalancer(t *testing.T) {

lb := privval.NewRemoteSignerLoadBalancer(logger, listeners)

err := lb.Start()

t.Cleanup(func() {
_ = lb.Stop()
})

require.NoError(t, err)
require.NoError(t, lb.Start())

remoteSigners := make([]*MockRemoteSigner, len(listenAddrs))

Expand All @@ -70,8 +63,7 @@ func TestLoadBalancer(t *testing.T) {
})
}

err = eg.Wait()
require.NoError(t, err)
require.NoError(t, eg.Wait())

total := 0
for i := range listenAddrs {
Expand Down
4 changes: 2 additions & 2 deletions privval/remote_signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ type MockRemoteSigner struct {
dialer net.Dialer
}

func (m *MockRemoteSigner) Counter() Counter {
return m.counter.Copy()
func (rs *MockRemoteSigner) Counter() Counter {
return rs.counter.Copy()
}

// NewMockRemoteSigner return a MockRemoteSigner that will dial using the given
Expand Down
5 changes: 0 additions & 5 deletions privval/signer_listener_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@ func (sl *SignerListenerEndpoint) SendRequest(request privvalproto.Message) (*pr
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()

return sl.SendRequestLocked(request)
}

// SendRequest ensures there is a connection, sends a request and waits for a response
func (sl *SignerListenerEndpoint) SendRequestLocked(request privvalproto.Message) (*privvalproto.Message, error) {
err := sl.ensureConnection(sl.timeoutAccept)
if err != nil {
return nil, err
Expand Down

0 comments on commit 322b9bc

Please sign in to comment.