From 22f89e799617fa080296b10816f26be7fef90bad Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Wed, 11 Oct 2023 22:08:34 +0200 Subject: [PATCH] WIP --- cmd/run.go | 18 +++-- errors/errors.go | 3 + metrics/builtins.go | 11 ++- network/client.go | 75 ++++++++++++--------- network/engine.go | 115 ++++++++++++++++++++------------ network/network_helpers_test.go | 19 +++--- network/proxy.go | 35 ++++++++-- network/server.go | 32 +++++---- network/server_test.go | 45 +++++++------ plugin/plugin_registry.go | 4 +- 10 files changed, 227 insertions(+), 130 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 7aba4a41..03653c00 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -98,11 +98,12 @@ func StopGracefully( } } - logger.Info().Msg("Stopping GatewayD") - span.AddEvent("Stopping GatewayD", trace.WithAttributes( + logger.Info().Msg("GatewayD is shutting down") + span.AddEvent("GatewayD is shutting down", trace.WithAttributes( attribute.String("signal", signal), )) if healthCheckScheduler != nil { + healthCheckScheduler.Stop() healthCheckScheduler.Clear() logger.Info().Msg("Stopped health check scheduler") span.AddEvent("Stopped health check scheduler") @@ -268,7 +269,7 @@ var runCmd = &cobra.Command{ startDelay := time.Now().Add(conf.Plugin.HealthCheckPeriod) if _, err := healthCheckScheduler.Every( conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do(func() { - _, span = otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") + _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") defer span.End() var plugins []string @@ -749,9 +750,13 @@ var runCmd = &cobra.Command{ ) signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, signals...) - go func(pluginRegistry *plugin.Registry, + go func(pluginTimeoutCtx context.Context, + pluginRegistry *plugin.Registry, logger zerolog.Logger, servers map[string]*network.Server, + metricsMerger *metrics.Merger, + metricsServer *http.Server, + stopChan chan struct{}, ) { for sig := range signalsCh { for _, s := range signals { @@ -771,13 +776,14 @@ var runCmd = &cobra.Command{ } } } - }(pluginRegistry, logger, servers) + }(pluginTimeoutCtx, pluginRegistry, logger, servers, metricsMerger, metricsServer, stopChan) _, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers") // Start the server. for name, server := range servers { logger := loggers[name] go func( + span trace.Span, server *network.Server, logger zerolog.Logger, healthCheckScheduler *gocron.Scheduler, @@ -797,7 +803,7 @@ var runCmd = &cobra.Command{ pluginRegistry.Shutdown() os.Exit(gerr.FailedToStartServer) } - }(server, logger, healthCheckScheduler, metricsMerger, pluginRegistry) + }(span, server, logger, healthCheckScheduler, metricsMerger, pluginRegistry) } span.End() diff --git a/errors/errors.go b/errors/errors.go index f16a6501..a8799bbc 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -23,6 +23,7 @@ const ( ErrCodeServerSendFailed ErrCodeServerListenFailed ErrCodeSplitHostPortFailed + ErrCodeCloseListenerFailed ErrCodeAcceptFailed ErrCodeReadFailed ErrCodePutFailed @@ -85,6 +86,8 @@ var ( ErrCodeServerListenFailed, "couldn't listen on the server", nil) ErrSplitHostPortFailed = NewGatewayDError( ErrCodeSplitHostPortFailed, "failed to split host:port", nil) + ErrCloseListenerFailed = NewGatewayDError( + ErrCodeCloseListenerFailed, "failed to close listener", nil) ErrAcceptFailed = NewGatewayDError( ErrCodeAcceptFailed, "failed to accept connection", nil) diff --git a/metrics/builtins.go b/metrics/builtins.go index e56b978d..2e588a58 100644 --- a/metrics/builtins.go +++ b/metrics/builtins.go @@ -75,10 +75,15 @@ var ( Name: "proxied_connections", Help: "Number of proxy connects", }) - ProxyPassThroughs = promauto.NewCounter(prometheus.CounterOpts{ + ProxyPassThroughsToClient = promauto.NewCounter(prometheus.CounterOpts{ Namespace: Namespace, - Name: "proxy_passthroughs_total", - Help: "Number of successful proxy passthroughs", + Name: "proxy_passthroughs_to_client_total", + Help: "Number of successful proxy passthroughs from server to client", + }) + ProxyPassThroughsToServer = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: Namespace, + Name: "proxy_passthroughs_to_server_total", + Help: "Number of successful proxy passthroughs from client to server", }) ProxyPassThroughTerminations = promauto.NewCounter(prometheus.CounterOpts{ Namespace: Namespace, diff --git a/network/client.go b/network/client.go index de2ef2a8..9cff7a9d 100644 --- a/network/client.go +++ b/network/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "sync/atomic" "time" "github.com/gatewayd-io/gatewayd/config" @@ -26,8 +27,9 @@ type IClient interface { type Client struct { net.Conn - logger zerolog.Logger - ctx context.Context //nolint:containedctx + logger zerolog.Logger + ctx context.Context //nolint:containedctx + connected atomic.Bool TCPKeepAlive bool TCPKeepAlivePeriod time.Duration @@ -53,6 +55,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. return nil } + client.connected.Store(false) client.logger = logger // Try to resolve the address and log an error if it can't be resolved. @@ -87,6 +90,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. } client.Conn = conn + client.connected.Store(true) // Set the TCP keep alive. client.TCPKeepAlive = clientConfig.TCPKeepAlive @@ -146,6 +150,11 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Send") defer span.End() + if !c.connected.Load() { + span.RecordError(gerr.ErrClientNotConnected) + return 0, gerr.ErrClientNotConnected + } + sent := 0 received := len(data) for { @@ -170,8 +179,6 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { }, ).Msg("Sent data to server") - metrics.BytesSentToServer.Observe(float64(sent)) - return sent, nil } @@ -180,6 +187,11 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Receive") defer span.End() + if !c.connected.Load() { + span.RecordError(gerr.ErrClientNotConnected) + return 0, nil, gerr.ErrClientNotConnected + } + var ctx context.Context var cancel context.CancelFunc if c.ReceiveTimeout > 0 { @@ -192,26 +204,21 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { var received int buffer := bytes.NewBuffer(nil) // Read the data in chunks. - select { //nolint:gosimple - case <-time.After(time.Millisecond): - for ctx.Err() == nil { - chunk := make([]byte, c.ReceiveChunkSize) - read, err := c.Conn.Read(chunk) - if err != nil { - c.logger.Error().Err(err).Msg("Couldn't receive data from the server") - span.RecordError(err) - metrics.BytesReceivedFromServer.Observe(float64(received)) - return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) - } - received += read - buffer.Write(chunk[:read]) + for ctx.Err() == nil { + chunk := make([]byte, c.ReceiveChunkSize) + read, err := c.Conn.Read(chunk) + if err != nil { + c.logger.Error().Err(err).Msg("Couldn't receive data from the server") + span.RecordError(err) + return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) + } + received += read + buffer.Write(chunk[:read]) - if read == 0 || read < c.ReceiveChunkSize { - break - } + if read == 0 || read < c.ReceiveChunkSize { + break } } - metrics.BytesReceivedFromServer.Observe(float64(received)) return received, buffer.Bytes(), nil } @@ -220,6 +227,12 @@ func (c *Client) Close() { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Close") defer span.End() + // Set the deadline to now so that the connection is closed immediately. + if err := c.Conn.SetDeadline(time.Now()); err != nil { + c.logger.Error().Err(err).Msg("Failed to set deadline") + span.RecordError(err) + } + c.logger.Debug().Str("address", c.Address).Msg("Closing connection to server") if c.Conn != nil { c.Conn.Close() @@ -228,6 +241,7 @@ func (c *Client) Close() { c.Conn = nil c.Address = "" c.Network = "" + c.connected.Store(false) metrics.ServerConnections.Dec() } @@ -257,20 +271,15 @@ func (c *Client) IsConnected() bool { return false } - if n, err := c.Read([]byte{}); n == 0 && err != nil { - c.logger.Debug().Fields( - map[string]interface{}{ - "address": c.Address, - "reason": "read 0 bytes", - }).Msg("Connection to server is closed") - return false - } - - return true + return c.connected.Load() } // RemoteAddr returns the remote address of the client safely. func (c *Client) RemoteAddr() string { + if !c.connected.Load() { + return "" + } + if c.Conn != nil && c.Conn.RemoteAddr() != nil { return c.Conn.RemoteAddr().String() } @@ -280,6 +289,10 @@ func (c *Client) RemoteAddr() string { // LocalAddr returns the local address of the client safely. func (c *Client) LocalAddr() string { + if !c.connected.Load() { + return "" + } + if c.Conn != nil && c.Conn.LocalAddr() != nil { return c.Conn.LocalAddr().String() } diff --git a/network/engine.go b/network/engine.go index c9800e1a..ec026d02 100644 --- a/network/engine.go +++ b/network/engine.go @@ -5,6 +5,7 @@ import ( "net" "strconv" "sync" + "sync/atomic" "time" "github.com/gatewayd-io/gatewayd/config" @@ -35,6 +36,7 @@ type Engine struct { host string port int connections uint32 + running *atomic.Bool stopServer chan struct{} mu *sync.RWMutex } @@ -49,7 +51,13 @@ func (engine *Engine) Stop(ctx context.Context) error { _, cancel := context.WithDeadline(ctx, time.Now().Add(config.DefaultEngineStopTimeout)) defer cancel() + engine.running.Store(false) + if err := engine.listener.Close(); err != nil { + engine.stopServer <- struct{}{} + return gerr.ErrCloseListenerFailed.Wrap(err) + } engine.stopServer <- struct{}{} + close(engine.stopServer) return nil } @@ -59,6 +67,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { connections: 0, stopServer: make(chan struct{}), mu: &sync.RWMutex{}, + running: &atomic.Bool{}, } if action := server.OnBoot(server.engine); action != None { @@ -71,7 +80,6 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { server.logger.Error().Err(err).Msg("Server failed to start listening") return gerr.ErrServerListenFailed.Wrap(err) } - defer server.engine.listener.Close() if server.engine.listener == nil { server.logger.Error().Msg("Server is not properly initialized") @@ -102,56 +110,79 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { } for { - interval, action := server.OnTick() - if action == Shutdown { - server.OnShutdown(server.engine) - return - } - if interval == time.Duration(0) { + select { + case <-server.engine.stopServer: return + default: + interval, action := server.OnTick() + if action == Shutdown { + server.OnShutdown(server.engine) + return + } + if interval == time.Duration(0) { + return + } + time.Sleep(interval) } - time.Sleep(interval) } }(server) - for { - conn, err := server.engine.listener.Accept() - if err != nil { - server.logger.Error().Err(err).Msg("Failed to accept connection") - return gerr.ErrAcceptFailed.Wrap(err) - } + server.engine.running.Store(true) - if out, action := server.OnOpen(conn); action != None { - if _, err := conn.Write(out); err != nil { - server.logger.Error().Err(err).Msg("Failed to write to connection") - } - conn.Close() - if action == Shutdown { - server.OnShutdown(server.engine) - return nil - } - } - server.engine.mu.Lock() - server.engine.connections++ - server.engine.mu.Unlock() - - // For every new connection, a new unbuffered channel is created to help - // stop the proxy, recycle the server connection and close stale connections. - stopConnection := make(chan struct{}) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - if action := server.OnTraffic(conn, stopConnection); action == Close { - return + for { + select { + case <-server.engine.stopServer: + server.logger.Info().Msg("Server stopped") + return nil + default: + conn, err := server.engine.listener.Accept() + if err != nil { + if !server.engine.running.Load() { + return nil + } + server.logger.Error().Err(err).Msg("Failed to accept connection") + return gerr.ErrAcceptFailed.Wrap(err) } - }(server, conn, stopConnection) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - <-stopConnection + if out, action := server.OnOpen(conn); action != None { + if _, err := conn.Write(out); err != nil { + server.logger.Error().Err(err).Msg("Failed to write to connection") + } + conn.Close() + if action == Shutdown { + server.OnShutdown(server.engine) + return nil + } + } server.engine.mu.Lock() - server.engine.connections-- + server.engine.connections++ server.engine.mu.Unlock() - if action := server.OnClose(conn, err); action == Close { - return - } - }(server, conn, stopConnection) + + // For every new connection, a new unbuffered channel is created to help + // stop the proxy, recycle the server connection and close stale connections. + stopConnection := make(chan struct{}) + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + if action := server.OnTraffic(conn, stopConnection); action == Close { + return + } + }(server, conn, stopConnection) + + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + for { + select { + case <-stopConnection: + server.engine.mu.Lock() + server.engine.connections-- + server.engine.mu.Unlock() + if action := server.OnClose(conn, err); action == Close { + return + } + return + case <-server.engine.stopServer: + return + } + } + }(server, conn, stopConnection) + } } } diff --git a/network/network_helpers_test.go b/network/network_helpers_test.go index 766c0315..18c967e0 100644 --- a/network/network_helpers_test.go +++ b/network/network_helpers_test.go @@ -91,8 +91,10 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { # TYPE gatewayd_proxy_health_checks_total counter # HELP gatewayd_proxy_passthrough_terminations_total Number of proxy passthrough terminations by plugins # TYPE gatewayd_proxy_passthrough_terminations_total counter - # HELP gatewayd_proxy_passthroughs_total Number of successful proxy passthroughs - # TYPE gatewayd_proxy_passthroughs_total counter + # HELP gatewayd_proxy_passthroughs_to_client_total Number of successful proxy passthroughs + # TYPE gatewayd_proxy_passthroughs_to_client_total counter + # HELP gatewayd_proxy_passthroughs_to_server_total Number of successful proxy passthroughs + # TYPE gatewayd_proxy_passthroughs_to_server_total counter # HELP gatewayd_server_connections Number of server connections # TYPE gatewayd_server_connections gauge # HELP gatewayd_server_ticks_fired_total Total number of server ticks fired @@ -105,12 +107,12 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { want = metadata + ` gatewayd_bytes_received_from_client_sum 67 gatewayd_bytes_received_from_client_count 1 - gatewayd_bytes_received_from_server_sum 96 - gatewayd_bytes_received_from_server_count 4 + gatewayd_bytes_received_from_server_sum 24 + gatewayd_bytes_received_from_server_count 1 gatewayd_bytes_sent_to_client_sum 24 gatewayd_bytes_sent_to_client_count 1 - gatewayd_bytes_sent_to_server_sum 282 - gatewayd_bytes_sent_to_server_count 5 + gatewayd_bytes_sent_to_server_sum 67 + gatewayd_bytes_sent_to_server_count 1 gatewayd_client_connections 1 gatewayd_plugin_hooks_executed_total 11 gatewayd_plugin_hooks_registered_total 0 @@ -118,8 +120,9 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { gatewayd_proxied_connections 1 gatewayd_proxy_health_checks_total 0 gatewayd_proxy_passthrough_terminations_total 0 - gatewayd_proxy_passthroughs_total 1 - gatewayd_server_connections 5 + gatewayd_proxy_passthroughs_to_client_total 1 + gatewayd_proxy_passthroughs_to_server_total 1 + gatewayd_server_connections 3 gatewayd_traffic_bytes_sum 182 gatewayd_traffic_bytes_count 4 gatewayd_server_ticks_fired_total 1 diff --git a/network/proxy.go b/network/proxy.go index 66bf065d..02f2605c 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -316,7 +316,7 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { // If the hook wants to terminate the connection, do it. if pr.shouldTerminate(result) { if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil { - metrics.ProxyPassThroughs.Inc() + metrics.ProxyPassThroughsToClient.Inc() metrics.ProxyPassThroughTerminations.Inc() metrics.BytesSentToClient.Observe(float64(modReceived)) metrics.TotalTrafficBytes.Observe(float64(modReceived)) @@ -357,6 +357,8 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { } span.AddEvent("Ran the OnTrafficToServer hooks") + metrics.ProxyPassThroughsToServer.Inc() + return nil } @@ -459,7 +461,7 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { span.RecordError(errVerdict) } - metrics.ProxyPassThroughs.Inc() + metrics.ProxyPassThroughsToClient.Inc() return errVerdict } @@ -501,8 +503,16 @@ func (pr *Proxy) Shutdown() { defer span.End() pr.availableConnections.ForEach(func(key, value interface{}) bool { - if cl, ok := value.(*Client); ok { - cl.Close() + if client, ok := value.(*Client); ok { + if client.IsConnected() { + // This will stop all the Conn.Read() and Conn.Write() calls. + // Ref: https://groups.google.com/g/golang-nuts/c/VPVWFrpIEyo + if err := client.Conn.SetDeadline(time.Now()); err != nil { + pr.logger.Error().Err(err).Msg("Error setting the deadline") + span.RecordError(err) + } + client.Close() + } } return true }) @@ -511,14 +521,27 @@ func (pr *Proxy) Shutdown() { pr.busyConnections.ForEach(func(key, value interface{}) bool { if conn, ok := key.(net.Conn); ok { + // This will stop all the Conn.Read() and Conn.Write() calls. + if err := conn.SetDeadline(time.Now()); err != nil { + pr.logger.Error().Err(err).Msg("Error setting the deadline") + span.RecordError(err) + } conn.Close() } - if cl, ok := value.(*Client); ok { - cl.Close() + if client, ok := value.(*Client); ok { + if client != nil { + // This will stop all the Conn.Read() and Conn.Write() calls. + if err := client.Conn.SetDeadline(time.Now()); err != nil { + pr.logger.Error().Err(err).Msg("Error setting the deadline") + span.RecordError(err) + } + client.Close() + } } return true }) pr.busyConnections.Clear() + pr.scheduler.Stop() pr.scheduler.Clear() pr.logger.Debug().Msg("All busy connections have been closed") } diff --git a/network/server.go b/network/server.go index 3280fc30..644b62e2 100644 --- a/network/server.go +++ b/network/server.go @@ -245,13 +245,17 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { // Pass the traffic from the client to server. // If there is an error, log it and close the connection. - go func(s *Server, conn net.Conn, stopConnection chan struct{}) { + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { for { - s.logger.Trace().Msg("Passing through traffic from client to server") - if err := s.proxy.PassThroughToServer(conn); err != nil { - s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + server.logger.Trace().Msg("Passing through traffic from client to server") + if err := server.proxy.PassThroughToServer(conn); err != nil { + server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - stopConnection <- struct{}{} + server.engine.mu.Lock() + if server.Status == config.Stopped { + stopConnection <- struct{}{} + } + server.engine.mu.Unlock() break } } @@ -259,13 +263,17 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { // Pass the traffic from the server to client. // If there is an error, log it and close the connection. - go func(s *Server, conn net.Conn, stopConnection chan struct{}) { + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { for { - s.logger.Debug().Msg("Passing through traffic from server to client") - if err := s.proxy.PassThroughToClient(conn); err != nil { - s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + server.logger.Debug().Msg("Passing through traffic from server to client") + if err := server.proxy.PassThroughToClient(conn); err != nil { + server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - stopConnection <- struct{}{} + server.engine.mu.Lock() + if server.Status == config.Stopped { + stopConnection <- struct{}{} + } + server.engine.mu.Unlock() break } } @@ -279,7 +287,7 @@ func (s *Server) OnShutdown(Engine) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnShutdown") defer span.End() - s.logger.Debug().Msg("GatewayD is shutting down...") + s.logger.Debug().Msg("GatewayD is shutting down") pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() @@ -376,7 +384,7 @@ func (s *Server) Run() error { // Start the server. origErr := Run(s.Network, addr, s) - if origErr != nil { + if origErr != nil && origErr.Unwrap() != nil { s.logger.Error().Err(origErr).Msg("Failed to start server") span.RecordError(origErr) return gerr.ErrFailedToStartServer.Wrap(origErr) diff --git a/network/server_test.go b/network/server_test.go index 3e646ac6..b40dd6eb 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -2,6 +2,7 @@ package network import ( "bufio" + "bytes" "context" "errors" "io" @@ -23,11 +24,9 @@ import ( // TestRunServer tests an entire server run with a single client connection and hooks. func TestRunServer(t *testing.T) { errs := make(chan error) - defer close(errs) logger := logging.NewLogger(context.Background(), logging.LoggerConfig{ Output: []config.LogOutput{ - config.Console, config.File, }, TimeFormat: zerolog.TimeFormatUnix, @@ -57,9 +56,10 @@ func TestRunServer(t *testing.T) { errs <- errors.New("request is nil") //nolint:goerr113 } - logger.Info().Msg("Ingress traffic") if req, ok := paramsMap["request"].([]byte); ok { - assert.Equal(t, CreatePgStartupPacket(), req) + if !bytes.Equal(req, CreatePgStartupPacket()) { + errs <- errors.New("request does not match") //nolint:goerr113 + } } else { errs <- errors.New("request is not a []byte") //nolint:goerr113 } @@ -81,7 +81,9 @@ func TestRunServer(t *testing.T) { logger.Info().Msg("Ingress traffic") if req, ok := paramsMap["request"].([]byte); ok { - assert.Equal(t, CreatePgStartupPacket(), req) + if !bytes.Equal(req, CreatePgStartupPacket()) { + errs <- errors.New("request does not match") //nolint:goerr113 + } } else { errs <- errors.New("request is not a []byte") //nolint:goerr113 } @@ -175,7 +177,7 @@ func TestRunServer(t *testing.T) { "127.0.0.1:15432", config.DefaultTickInterval, Option{ - EnableTicker: false, + EnableTicker: true, }, proxy, logger, @@ -185,16 +187,21 @@ func TestRunServer(t *testing.T) { assert.NotNil(t, server) stop := make(chan struct{}) - defer close(stop) var waitGroup sync.WaitGroup waitGroup.Add(1) - go func(t *testing.T, server *Server, stop chan struct{}, waitGroup *sync.WaitGroup) { + go func(t *testing.T, server *Server, pluginRegistry *plugin.Registry, stop chan struct{}, waitGroup *sync.WaitGroup) { t.Helper() for { select { case <-stop: + server.Shutdown() + pluginRegistry.Shutdown() + + // Wait for the server to stop. + time.Sleep(100 * time.Millisecond) + // Read the log file and check if the log file contains the expected log messages. if _, err := os.Stat("server_test.log"); err == nil { logFile, err := os.Open("server_test.log") @@ -213,22 +220,21 @@ func TestRunServer(t *testing.T) { assert.Contains(t, logLines, "GatewayD is ticking...", "GatewayD should be ticking") assert.Contains(t, logLines, "Ingress traffic", "Ingress traffic should be logged") assert.Contains(t, logLines, "Egress traffic", "Egress traffic should be logged") - assert.Contains(t, logLines, "GatewayD is shutting down...", "GatewayD should be shutting down") + assert.Contains(t, logLines, "GatewayD is shutting down", "GatewayD should be shutting down") assert.NoError(t, os.Remove("server_test.log")) - server.Shutdown() } + waitGroup.Done() return - case err := <-errs: + case <-errs: server.Shutdown() - t.Log(err) - t.Fail() + pluginRegistry.Shutdown() waitGroup.Done() return default: //nolint:staticcheck } } - }(t, server, stop, &waitGroup) + }(t, server, pluginRegistry, stop, &waitGroup) waitGroup.Add(1) go func(t *testing.T, server *Server, errs chan error, waitGroup *sync.WaitGroup) { @@ -285,18 +291,17 @@ func TestRunServer(t *testing.T) { assert.Equal(t, 2, proxy.availableConnections.Size()) assert.Equal(t, 1, proxy.busyConnections.Size()) + client.Close() // Test Prometheus metrics. CollectAndComparePrometheusMetrics(t) - - client.Close() - time.Sleep(100 * time.Millisecond) - stop <- struct{}{} - waitGroup.Done() - return + break } time.Sleep(100 * time.Millisecond) } + stop <- struct{}{} + close(stop) + waitGroup.Done() }(t, server, proxy, stop, &waitGroup) waitGroup.Wait() diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index b975d86e..e8d0d993 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -271,6 +271,8 @@ func (reg *Registry) Run( _, span := otel.Tracer(config.TracerName).Start(reg.ctx, "Run") defer span.End() + metrics.PluginHooksExecuted.Inc() + if ctx == nil { return nil, gerr.ErrNilContext } @@ -377,8 +379,6 @@ func (reg *Registry) Run( delete(reg.hooks[hookName], priority) } - metrics.PluginHooksExecuted.Inc() - return returnVal.AsMap(), nil }