diff --git a/cmd/start.go b/cmd/start.go index 9979990..b23558d 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -16,6 +16,8 @@ const ( flagListen = "listen" flagAll = "all" flagGRPCAddress = "grpc" + flagOperator = "operator" + flagSentry = "sentry" ) func startCmd() *cobra.Command { @@ -65,7 +67,11 @@ func startCmd() *cobra.Command { ctx := cmd.Context() - watcher, err := NewSentryWatcher(ctx, logger, all, hc) + // if we're running in kubernetes, we can auto-discover sentries + operator, _ := cmd.Flags().GetBool(flagOperator) + sentries, _ := cmd.Flags().GetStringArray(flagSentry) + + watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries) if err != nil { return err } @@ -78,7 +84,9 @@ func startCmd() *cobra.Command { }, } - cmd.Flags().StringArrayP(flagListen, "l", []string{"tcp://0.0.0.0:1234"}, "Privval listen addresses for the proxy") + cmd.Flags().StringArrayP(flagListen, "l", nil, "Privval listen addresses for the proxy (e.g. tcp://0.0.0.0:1234)") + cmd.Flags().StringArrayP(flagSentry, "s", nil, "Privval connect addresses for the proxy") + cmd.Flags().BoolP(flagOperator, "o", true, "Use this when running in kubernetes with the Cosmos Operator to auto-discover sentries") cmd.Flags().StringP(flagGRPCAddress, "g", "", "GRPC address for the proxy") cmd.Flags().BoolP(flagAll, "a", false, "Connect to sentries on all nodes") cmd.Flags().String(flagLogLevel, "info", "Set log level (debug, info, error, none)") diff --git a/cmd/watcher.go b/cmd/watcher.go index 754c634..5c1ebab 100644 --- a/cmd/watcher.go +++ b/cmd/watcher.go @@ -23,12 +23,14 @@ const ( ) type SentryWatcher struct { - all bool - client *kubernetes.Clientset - hc signer.HorcruxConnection - log cometlog.Logger - node string - sentries map[string]*signer.ReconnRemoteSigner + all bool + client *kubernetes.Clientset + hc signer.HorcruxConnection + log cometlog.Logger + node string + operator bool + persistentSentries []*signer.ReconnRemoteSigner + sentries map[string]*signer.ReconnRemoteSigner stop chan struct{} done chan struct{} @@ -39,50 +41,72 @@ func NewSentryWatcher( logger cometlog.Logger, all bool, // should we connect to sentries on all nodes, or just this node? hc signer.HorcruxConnection, + operator bool, + sentries []string, ) (*SentryWatcher, error) { - config, err := rest.InClusterConfig() - if err != nil { - return nil, fmt.Errorf("failed to get in cluster config: %w", err) - } - // creates the clientset - clientset, err := kubernetes.NewForConfig(config) - if err != nil { - return nil, fmt.Errorf("failed to create kube clientset: %w", err) - } - + var clientset *kubernetes.Clientset var thisNode string - if !all { - // need to determine which node this pod is on so we can only connect to sentries on this node - nsbz, err := os.ReadFile(namespaceFile) + if operator { + config, err := rest.InClusterConfig() if err != nil { - return nil, fmt.Errorf("failed to read namespace from service account: %w", err) + return nil, fmt.Errorf("failed to get in cluster config: %w", err) } - ns := string(nsbz) - - thisPod, err := clientset.CoreV1().Pods(ns).Get(ctx, os.Getenv("HOSTNAME"), metav1.GetOptions{}) + // creates the clientset + clientset, err := kubernetes.NewForConfig(config) if err != nil { - return nil, fmt.Errorf("failed to get this pod: %w", err) + return nil, fmt.Errorf("failed to create kube clientset: %w", err) + } + + if !all { + // need to determine which node this pod is on so we can only connect to sentries on this node + + nsbz, err := os.ReadFile(namespaceFile) + if err != nil { + return nil, fmt.Errorf("failed to read namespace from service account: %w", err) + } + ns := string(nsbz) + + thisPod, err := clientset.CoreV1().Pods(ns).Get(ctx, os.Getenv("HOSTNAME"), metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to get this pod: %w", err) + } + + thisNode = thisPod.Spec.NodeName } + } - thisNode = thisPod.Spec.NodeName + persistentSentries := make([]*signer.ReconnRemoteSigner, len(sentries)) + for i, sentry := range sentries { + dialer := net.Dialer{Timeout: 2 * time.Second} + persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer) } return &SentryWatcher{ - all: all, - client: clientset, - done: make(chan struct{}), - hc: hc, - log: logger, - node: thisNode, - sentries: make(map[string]*signer.ReconnRemoteSigner), - stop: make(chan struct{}), + all: all, + client: clientset, + done: make(chan struct{}), + hc: hc, + log: logger, + node: thisNode, + operator: operator, + persistentSentries: persistentSentries, + sentries: make(map[string]*signer.ReconnRemoteSigner), + stop: make(chan struct{}), }, nil } // Watch will reconcile the sentries with the kube api at a reasonable interval. // It must be called only once. func (w *SentryWatcher) Watch(ctx context.Context) { + for _, sentry := range w.persistentSentries { + if err := sentry.Start(); err != nil { + w.log.Error("Failed to start persistent sentry", "error", err) + } + } + if !w.operator { + return + } defer close(w.done) const interval = 30 * time.Second timer := time.NewTimer(interval) @@ -109,6 +133,9 @@ func (w *SentryWatcher) Stop() error { close(w.stop) <-w.done var err error + for _, sentry := range w.persistentSentries { + err = errors.Join(err, sentry.Stop()) + } for _, sentry := range w.sentries { err = errors.Join(err, sentry.Stop()) }