diff --git a/cmd/run.go b/cmd/run.go index 7aba4a41..dae3e339 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -98,8 +98,8 @@ func StopGracefully( } } - logger.Info().Msg("Stopping GatewayD") - span.AddEvent("Stopping GatewayD", trace.WithAttributes( + logger.Info().Msg("GatewayD is shutting down") + span.AddEvent("GatewayD is shutting down", trace.WithAttributes( attribute.String("signal", signal), )) if healthCheckScheduler != nil { @@ -749,9 +749,13 @@ var runCmd = &cobra.Command{ ) signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, signals...) - go func(pluginRegistry *plugin.Registry, + go func(pluginTimeoutCtx context.Context, + pluginRegistry *plugin.Registry, logger zerolog.Logger, servers map[string]*network.Server, + metricsMerger *metrics.Merger, + metricsServer *http.Server, + stopChan chan struct{}, ) { for sig := range signalsCh { for _, s := range signals { @@ -771,13 +775,14 @@ var runCmd = &cobra.Command{ } } } - }(pluginRegistry, logger, servers) + }(pluginTimeoutCtx, pluginRegistry, logger, servers, metricsMerger, metricsServer, stopChan) _, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers") // Start the server. for name, server := range servers { logger := loggers[name] go func( + span trace.Span, server *network.Server, logger zerolog.Logger, healthCheckScheduler *gocron.Scheduler, @@ -797,7 +802,7 @@ var runCmd = &cobra.Command{ pluginRegistry.Shutdown() os.Exit(gerr.FailedToStartServer) } - }(server, logger, healthCheckScheduler, metricsMerger, pluginRegistry) + }(span, server, logger, healthCheckScheduler, metricsMerger, pluginRegistry) } span.End() diff --git a/errors/errors.go b/errors/errors.go index f16a6501..1100d236 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -22,6 +22,7 @@ const ( ErrCodeServerReceiveFailed ErrCodeServerSendFailed ErrCodeServerListenFailed + ErrCodeServerCloseFailed ErrCodeSplitHostPortFailed ErrCodeAcceptFailed ErrCodeReadFailed @@ -87,6 +88,8 @@ var ( ErrCodeSplitHostPortFailed, "failed to split host:port", nil) ErrAcceptFailed = NewGatewayDError( ErrCodeAcceptFailed, "failed to accept connection", nil) + ErrServerCloseFailed = NewGatewayDError( + ErrCodeServerCloseFailed, "failed to close server", nil) ErrReadFailed = NewGatewayDError( ErrCodeReadFailed, "failed to read from the client", nil) diff --git a/metrics/builtins.go b/metrics/builtins.go index e56b978d..2e588a58 100644 --- a/metrics/builtins.go +++ b/metrics/builtins.go @@ -75,10 +75,15 @@ var ( Name: "proxied_connections", Help: "Number of proxy connects", }) - ProxyPassThroughs = promauto.NewCounter(prometheus.CounterOpts{ + ProxyPassThroughsToClient = promauto.NewCounter(prometheus.CounterOpts{ Namespace: Namespace, - Name: "proxy_passthroughs_total", - Help: "Number of successful proxy passthroughs", + Name: "proxy_passthroughs_to_client_total", + Help: "Number of successful proxy passthroughs from server to client", + }) + ProxyPassThroughsToServer = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: Namespace, + Name: "proxy_passthroughs_to_server_total", + Help: "Number of successful proxy passthroughs from client to server", }) ProxyPassThroughTerminations = promauto.NewCounter(prometheus.CounterOpts{ Namespace: Namespace, diff --git a/network/client.go b/network/client.go index de2ef2a8..603707ee 100644 --- a/network/client.go +++ b/network/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "sync/atomic" "time" "github.com/gatewayd-io/gatewayd/config" @@ -26,8 +27,9 @@ type IClient interface { type Client struct { net.Conn - logger zerolog.Logger - ctx context.Context //nolint:containedctx + logger zerolog.Logger + ctx context.Context //nolint:containedctx + connected atomic.Bool TCPKeepAlive bool TCPKeepAlivePeriod time.Duration @@ -53,6 +55,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. return nil } + client.connected.Store(false) client.logger = logger // Try to resolve the address and log an error if it can't be resolved. @@ -87,6 +90,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. } client.Conn = conn + client.connected.Store(true) // Set the TCP keep alive. client.TCPKeepAlive = clientConfig.TCPKeepAlive @@ -146,6 +150,11 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Send") defer span.End() + if !c.connected.Load() { + span.RecordError(gerr.ErrClientNotConnected) + return 0, gerr.ErrClientNotConnected + } + sent := 0 received := len(data) for { @@ -170,8 +179,6 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { }, ).Msg("Sent data to server") - metrics.BytesSentToServer.Observe(float64(sent)) - return sent, nil } @@ -180,6 +187,11 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Receive") defer span.End() + if !c.connected.Load() { + span.RecordError(gerr.ErrClientNotConnected) + return 0, nil, gerr.ErrClientNotConnected + } + var ctx context.Context var cancel context.CancelFunc if c.ReceiveTimeout > 0 { @@ -192,31 +204,29 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { var received int buffer := bytes.NewBuffer(nil) // Read the data in chunks. - select { //nolint:gosimple - case <-time.After(time.Millisecond): - for ctx.Err() == nil { - chunk := make([]byte, c.ReceiveChunkSize) - read, err := c.Conn.Read(chunk) - if err != nil { - c.logger.Error().Err(err).Msg("Couldn't receive data from the server") - span.RecordError(err) - metrics.BytesReceivedFromServer.Observe(float64(received)) - return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) - } - received += read - buffer.Write(chunk[:read]) + for ctx.Err() == nil { + chunk := make([]byte, c.ReceiveChunkSize) + read, err := c.Conn.Read(chunk) + if err != nil { + c.logger.Error().Err(err).Msg("Couldn't receive data from the server") + span.RecordError(err) + return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) + } + received += read + buffer.Write(chunk[:read]) - if read == 0 || read < c.ReceiveChunkSize { - break - } + if read == 0 || read < c.ReceiveChunkSize { + break } } - metrics.BytesReceivedFromServer.Observe(float64(received)) return received, buffer.Bytes(), nil } // Close closes the connection to the server. func (c *Client) Close() { + // Set the deadline to now so that the connection is closed immediately. + c.Conn.SetDeadline(time.Now()) + _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Close") defer span.End() @@ -228,6 +238,7 @@ func (c *Client) Close() { c.Conn = nil c.Address = "" c.Network = "" + c.connected.Store(false) metrics.ServerConnections.Dec() } @@ -257,20 +268,15 @@ func (c *Client) IsConnected() bool { return false } - if n, err := c.Read([]byte{}); n == 0 && err != nil { - c.logger.Debug().Fields( - map[string]interface{}{ - "address": c.Address, - "reason": "read 0 bytes", - }).Msg("Connection to server is closed") - return false - } - - return true + return c.connected.Load() } // RemoteAddr returns the remote address of the client safely. func (c *Client) RemoteAddr() string { + if !c.connected.Load() { + return "" + } + if c.Conn != nil && c.Conn.RemoteAddr() != nil { return c.Conn.RemoteAddr().String() } @@ -280,6 +286,10 @@ func (c *Client) RemoteAddr() string { // LocalAddr returns the local address of the client safely. func (c *Client) LocalAddr() string { + if !c.connected.Load() { + return "" + } + if c.Conn != nil && c.Conn.LocalAddr() != nil { return c.Conn.LocalAddr().String() } diff --git a/network/engine.go b/network/engine.go index c9800e1a..a1331bac 100644 --- a/network/engine.go +++ b/network/engine.go @@ -5,6 +5,7 @@ import ( "net" "strconv" "sync" + "sync/atomic" "time" "github.com/gatewayd-io/gatewayd/config" @@ -35,6 +36,7 @@ type Engine struct { host string port int connections uint32 + running *atomic.Bool stopServer chan struct{} mu *sync.RWMutex } @@ -49,6 +51,11 @@ func (engine *Engine) Stop(ctx context.Context) error { _, cancel := context.WithDeadline(ctx, time.Now().Add(config.DefaultEngineStopTimeout)) defer cancel() + engine.running.Store(false) + if err := engine.listener.Close(); err != nil { + engine.stopServer <- struct{}{} + return gerr.ErrServerCloseFailed.Wrap(err) + } engine.stopServer <- struct{}{} return nil } @@ -59,6 +66,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { connections: 0, stopServer: make(chan struct{}), mu: &sync.RWMutex{}, + running: &atomic.Bool{}, } if action := server.OnBoot(server.engine); action != None { @@ -71,7 +79,6 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { server.logger.Error().Err(err).Msg("Server failed to start listening") return gerr.ErrServerListenFailed.Wrap(err) } - defer server.engine.listener.Close() if server.engine.listener == nil { server.logger.Error().Msg("Server is not properly initialized") @@ -114,44 +121,62 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { } }(server) - for { - conn, err := server.engine.listener.Accept() - if err != nil { - server.logger.Error().Err(err).Msg("Failed to accept connection") - return gerr.ErrAcceptFailed.Wrap(err) - } + server.engine.running.Store(true) - if out, action := server.OnOpen(conn); action != None { - if _, err := conn.Write(out); err != nil { - server.logger.Error().Err(err).Msg("Failed to write to connection") - } - conn.Close() - if action == Shutdown { - server.OnShutdown(server.engine) - return nil - } - } - server.engine.mu.Lock() - server.engine.connections++ - server.engine.mu.Unlock() - - // For every new connection, a new unbuffered channel is created to help - // stop the proxy, recycle the server connection and close stale connections. - stopConnection := make(chan struct{}) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - if action := server.OnTraffic(conn, stopConnection); action == Close { - return + for { + select { + case <-server.engine.stopServer: + server.logger.Info().Msg("Server stopped") + return nil + default: + conn, err := server.engine.listener.Accept() + if err != nil { + if !server.engine.running.Load() { + return nil + } + server.logger.Error().Err(err).Msg("Failed to accept connection") + return gerr.ErrAcceptFailed.Wrap(err) } - }(server, conn, stopConnection) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - <-stopConnection + if out, action := server.OnOpen(conn); action != None { + if _, err := conn.Write(out); err != nil { + server.logger.Error().Err(err).Msg("Failed to write to connection") + } + conn.Close() + if action == Shutdown { + server.OnShutdown(server.engine) + return nil + } + } server.engine.mu.Lock() - server.engine.connections-- + server.engine.connections++ server.engine.mu.Unlock() - if action := server.OnClose(conn, err); action == Close { - return - } - }(server, conn, stopConnection) + + // For every new connection, a new unbuffered channel is created to help + // stop the proxy, recycle the server connection and close stale connections. + stopConnection := make(chan struct{}) + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + if action := server.OnTraffic(conn, stopConnection); action == Close { + return + } + }(server, conn, stopConnection) + + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + for { + select { + case <-stopConnection: + server.engine.mu.Lock() + server.engine.connections-- + server.engine.mu.Unlock() + if action := server.OnClose(conn, err); action == Close { + return + } + return + case <-server.engine.stopServer: + return + } + } + }(server, conn, stopConnection) + } } } diff --git a/network/network_helpers_test.go b/network/network_helpers_test.go index 766c0315..7c20be1b 100644 --- a/network/network_helpers_test.go +++ b/network/network_helpers_test.go @@ -91,8 +91,10 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { # TYPE gatewayd_proxy_health_checks_total counter # HELP gatewayd_proxy_passthrough_terminations_total Number of proxy passthrough terminations by plugins # TYPE gatewayd_proxy_passthrough_terminations_total counter - # HELP gatewayd_proxy_passthroughs_total Number of successful proxy passthroughs - # TYPE gatewayd_proxy_passthroughs_total counter + # HELP gatewayd_proxy_passthroughs_to_client_total Number of successful proxy passthroughs + # TYPE gatewayd_proxy_passthroughs_to_client_total counter + # HELP gatewayd_proxy_passthroughs_to_server_total Number of successful proxy passthroughs + # TYPE gatewayd_proxy_passthroughs_to_server_total counter # HELP gatewayd_server_connections Number of server connections # TYPE gatewayd_server_connections gauge # HELP gatewayd_server_ticks_fired_total Total number of server ticks fired @@ -105,20 +107,21 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { want = metadata + ` gatewayd_bytes_received_from_client_sum 67 gatewayd_bytes_received_from_client_count 1 - gatewayd_bytes_received_from_server_sum 96 - gatewayd_bytes_received_from_server_count 4 + gatewayd_bytes_received_from_server_sum 24 + gatewayd_bytes_received_from_server_count 1 gatewayd_bytes_sent_to_client_sum 24 gatewayd_bytes_sent_to_client_count 1 - gatewayd_bytes_sent_to_server_sum 282 - gatewayd_bytes_sent_to_server_count 5 + gatewayd_bytes_sent_to_server_sum 67 + gatewayd_bytes_sent_to_server_count 1 gatewayd_client_connections 1 - gatewayd_plugin_hooks_executed_total 11 + gatewayd_plugin_hooks_executed_total 10 gatewayd_plugin_hooks_registered_total 0 gatewayd_plugins_loaded_total 0 gatewayd_proxied_connections 1 gatewayd_proxy_health_checks_total 0 gatewayd_proxy_passthrough_terminations_total 0 - gatewayd_proxy_passthroughs_total 1 + gatewayd_proxy_passthroughs_to_client_total 1 + gatewayd_proxy_passthroughs_to_server_total 1 gatewayd_server_connections 5 gatewayd_traffic_bytes_sum 182 gatewayd_traffic_bytes_count 4 diff --git a/network/proxy.go b/network/proxy.go index 66bf065d..5019bee2 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -316,7 +316,7 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { // If the hook wants to terminate the connection, do it. if pr.shouldTerminate(result) { if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil { - metrics.ProxyPassThroughs.Inc() + metrics.ProxyPassThroughsToClient.Inc() metrics.ProxyPassThroughTerminations.Inc() metrics.BytesSentToClient.Observe(float64(modReceived)) metrics.TotalTrafficBytes.Observe(float64(modReceived)) @@ -357,6 +357,8 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { } span.AddEvent("Ran the OnTrafficToServer hooks") + metrics.ProxyPassThroughsToServer.Inc() + return nil } @@ -459,7 +461,7 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { span.RecordError(errVerdict) } - metrics.ProxyPassThroughs.Inc() + metrics.ProxyPassThroughsToClient.Inc() return errVerdict } @@ -502,7 +504,12 @@ func (pr *Proxy) Shutdown() { pr.availableConnections.ForEach(func(key, value interface{}) bool { if cl, ok := value.(*Client); ok { - cl.Close() + if cl.IsConnected() { + // This will stop all the Conn.Read() and Conn.Write() calls. + // Ref: https://groups.google.com/g/golang-nuts/c/VPVWFrpIEyo + cl.Conn.SetDeadline(time.Now()) + cl.Close() + } } return true }) @@ -511,10 +518,16 @@ func (pr *Proxy) Shutdown() { pr.busyConnections.ForEach(func(key, value interface{}) bool { if conn, ok := key.(net.Conn); ok { + // This will stop all the Conn.Read() and Conn.Write() calls. + conn.SetDeadline(time.Now()) conn.Close() } if cl, ok := value.(*Client); ok { - cl.Close() + if cl != nil { + // This will stop all the Conn.Read() and Conn.Write() calls. + cl.Conn.SetDeadline(time.Now()) + cl.Close() + } } return true }) diff --git a/network/server.go b/network/server.go index 3280fc30..1c5891ec 100644 --- a/network/server.go +++ b/network/server.go @@ -251,7 +251,11 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { if err := s.proxy.PassThroughToServer(conn); err != nil { s.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - stopConnection <- struct{}{} + s.engine.mu.Lock() + if s.Status != config.Stopped { + stopConnection <- struct{}{} + } + s.engine.mu.Unlock() break } } @@ -265,7 +269,9 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { if err := s.proxy.PassThroughToClient(conn); err != nil { s.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - stopConnection <- struct{}{} + if s.Status != config.Stopped { + stopConnection <- struct{}{} + } break } } @@ -279,7 +285,7 @@ func (s *Server) OnShutdown(Engine) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnShutdown") defer span.End() - s.logger.Debug().Msg("GatewayD is shutting down...") + s.logger.Debug().Msg("GatewayD is shutting down") pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() @@ -376,7 +382,7 @@ func (s *Server) Run() error { // Start the server. origErr := Run(s.Network, addr, s) - if origErr != nil { + if origErr != nil && origErr.Unwrap() != nil { s.logger.Error().Err(origErr).Msg("Failed to start server") span.RecordError(origErr) return gerr.ErrFailedToStartServer.Wrap(origErr) diff --git a/network/server_test.go b/network/server_test.go index 3e646ac6..827a493d 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -2,6 +2,7 @@ package network import ( "bufio" + "bytes" "context" "errors" "io" @@ -23,11 +24,9 @@ import ( // TestRunServer tests an entire server run with a single client connection and hooks. func TestRunServer(t *testing.T) { errs := make(chan error) - defer close(errs) logger := logging.NewLogger(context.Background(), logging.LoggerConfig{ Output: []config.LogOutput{ - config.Console, config.File, }, TimeFormat: zerolog.TimeFormatUnix, @@ -57,9 +56,10 @@ func TestRunServer(t *testing.T) { errs <- errors.New("request is nil") //nolint:goerr113 } - logger.Info().Msg("Ingress traffic") if req, ok := paramsMap["request"].([]byte); ok { - assert.Equal(t, CreatePgStartupPacket(), req) + if !bytes.Equal(req, CreatePgStartupPacket()) { + errs <- errors.New("request does not match") //nolint:goerr113 + } } else { errs <- errors.New("request is not a []byte") //nolint:goerr113 } @@ -81,7 +81,9 @@ func TestRunServer(t *testing.T) { logger.Info().Msg("Ingress traffic") if req, ok := paramsMap["request"].([]byte); ok { - assert.Equal(t, CreatePgStartupPacket(), req) + if !bytes.Equal(req, CreatePgStartupPacket()) { + errs <- errors.New("request does not match") //nolint:goerr113 + } } else { errs <- errors.New("request is not a []byte") //nolint:goerr113 } @@ -175,7 +177,7 @@ func TestRunServer(t *testing.T) { "127.0.0.1:15432", config.DefaultTickInterval, Option{ - EnableTicker: false, + EnableTicker: true, }, proxy, logger, @@ -190,11 +192,17 @@ func TestRunServer(t *testing.T) { var waitGroup sync.WaitGroup waitGroup.Add(1) - go func(t *testing.T, server *Server, stop chan struct{}, waitGroup *sync.WaitGroup) { + go func(t *testing.T, server *Server, pluginRegistry *plugin.Registry, stop chan struct{}, waitGroup *sync.WaitGroup) { t.Helper() for { select { case <-stop: + server.Shutdown() + pluginRegistry.Shutdown() + + // Wait for the server to stop. + time.Sleep(100 * time.Millisecond) + // Read the log file and check if the log file contains the expected log messages. if _, err := os.Stat("server_test.log"); err == nil { logFile, err := os.Open("server_test.log") @@ -213,14 +221,15 @@ func TestRunServer(t *testing.T) { assert.Contains(t, logLines, "GatewayD is ticking...", "GatewayD should be ticking") assert.Contains(t, logLines, "Ingress traffic", "Ingress traffic should be logged") assert.Contains(t, logLines, "Egress traffic", "Egress traffic should be logged") - assert.Contains(t, logLines, "GatewayD is shutting down...", "GatewayD should be shutting down") + assert.Contains(t, logLines, "GatewayD is shutting down", "GatewayD should be shutting down") assert.NoError(t, os.Remove("server_test.log")) - server.Shutdown() } + waitGroup.Done() return case err := <-errs: server.Shutdown() + pluginRegistry.Shutdown() t.Log(err) t.Fail() waitGroup.Done() @@ -228,17 +237,17 @@ func TestRunServer(t *testing.T) { default: //nolint:staticcheck } } - }(t, server, stop, &waitGroup) + }(t, server, pluginRegistry, stop, &waitGroup) waitGroup.Add(1) - go func(t *testing.T, server *Server, errs chan error, waitGroup *sync.WaitGroup) { + go func(t *testing.T, server *Server, errs chan error, stop chan struct{}, waitGroup *sync.WaitGroup) { t.Helper() if err := server.Run(); err != nil { errs <- err t.Fail() } waitGroup.Done() - }(t, server, errs, &waitGroup) + }(t, server, errs, stop, &waitGroup) waitGroup.Add(1) go func(t *testing.T, server *Server, proxy *Proxy, stop chan struct{}, waitGroup *sync.WaitGroup) { @@ -290,13 +299,12 @@ func TestRunServer(t *testing.T) { CollectAndComparePrometheusMetrics(t) client.Close() - time.Sleep(100 * time.Millisecond) - stop <- struct{}{} - waitGroup.Done() - return + break } time.Sleep(100 * time.Millisecond) } + stop <- struct{}{} + waitGroup.Done() }(t, server, proxy, stop, &waitGroup) waitGroup.Wait()