Skip to content

Commit

Permalink
fix: Data race in sentries (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidNix authored Sep 13, 2023
1 parent 322b9bc commit 7c1fefd
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 87 deletions.
15 changes: 3 additions & 12 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,16 @@ package cmd
import (
"os"

cometlog "github.com/cometbft/cometbft/libs/log"
"github.com/spf13/cobra"
"github.com/strangelove-ventures/horcrux-proxy/privval"
"github.com/strangelove-ventures/horcrux-proxy/signer"
)

type appState struct {
logger cometlog.Logger
loadBalancer *privval.RemoteSignerLoadBalancer
sentries map[string]*signer.ReconnRemoteSigner
}

func rootCmd(a *appState) *cobra.Command {
func rootCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "horcrux-proxy",
Short: "A tendermint remote signer proxy",
}

cmd.AddCommand(startCmd(a))
cmd.AddCommand(startCmd())
cmd.AddCommand(versionCmd())

return cmd
Expand All @@ -30,7 +21,7 @@ func rootCmd(a *appState) *cobra.Command {
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd(new(appState)).Execute(); err != nil {
if err := rootCmd().Execute(); err != nil {
// Cobra will print the error
os.Exit(1)
}
Expand Down
44 changes: 21 additions & 23 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/spf13/cobra"

"github.com/strangelove-ventures/horcrux-proxy/privval"
"github.com/strangelove-ventures/horcrux-proxy/signer"
)

const (
Expand All @@ -17,7 +16,7 @@ const (
flagAll = "all"
)

func startCmd(a *appState) *cobra.Command {
func startCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "start",
Short: "Start horcrux-proxy process",
Expand All @@ -32,31 +31,33 @@ func startCmd(a *appState) *cobra.Command {
return fmt.Errorf("failed to parse log level: %w", err)
}

a.logger = cometlog.NewFilter(cometlog.NewTMLogger(cometlog.NewSyncWriter(out)), logLevelOpt).With("module", "validator")

a.logger.Info("Horcrux Proxy")
logger := cometlog.NewFilter(cometlog.NewTMLogger(cometlog.NewSyncWriter(out)), logLevelOpt).With("module", "validator")
logger.Info("Horcrux Proxy")

listenAddrs, _ := cmd.Flags().GetStringArray(flagListen)
all, _ := cmd.Flags().GetBool(flagAll)

listeners := make([]privval.SignerListener, len(listenAddrs))
for i, addr := range listenAddrs {
listeners[i] = privval.NewSignerListener(a.logger, addr)
listeners[i] = privval.NewSignerListener(logger, addr)
}

a.loadBalancer = privval.NewRemoteSignerLoadBalancer(a.logger, listeners)

if err := a.loadBalancer.Start(); err != nil {
loadBalancer := privval.NewRemoteSignerLoadBalancer(logger, listeners)
if err = loadBalancer.Start(); err != nil {
return fmt.Errorf("failed to start listener(s): %w", err)
}
defer logIfErr(logger, loadBalancer.Stop)

a.sentries = make(map[string]*signer.ReconnRemoteSigner)
ctx := cmd.Context()

if err := watchForChangedSentries(cmd.Context(), a, all); err != nil {
watcher, err := NewSentryWatcher(ctx, logger, all, loadBalancer)
if err != nil {
return err
}
defer logIfErr(logger, watcher.Stop)
go watcher.Watch(ctx)

waitAndTerminate(a)
waitForSignals(logger)

return nil
},
Expand All @@ -69,18 +70,15 @@ func startCmd(a *appState) *cobra.Command {
return cmd
}

func waitAndTerminate(a *appState) {
func logIfErr(logger cometlog.Logger, fn func() error) {
if err := fn(); err != nil {
logger.Error("Error", "err", err)
}
}

func waitForSignals(logger cometlog.Logger) {
done := make(chan struct{})
cometos.TrapSignal(a.logger, func() {
for _, s := range a.sentries {
err := s.Stop()
if err != nil {
panic(err)
}
}
if err := a.loadBalancer.Stop(); err != nil {
panic(err)
}
cometos.TrapSignal(logger, func() {
close(done)
})
<-done
Expand Down
117 changes: 78 additions & 39 deletions cmd/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package cmd

import (
"context"
"errors"
"fmt"
"net"
"os"
"time"

cometlog "github.com/cometbft/cometbft/libs/log"
"github.com/strangelove-ventures/horcrux-proxy/privval"
"github.com/strangelove-ventures/horcrux-proxy/signer"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -20,69 +23,105 @@ const (
labelCosmosSentry = "app.kubernetes.io/component=cosmos-sentry"
)

func watchForChangedSentries(
type SentryWatcher struct {
all bool
client *kubernetes.Clientset
lb *privval.RemoteSignerLoadBalancer
log cometlog.Logger
node string
sentries map[string]*signer.ReconnRemoteSigner

stop chan struct{}
done chan struct{}
}

func NewSentryWatcher(
ctx context.Context,
a *appState,
logger cometlog.Logger,
all bool, // should we connect to sentries on all nodes, or just this node?
) error {
lb *privval.RemoteSignerLoadBalancer,
) (*SentryWatcher, error) {
config, err := rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in cluster config: %w", err)
return nil, fmt.Errorf("failed to get in cluster config: %w", err)
}
// creates the clientset
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return fmt.Errorf("failed to create kube clientset: %w", err)
return nil, fmt.Errorf("failed to create kube clientset: %w", err)
}

thisNode := ""
var thisNode string
if !all {
// need to determine which node this pod is on so we can only connect to sentries on this node

nsbz, err := os.ReadFile(namespaceFile)
if err != nil {
return fmt.Errorf("failed to read namespace from service account: %w", err)
return nil, fmt.Errorf("failed to read namespace from service account: %w", err)
}
ns := string(nsbz)

thisPod, err := clientset.CoreV1().Pods(ns).Get(ctx, os.Getenv("HOSTNAME"), metav1.GetOptions{})
if err != nil {
return fmt.Errorf("failed to get this pod: %w", err)
return nil, fmt.Errorf("failed to get this pod: %w", err)
}

thisNode = thisPod.Spec.NodeName
}

t := time.NewTimer(30 * time.Second)
return &SentryWatcher{
all: all,
client: clientset,
done: make(chan struct{}),
lb: lb,
log: logger,
node: thisNode,
sentries: make(map[string]*signer.ReconnRemoteSigner),
stop: make(chan struct{}),
}, nil
}

go func() {
defer t.Stop()
for {
if err := reconcileSentries(ctx, a, thisNode, clientset, all); err != nil {
a.logger.Error("Failed to reconcile sentries with kube api", "error", err)
}
select {
case <-ctx.Done():
return
case <-t.C:
t.Reset(30 * time.Second)
}
// Watch will reconcile the sentries with the kube api at a reasonable interval.
// It must be called only once.
func (w *SentryWatcher) Watch(ctx context.Context) {
defer close(w.done)
const interval = 30 * time.Second
timer := time.NewTimer(interval)
defer timer.Stop()

for {
if err := w.reconcileSentries(ctx); err != nil {
w.log.Error("Failed to reconcile sentries with kube api", "error", err)
}
}()
select {
case <-w.stop:
return
case <-ctx.Done():
return
case <-timer.C:
timer.Reset(interval)
}
}
}

return nil
// Stop cleans up the sentries and stops the watcher. It must be called only once.
func (w *SentryWatcher) Stop() error {
// The dual channel synchronization ensures w.sentries is only read/mutated by one goroutine.
close(w.stop)
<-w.done
var err error
for _, sentry := range w.sentries {
err = errors.Join(err, sentry.Stop())
}
return err
}

func reconcileSentries(
func (w *SentryWatcher) reconcileSentries(
ctx context.Context,
a *appState,
thisNode string,
clientset *kubernetes.Clientset,
all bool, // should we connect to sentries on all nodes, or just this node?
) error {
configNodes := make([]string, 0)

services, err := clientset.CoreV1().Services("").List(ctx, metav1.ListOptions{
services, err := w.client.CoreV1().Services("").List(ctx, metav1.ListOptions{
LabelSelector: labelCosmosSentry,
})

Expand All @@ -97,7 +136,7 @@ func reconcileSentries(

set := labels.Set(s.Spec.Selector)

pods, err := clientset.CoreV1().Pods(s.Namespace).List(ctx, metav1.ListOptions{LabelSelector: set.AsSelector().String()})
pods, err := w.client.CoreV1().Pods(s.Namespace).List(ctx, metav1.ListOptions{LabelSelector: set.AsSelector().String()})
if err != nil {
return fmt.Errorf("failed to list pods in namespace %s for service %s: %w", s.Namespace, s.Name, err)
}
Expand All @@ -106,7 +145,7 @@ func reconcileSentries(
continue
}

if !all && pods.Items[0].Spec.NodeName != thisNode {
if !w.all && pods.Items[0].Spec.NodeName != w.node {
continue
}

Expand All @@ -118,21 +157,21 @@ func reconcileSentries(

for _, newConfigSentry := range configNodes {
foundNewConfigSentry := false
for existingSentry := range a.sentries {
for existingSentry := range w.sentries {
if existingSentry == newConfigSentry {
foundNewConfigSentry = true
break
}
}
if !foundNewConfigSentry {
a.logger.Info("Will add new sentry", "address", newConfigSentry)
w.log.Info("Will add new sentry", "address", newConfigSentry)
newSentries = append(newSentries, newConfigSentry)
}
}

removedSentries := make([]string, 0)

for existingSentry := range a.sentries {
for existingSentry := range w.sentries {
foundExistingSentry := false
for _, newConfigSentry := range configNodes {
if existingSentry == newConfigSentry {
Expand All @@ -141,26 +180,26 @@ func reconcileSentries(
}
}
if !foundExistingSentry {
a.logger.Info("Will remove existing sentry", "address", existingSentry)
w.log.Info("Will remove existing sentry", "address", existingSentry)
removedSentries = append(removedSentries, existingSentry)
}
}

for _, s := range removedSentries {
if err := a.sentries[s].Stop(); err != nil {
if err := w.sentries[s].Stop(); err != nil {
return fmt.Errorf("failed to stop remote signer: %w", err)
}
delete(a.sentries, s)
delete(w.sentries, s)
}

for _, newSentry := range newSentries {
dialer := net.Dialer{Timeout: 2 * time.Second}
s := signer.NewReconnRemoteSigner(newSentry, a.logger, a.loadBalancer, dialer)
s := signer.NewReconnRemoteSigner(newSentry, w.log, w.lb, dialer)

if err := s.Start(); err != nil {
return fmt.Errorf("failed to start new remote signer(s): %w", err)
}
a.sentries[newSentry] = s
w.sentries[newSentry] = s
}

return nil
Expand Down
Loading

0 comments on commit 7c1fefd

Please sign in to comment.