diff --git a/cmd/run.go b/cmd/run.go index 7aba4a41..03653c00 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -98,11 +98,12 @@ func StopGracefully( } } - logger.Info().Msg("Stopping GatewayD") - span.AddEvent("Stopping GatewayD", trace.WithAttributes( + logger.Info().Msg("GatewayD is shutting down") + span.AddEvent("GatewayD is shutting down", trace.WithAttributes( attribute.String("signal", signal), )) if healthCheckScheduler != nil { + healthCheckScheduler.Stop() healthCheckScheduler.Clear() logger.Info().Msg("Stopped health check scheduler") span.AddEvent("Stopped health check scheduler") @@ -268,7 +269,7 @@ var runCmd = &cobra.Command{ startDelay := time.Now().Add(conf.Plugin.HealthCheckPeriod) if _, err := healthCheckScheduler.Every( conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do(func() { - _, span = otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") + _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") defer span.End() var plugins []string @@ -749,9 +750,13 @@ var runCmd = &cobra.Command{ ) signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, signals...) - go func(pluginRegistry *plugin.Registry, + go func(pluginTimeoutCtx context.Context, + pluginRegistry *plugin.Registry, logger zerolog.Logger, servers map[string]*network.Server, + metricsMerger *metrics.Merger, + metricsServer *http.Server, + stopChan chan struct{}, ) { for sig := range signalsCh { for _, s := range signals { @@ -771,13 +776,14 @@ var runCmd = &cobra.Command{ } } } - }(pluginRegistry, logger, servers) + }(pluginTimeoutCtx, pluginRegistry, logger, servers, metricsMerger, metricsServer, stopChan) _, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers") // Start the server. for name, server := range servers { logger := loggers[name] go func( + span trace.Span, server *network.Server, logger zerolog.Logger, healthCheckScheduler *gocron.Scheduler, @@ -797,7 +803,7 @@ var runCmd = &cobra.Command{ pluginRegistry.Shutdown() os.Exit(gerr.FailedToStartServer) } - }(server, logger, healthCheckScheduler, metricsMerger, pluginRegistry) + }(span, server, logger, healthCheckScheduler, metricsMerger, pluginRegistry) } span.End() diff --git a/errors/errors.go b/errors/errors.go index f16a6501..a8799bbc 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -23,6 +23,7 @@ const ( ErrCodeServerSendFailed ErrCodeServerListenFailed ErrCodeSplitHostPortFailed + ErrCodeCloseListenerFailed ErrCodeAcceptFailed ErrCodeReadFailed ErrCodePutFailed @@ -85,6 +86,8 @@ var ( ErrCodeServerListenFailed, "couldn't listen on the server", nil) ErrSplitHostPortFailed = NewGatewayDError( ErrCodeSplitHostPortFailed, "failed to split host:port", nil) + ErrCloseListenerFailed = NewGatewayDError( + ErrCodeCloseListenerFailed, "failed to close listener", nil) ErrAcceptFailed = NewGatewayDError( ErrCodeAcceptFailed, "failed to accept connection", nil) diff --git a/metrics/builtins.go b/metrics/builtins.go index e56b978d..2e588a58 100644 --- a/metrics/builtins.go +++ b/metrics/builtins.go @@ -75,10 +75,15 @@ var ( Name: "proxied_connections", Help: "Number of proxy connects", }) - ProxyPassThroughs = promauto.NewCounter(prometheus.CounterOpts{ + ProxyPassThroughsToClient = promauto.NewCounter(prometheus.CounterOpts{ Namespace: Namespace, - Name: "proxy_passthroughs_total", - Help: "Number of successful proxy passthroughs", + Name: "proxy_passthroughs_to_client_total", + Help: "Number of successful proxy passthroughs from server to client", + }) + ProxyPassThroughsToServer = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: Namespace, + Name: "proxy_passthroughs_to_server_total", + Help: "Number of successful proxy passthroughs from client to server", }) ProxyPassThroughTerminations = promauto.NewCounter(prometheus.CounterOpts{ Namespace: Namespace, diff --git a/network/client.go b/network/client.go index de2ef2a8..9cff7a9d 100644 --- a/network/client.go +++ b/network/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "sync/atomic" "time" "github.com/gatewayd-io/gatewayd/config" @@ -26,8 +27,9 @@ type IClient interface { type Client struct { net.Conn - logger zerolog.Logger - ctx context.Context //nolint:containedctx + logger zerolog.Logger + ctx context.Context //nolint:containedctx + connected atomic.Bool TCPKeepAlive bool TCPKeepAlivePeriod time.Duration @@ -53,6 +55,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. return nil } + client.connected.Store(false) client.logger = logger // Try to resolve the address and log an error if it can't be resolved. @@ -87,6 +90,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. } client.Conn = conn + client.connected.Store(true) // Set the TCP keep alive. client.TCPKeepAlive = clientConfig.TCPKeepAlive @@ -146,6 +150,11 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Send") defer span.End() + if !c.connected.Load() { + span.RecordError(gerr.ErrClientNotConnected) + return 0, gerr.ErrClientNotConnected + } + sent := 0 received := len(data) for { @@ -170,8 +179,6 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { }, ).Msg("Sent data to server") - metrics.BytesSentToServer.Observe(float64(sent)) - return sent, nil } @@ -180,6 +187,11 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Receive") defer span.End() + if !c.connected.Load() { + span.RecordError(gerr.ErrClientNotConnected) + return 0, nil, gerr.ErrClientNotConnected + } + var ctx context.Context var cancel context.CancelFunc if c.ReceiveTimeout > 0 { @@ -192,26 +204,21 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { var received int buffer := bytes.NewBuffer(nil) // Read the data in chunks. - select { //nolint:gosimple - case <-time.After(time.Millisecond): - for ctx.Err() == nil { - chunk := make([]byte, c.ReceiveChunkSize) - read, err := c.Conn.Read(chunk) - if err != nil { - c.logger.Error().Err(err).Msg("Couldn't receive data from the server") - span.RecordError(err) - metrics.BytesReceivedFromServer.Observe(float64(received)) - return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) - } - received += read - buffer.Write(chunk[:read]) + for ctx.Err() == nil { + chunk := make([]byte, c.ReceiveChunkSize) + read, err := c.Conn.Read(chunk) + if err != nil { + c.logger.Error().Err(err).Msg("Couldn't receive data from the server") + span.RecordError(err) + return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) + } + received += read + buffer.Write(chunk[:read]) - if read == 0 || read < c.ReceiveChunkSize { - break - } + if read == 0 || read < c.ReceiveChunkSize { + break } } - metrics.BytesReceivedFromServer.Observe(float64(received)) return received, buffer.Bytes(), nil } @@ -220,6 +227,12 @@ func (c *Client) Close() { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Close") defer span.End() + // Set the deadline to now so that the connection is closed immediately. + if err := c.Conn.SetDeadline(time.Now()); err != nil { + c.logger.Error().Err(err).Msg("Failed to set deadline") + span.RecordError(err) + } + c.logger.Debug().Str("address", c.Address).Msg("Closing connection to server") if c.Conn != nil { c.Conn.Close() @@ -228,6 +241,7 @@ func (c *Client) Close() { c.Conn = nil c.Address = "" c.Network = "" + c.connected.Store(false) metrics.ServerConnections.Dec() } @@ -257,20 +271,15 @@ func (c *Client) IsConnected() bool { return false } - if n, err := c.Read([]byte{}); n == 0 && err != nil { - c.logger.Debug().Fields( - map[string]interface{}{ - "address": c.Address, - "reason": "read 0 bytes", - }).Msg("Connection to server is closed") - return false - } - - return true + return c.connected.Load() } // RemoteAddr returns the remote address of the client safely. func (c *Client) RemoteAddr() string { + if !c.connected.Load() { + return "" + } + if c.Conn != nil && c.Conn.RemoteAddr() != nil { return c.Conn.RemoteAddr().String() } @@ -280,6 +289,10 @@ func (c *Client) RemoteAddr() string { // LocalAddr returns the local address of the client safely. func (c *Client) LocalAddr() string { + if !c.connected.Load() { + return "" + } + if c.Conn != nil && c.Conn.LocalAddr() != nil { return c.Conn.LocalAddr().String() } diff --git a/network/engine.go b/network/engine.go index c9800e1a..af480250 100644 --- a/network/engine.go +++ b/network/engine.go @@ -5,6 +5,7 @@ import ( "net" "strconv" "sync" + "sync/atomic" "time" "github.com/gatewayd-io/gatewayd/config" @@ -35,6 +36,7 @@ type Engine struct { host string port int connections uint32 + running *atomic.Bool stopServer chan struct{} mu *sync.RWMutex } @@ -49,7 +51,15 @@ func (engine *Engine) Stop(ctx context.Context) error { _, cancel := context.WithDeadline(ctx, time.Now().Add(config.DefaultEngineStopTimeout)) defer cancel() + engine.running.Store(false) + if engine.listener != nil { + if err := engine.listener.Close(); err != nil { + engine.stopServer <- struct{}{} + return gerr.ErrCloseListenerFailed.Wrap(err) + } + } engine.stopServer <- struct{}{} + close(engine.stopServer) return nil } @@ -59,6 +69,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { connections: 0, stopServer: make(chan struct{}), mu: &sync.RWMutex{}, + running: &atomic.Bool{}, } if action := server.OnBoot(server.engine); action != None { @@ -71,7 +82,6 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { server.logger.Error().Err(err).Msg("Server failed to start listening") return gerr.ErrServerListenFailed.Wrap(err) } - defer server.engine.listener.Close() if server.engine.listener == nil { server.logger.Error().Msg("Server is not properly initialized") @@ -102,56 +112,79 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { } for { - interval, action := server.OnTick() - if action == Shutdown { - server.OnShutdown(server.engine) - return - } - if interval == time.Duration(0) { + select { + case <-server.engine.stopServer: return + default: + interval, action := server.OnTick() + if action == Shutdown { + server.OnShutdown(server.engine) + return + } + if interval == time.Duration(0) { + return + } + time.Sleep(interval) } - time.Sleep(interval) } }(server) - for { - conn, err := server.engine.listener.Accept() - if err != nil { - server.logger.Error().Err(err).Msg("Failed to accept connection") - return gerr.ErrAcceptFailed.Wrap(err) - } + server.engine.running.Store(true) - if out, action := server.OnOpen(conn); action != None { - if _, err := conn.Write(out); err != nil { - server.logger.Error().Err(err).Msg("Failed to write to connection") - } - conn.Close() - if action == Shutdown { - server.OnShutdown(server.engine) - return nil - } - } - server.engine.mu.Lock() - server.engine.connections++ - server.engine.mu.Unlock() - - // For every new connection, a new unbuffered channel is created to help - // stop the proxy, recycle the server connection and close stale connections. - stopConnection := make(chan struct{}) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - if action := server.OnTraffic(conn, stopConnection); action == Close { - return + for { + select { + case <-server.engine.stopServer: + server.logger.Info().Msg("Server stopped") + return nil + default: + conn, err := server.engine.listener.Accept() + if err != nil { + if !server.engine.running.Load() { + return nil + } + server.logger.Error().Err(err).Msg("Failed to accept connection") + return gerr.ErrAcceptFailed.Wrap(err) } - }(server, conn, stopConnection) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - <-stopConnection + if out, action := server.OnOpen(conn); action != None { + if _, err := conn.Write(out); err != nil { + server.logger.Error().Err(err).Msg("Failed to write to connection") + } + conn.Close() + if action == Shutdown { + server.OnShutdown(server.engine) + return nil + } + } server.engine.mu.Lock() - server.engine.connections-- + server.engine.connections++ server.engine.mu.Unlock() - if action := server.OnClose(conn, err); action == Close { - return - } - }(server, conn, stopConnection) + + // For every new connection, a new unbuffered channel is created to help + // stop the proxy, recycle the server connection and close stale connections. + stopConnection := make(chan struct{}) + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + if action := server.OnTraffic(conn, stopConnection); action == Close { + return + } + }(server, conn, stopConnection) + + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + for { + select { + case <-stopConnection: + server.engine.mu.Lock() + server.engine.connections-- + server.engine.mu.Unlock() + if action := server.OnClose(conn, err); action == Close { + return + } + return + case <-server.engine.stopServer: + return + } + } + }(server, conn, stopConnection) + } } } diff --git a/network/network_helpers_test.go b/network/network_helpers_test.go index 766c0315..54d68ad2 100644 --- a/network/network_helpers_test.go +++ b/network/network_helpers_test.go @@ -91,8 +91,10 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { # TYPE gatewayd_proxy_health_checks_total counter # HELP gatewayd_proxy_passthrough_terminations_total Number of proxy passthrough terminations by plugins # TYPE gatewayd_proxy_passthrough_terminations_total counter - # HELP gatewayd_proxy_passthroughs_total Number of successful proxy passthroughs - # TYPE gatewayd_proxy_passthroughs_total counter + # HELP gatewayd_proxy_passthroughs_to_client_total Number of successful proxy passthroughs + # TYPE gatewayd_proxy_passthroughs_to_client_total counter + # HELP gatewayd_proxy_passthroughs_to_server_total Number of successful proxy passthroughs + # TYPE gatewayd_proxy_passthroughs_to_server_total counter # HELP gatewayd_server_connections Number of server connections # TYPE gatewayd_server_connections gauge # HELP gatewayd_server_ticks_fired_total Total number of server ticks fired @@ -105,12 +107,12 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { want = metadata + ` gatewayd_bytes_received_from_client_sum 67 gatewayd_bytes_received_from_client_count 1 - gatewayd_bytes_received_from_server_sum 96 - gatewayd_bytes_received_from_server_count 4 + gatewayd_bytes_received_from_server_sum 24 + gatewayd_bytes_received_from_server_count 1 gatewayd_bytes_sent_to_client_sum 24 gatewayd_bytes_sent_to_client_count 1 - gatewayd_bytes_sent_to_server_sum 282 - gatewayd_bytes_sent_to_server_count 5 + gatewayd_bytes_sent_to_server_sum 67 + gatewayd_bytes_sent_to_server_count 1 gatewayd_client_connections 1 gatewayd_plugin_hooks_executed_total 11 gatewayd_plugin_hooks_registered_total 0 @@ -118,8 +120,9 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { gatewayd_proxied_connections 1 gatewayd_proxy_health_checks_total 0 gatewayd_proxy_passthrough_terminations_total 0 - gatewayd_proxy_passthroughs_total 1 - gatewayd_server_connections 5 + gatewayd_proxy_passthroughs_to_client_total 1 + gatewayd_proxy_passthroughs_to_server_total 1 + gatewayd_server_connections 4 gatewayd_traffic_bytes_sum 182 gatewayd_traffic_bytes_count 4 gatewayd_server_ticks_fired_total 1 diff --git a/network/proxy.go b/network/proxy.go index 66bf065d..02f2605c 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -316,7 +316,7 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { // If the hook wants to terminate the connection, do it. if pr.shouldTerminate(result) { if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil { - metrics.ProxyPassThroughs.Inc() + metrics.ProxyPassThroughsToClient.Inc() metrics.ProxyPassThroughTerminations.Inc() metrics.BytesSentToClient.Observe(float64(modReceived)) metrics.TotalTrafficBytes.Observe(float64(modReceived)) @@ -357,6 +357,8 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { } span.AddEvent("Ran the OnTrafficToServer hooks") + metrics.ProxyPassThroughsToServer.Inc() + return nil } @@ -459,7 +461,7 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { span.RecordError(errVerdict) } - metrics.ProxyPassThroughs.Inc() + metrics.ProxyPassThroughsToClient.Inc() return errVerdict } @@ -501,8 +503,16 @@ func (pr *Proxy) Shutdown() { defer span.End() pr.availableConnections.ForEach(func(key, value interface{}) bool { - if cl, ok := value.(*Client); ok { - cl.Close() + if client, ok := value.(*Client); ok { + if client.IsConnected() { + // This will stop all the Conn.Read() and Conn.Write() calls. + // Ref: https://groups.google.com/g/golang-nuts/c/VPVWFrpIEyo + if err := client.Conn.SetDeadline(time.Now()); err != nil { + pr.logger.Error().Err(err).Msg("Error setting the deadline") + span.RecordError(err) + } + client.Close() + } } return true }) @@ -511,14 +521,27 @@ func (pr *Proxy) Shutdown() { pr.busyConnections.ForEach(func(key, value interface{}) bool { if conn, ok := key.(net.Conn); ok { + // This will stop all the Conn.Read() and Conn.Write() calls. + if err := conn.SetDeadline(time.Now()); err != nil { + pr.logger.Error().Err(err).Msg("Error setting the deadline") + span.RecordError(err) + } conn.Close() } - if cl, ok := value.(*Client); ok { - cl.Close() + if client, ok := value.(*Client); ok { + if client != nil { + // This will stop all the Conn.Read() and Conn.Write() calls. + if err := client.Conn.SetDeadline(time.Now()); err != nil { + pr.logger.Error().Err(err).Msg("Error setting the deadline") + span.RecordError(err) + } + client.Close() + } } return true }) pr.busyConnections.Clear() + pr.scheduler.Stop() pr.scheduler.Clear() pr.logger.Debug().Msg("All busy connections have been closed") } diff --git a/network/proxy_test.go b/network/proxy_test.go index 14f7a0ec..ec072740 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -23,8 +23,8 @@ func TestNewProxy(t *testing.T) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) client := NewClient( context.Background(), @@ -38,13 +38,13 @@ func TestNewProxy(t *testing.T) { TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, logger) - err := pool.Put(client.ID, client) + err := newPool.Put(client.ID, client) assert.Nil(t, err) - // Create a proxy with a fixed buffer pool + // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, @@ -86,13 +86,13 @@ func TestNewProxyElastic(t *testing.T) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) - // Create a proxy with an elastic buffer pool + // Create a proxy with an elastic buffer newPool proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, @@ -136,14 +136,14 @@ func BenchmarkNewProxy(b *testing.B) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) - // Create a proxy with a fixed buffer pool + // Create a proxy with a fixed buffer newPool for i := 0; i < b.N; i++ { proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, @@ -172,14 +172,14 @@ func BenchmarkNewProxyElastic(b *testing.B) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), config.EmptyPoolCapacity) - // Create a proxy with an elastic buffer pool + // Create a proxy with an elastic buffer newPool for i := 0; i < b.N; i++ { proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, @@ -216,8 +216,8 @@ func BenchmarkProxyConnectDisconnect(b *testing.B) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), 1) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), 1) clientConfig := config.Client{ Network: "tcp", @@ -229,12 +229,12 @@ func BenchmarkProxyConnectDisconnect(b *testing.B) { TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } - pool.Put("client", NewClient(context.Background(), &clientConfig, logger)) //nolint:errcheck + newPool.Put("client", NewClient(context.Background(), &clientConfig, logger)) //nolint:errcheck - // Create a proxy with a fixed buffer pool + // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, @@ -270,8 +270,8 @@ func BenchmarkProxyPassThrough(b *testing.B) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), 1) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), 1) clientConfig := config.Client{ Network: "tcp", @@ -283,12 +283,12 @@ func BenchmarkProxyPassThrough(b *testing.B) { TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } - pool.Put("client", NewClient(context.Background(), &clientConfig, logger)) //nolint:errcheck + newPool.Put("client", NewClient(context.Background(), &clientConfig, logger)) //nolint:errcheck - // Create a proxy with a fixed buffer pool + // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, @@ -326,8 +326,8 @@ func BenchmarkProxyIsHealthyAndIsExhausted(b *testing.B) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), 1) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), 1) clientConfig := config.Client{ Network: "tcp", @@ -340,12 +340,12 @@ func BenchmarkProxyIsHealthyAndIsExhausted(b *testing.B) { TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } client := NewClient(context.Background(), &clientConfig, logger) - pool.Put("client", client) //nolint:errcheck + newPool.Put("client", client) //nolint:errcheck - // Create a proxy with a fixed buffer pool + // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, @@ -383,8 +383,8 @@ func BenchmarkProxyAvailableAndBusyConnections(b *testing.B) { NoColor: true, }) - // Create a connection pool - pool := pool.NewPool(context.Background(), 1) + // Create a connection newPool + newPool := pool.NewPool(context.Background(), 1) clientConfig := config.Client{ Network: "tcp", @@ -397,12 +397,12 @@ func BenchmarkProxyAvailableAndBusyConnections(b *testing.B) { TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } client := NewClient(context.Background(), &clientConfig, logger) - pool.Put("client", client) //nolint:errcheck + newPool.Put("client", client) //nolint:errcheck - // Create a proxy with a fixed buffer pool + // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - pool, + newPool, plugin.NewRegistry( context.Background(), config.Loose, diff --git a/network/server.go b/network/server.go index 3280fc30..644b62e2 100644 --- a/network/server.go +++ b/network/server.go @@ -245,13 +245,17 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { // Pass the traffic from the client to server. // If there is an error, log it and close the connection. - go func(s *Server, conn net.Conn, stopConnection chan struct{}) { + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { for { - s.logger.Trace().Msg("Passing through traffic from client to server") - if err := s.proxy.PassThroughToServer(conn); err != nil { - s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + server.logger.Trace().Msg("Passing through traffic from client to server") + if err := server.proxy.PassThroughToServer(conn); err != nil { + server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - stopConnection <- struct{}{} + server.engine.mu.Lock() + if server.Status == config.Stopped { + stopConnection <- struct{}{} + } + server.engine.mu.Unlock() break } } @@ -259,13 +263,17 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { // Pass the traffic from the server to client. // If there is an error, log it and close the connection. - go func(s *Server, conn net.Conn, stopConnection chan struct{}) { + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { for { - s.logger.Debug().Msg("Passing through traffic from server to client") - if err := s.proxy.PassThroughToClient(conn); err != nil { - s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + server.logger.Debug().Msg("Passing through traffic from server to client") + if err := server.proxy.PassThroughToClient(conn); err != nil { + server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - stopConnection <- struct{}{} + server.engine.mu.Lock() + if server.Status == config.Stopped { + stopConnection <- struct{}{} + } + server.engine.mu.Unlock() break } } @@ -279,7 +287,7 @@ func (s *Server) OnShutdown(Engine) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnShutdown") defer span.End() - s.logger.Debug().Msg("GatewayD is shutting down...") + s.logger.Debug().Msg("GatewayD is shutting down") pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() @@ -376,7 +384,7 @@ func (s *Server) Run() error { // Start the server. origErr := Run(s.Network, addr, s) - if origErr != nil { + if origErr != nil && origErr.Unwrap() != nil { s.logger.Error().Err(origErr).Msg("Failed to start server") span.RecordError(origErr) return gerr.ErrFailedToStartServer.Wrap(origErr) diff --git a/network/server_test.go b/network/server_test.go index 3e646ac6..4dbb7d24 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -2,6 +2,7 @@ package network import ( "bufio" + "bytes" "context" "errors" "io" @@ -23,11 +24,9 @@ import ( // TestRunServer tests an entire server run with a single client connection and hooks. func TestRunServer(t *testing.T) { errs := make(chan error) - defer close(errs) logger := logging.NewLogger(context.Background(), logging.LoggerConfig{ Output: []config.LogOutput{ - config.Console, config.File, }, TimeFormat: zerolog.TimeFormatUnix, @@ -57,9 +56,10 @@ func TestRunServer(t *testing.T) { errs <- errors.New("request is nil") //nolint:goerr113 } - logger.Info().Msg("Ingress traffic") if req, ok := paramsMap["request"].([]byte); ok { - assert.Equal(t, CreatePgStartupPacket(), req) + if !bytes.Equal(req, CreatePgStartupPacket()) { + errs <- errors.New("request does not match") //nolint:goerr113 + } } else { errs <- errors.New("request is not a []byte") //nolint:goerr113 } @@ -81,7 +81,9 @@ func TestRunServer(t *testing.T) { logger.Info().Msg("Ingress traffic") if req, ok := paramsMap["request"].([]byte); ok { - assert.Equal(t, CreatePgStartupPacket(), req) + if !bytes.Equal(req, CreatePgStartupPacket()) { + errs <- errors.New("request does not match") //nolint:goerr113 + } } else { errs <- errors.New("request is not a []byte") //nolint:goerr113 } @@ -144,22 +146,22 @@ func TestRunServer(t *testing.T) { TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } - // Create a connection pool. - pool := pool.NewPool(context.Background(), 3) + // Create a connection newPool. + newPool := pool.NewPool(context.Background(), 3) client1 := NewClient(context.Background(), &clientConfig, logger) - err := pool.Put(client1.ID, client1) + err := newPool.Put(client1.ID, client1) assert.Nil(t, err) client2 := NewClient(context.Background(), &clientConfig, logger) - err = pool.Put(client2.ID, client2) + err = newPool.Put(client2.ID, client2) assert.Nil(t, err) client3 := NewClient(context.Background(), &clientConfig, logger) - err = pool.Put(client3.ID, client3) + err = newPool.Put(client3.ID, client3) assert.Nil(t, err) - // Create a proxy with a fixed buffer pool. + // Create a proxy with a fixed buffer newPool. proxy := NewProxy( context.Background(), - pool, + newPool, pluginRegistry, false, false, @@ -175,7 +177,7 @@ func TestRunServer(t *testing.T) { "127.0.0.1:15432", config.DefaultTickInterval, Option{ - EnableTicker: false, + EnableTicker: true, }, proxy, logger, @@ -185,21 +187,25 @@ func TestRunServer(t *testing.T) { assert.NotNil(t, server) stop := make(chan struct{}) - defer close(stop) var waitGroup sync.WaitGroup waitGroup.Add(1) - go func(t *testing.T, server *Server, stop chan struct{}, waitGroup *sync.WaitGroup) { + go func(t *testing.T, server *Server, pluginRegistry *plugin.Registry, stop chan struct{}, waitGroup *sync.WaitGroup) { t.Helper() for { select { case <-stop: + server.Shutdown() + pluginRegistry.Shutdown() + + // Wait for the server to stop. + time.Sleep(100 * time.Millisecond) + // Read the log file and check if the log file contains the expected log messages. if _, err := os.Stat("server_test.log"); err == nil { logFile, err := os.Open("server_test.log") assert.NoError(t, err) - defer logFile.Close() reader := bufio.NewReader(logFile) assert.NotNil(t, reader) @@ -207,28 +213,28 @@ func TestRunServer(t *testing.T) { buffer, err := io.ReadAll(reader) assert.NoError(t, err) assert.Greater(t, len(buffer), 0) // The log file should not be empty. + assert.NoError(t, logFile.Close()) logLines := string(buffer) assert.Contains(t, logLines, "GatewayD is running", "GatewayD should be running") assert.Contains(t, logLines, "GatewayD is ticking...", "GatewayD should be ticking") assert.Contains(t, logLines, "Ingress traffic", "Ingress traffic should be logged") assert.Contains(t, logLines, "Egress traffic", "Egress traffic should be logged") - assert.Contains(t, logLines, "GatewayD is shutting down...", "GatewayD should be shutting down") + assert.Contains(t, logLines, "GatewayD is shutting down", "GatewayD should be shutting down") assert.NoError(t, os.Remove("server_test.log")) - server.Shutdown() } + waitGroup.Done() return - case err := <-errs: + case <-errs: server.Shutdown() - t.Log(err) - t.Fail() + pluginRegistry.Shutdown() waitGroup.Done() return default: //nolint:staticcheck } } - }(t, server, stop, &waitGroup) + }(t, server, pluginRegistry, stop, &waitGroup) waitGroup.Add(1) go func(t *testing.T, server *Server, errs chan error, waitGroup *sync.WaitGroup) { @@ -290,13 +296,13 @@ func TestRunServer(t *testing.T) { CollectAndComparePrometheusMetrics(t) client.Close() - time.Sleep(100 * time.Millisecond) - stop <- struct{}{} - waitGroup.Done() - return + break } time.Sleep(100 * time.Millisecond) } + stop <- struct{}{} + close(stop) + waitGroup.Done() }(t, server, proxy, stop, &waitGroup) waitGroup.Wait() diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index b975d86e..e8d0d993 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -271,6 +271,8 @@ func (reg *Registry) Run( _, span := otel.Tracer(config.TracerName).Start(reg.ctx, "Run") defer span.End() + metrics.PluginHooksExecuted.Inc() + if ctx == nil { return nil, gerr.ErrNilContext } @@ -377,8 +379,6 @@ func (reg *Registry) Run( delete(reg.hooks[hookName], priority) } - metrics.PluginHooksExecuted.Inc() - return returnVal.AsMap(), nil }