diff --git a/network/engine.go b/network/engine.go index ec026d02..c7158ab2 100644 --- a/network/engine.go +++ b/network/engine.go @@ -163,7 +163,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { stopConnection := make(chan struct{}) go func(server *Server, conn net.Conn, stopConnection chan struct{}) { if action := server.OnTraffic(conn, stopConnection); action == Close { - return + stopConnection <- struct{}{} } }(server, conn, stopConnection) @@ -174,9 +174,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { server.engine.mu.Lock() server.engine.connections-- server.engine.mu.Unlock() - if action := server.OnClose(conn, err); action == Close { - return - } + server.OnClose(conn, err) return case <-server.engine.stopServer: return diff --git a/network/proxy.go b/network/proxy.go index 02f2605c..a6ac55b1 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -25,7 +25,7 @@ type IProxy interface { Disconnect(conn net.Conn) *gerr.GatewayDError PassThroughToServer(conn net.Conn) *gerr.GatewayDError PassThroughToClient(conn net.Conn) *gerr.GatewayDError - IsHealty(cl *Client) (*Client, *gerr.GatewayDError) + IsHealthy(cl *Client) (*Client, *gerr.GatewayDError) IsExhausted() bool Shutdown() AvailableConnections() []string @@ -160,7 +160,7 @@ func (pr *Proxy) Connect(conn net.Conn) *gerr.GatewayDError { } } - client, err := pr.IsHealty(client) + client, err := pr.IsHealthy(client) if err != nil { pr.logger.Error().Err(err).Msg("Failed to connect to the client") span.RecordError(err) @@ -207,34 +207,38 @@ func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError { defer span.End() client := pr.busyConnections.Pop(conn) + if client == nil { + // If this ever happens, it means that the client connection + // is pre-empted from the busy connections pool. + pr.logger.Debug().Msg("Client connection is pre-empted from the busy connections pool") + span.RecordError(gerr.ErrClientNotFound) + return gerr.ErrClientNotFound + } + //nolint:nestif - if client != nil { - if client, ok := client.(*Client); ok { - if (pr.Elastic && pr.ReuseElasticClients) || !pr.Elastic { - _, err := pr.IsHealty(client) - if err != nil { - pr.logger.Error().Err(err).Msg("Failed to reconnect to the client") - span.RecordError(err) - } - // If the client is not in the pool, put it back. - err = pr.availableConnections.Put(client.ID, client) - if err != nil { - pr.logger.Error().Err(err).Msg("Failed to put the client back in the pool") - span.RecordError(err) - } - } else { - span.RecordError(gerr.ErrClientNotConnected) - return gerr.ErrClientNotConnected + if client, ok := client.(*Client); ok { + if (pr.Elastic && pr.ReuseElasticClients) || !pr.Elastic { + // Recycle the server connection by reconnecting. + if err := client.Reconnect(); err != nil { + pr.logger.Error().Err(err).Msg("Failed to reconnect to the client") + span.RecordError(err) + } + + // If the client is not in the pool, put it back. + if err := pr.availableConnections.Put(client.ID, client); err != nil { + pr.logger.Error().Err(err).Msg("Failed to put the client back in the pool") + span.RecordError(err) } } else { - // This should never happen, but if it does, - // then there are some serious issues with the pool. - span.RecordError(gerr.ErrCastFailed) - return gerr.ErrCastFailed + span.RecordError(gerr.ErrClientNotConnected) + return gerr.ErrClientNotConnected } } else { - span.RecordError(gerr.ErrClientNotFound) - return gerr.ErrClientNotFound + // This should never happen, but if it does, + // then there are some serious issues with the pool. + pr.logger.Error().Msg("Failed to cast the client to the Client type") + span.RecordError(gerr.ErrCastFailed) + return gerr.ErrCastFailed } metrics.ProxiedConnections.Dec() @@ -255,7 +259,7 @@ func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError { return nil } -// PassThrough sends the data from the client to the server. +// PassThroughToServer sends the data from the client to the server. func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough") defer span.End() @@ -466,9 +470,9 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { return errVerdict } -// IsHealty checks if the pool is exhausted or the client is disconnected. -func (pr *Proxy) IsHealty(client *Client) (*Client, *gerr.GatewayDError) { - _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "IsHealty") +// IsHealthy checks if the pool is exhausted or the client is disconnected. +func (pr *Proxy) IsHealthy(client *Client) (*Client, *gerr.GatewayDError) { + _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "IsHealthy") defer span.End() if pr.IsExhausted() { @@ -526,7 +530,10 @@ func (pr *Proxy) Shutdown() { pr.logger.Error().Err(err).Msg("Error setting the deadline") span.RecordError(err) } - conn.Close() + if err := conn.Close(); err != nil { + pr.logger.Error().Err(err).Msg("Failed to close the connection") + span.RecordError(err) + } } if client, ok := value.(*Client); ok { if client != nil { diff --git a/network/proxy_test.go b/network/proxy_test.go index ec072740..8bd82452 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -71,7 +71,7 @@ func TestNewProxy(t *testing.T) { assert.Equal(t, false, proxy.Elastic) assert.Equal(t, false, proxy.ReuseElasticClients) assert.Equal(t, false, proxy.IsExhausted()) - c, err := proxy.IsHealty(client) + c, err := proxy.IsHealthy(client) assert.Nil(t, err) assert.Equal(t, client, c) } @@ -369,7 +369,7 @@ func BenchmarkProxyIsHealthyAndIsExhausted(b *testing.B) { // Connect to the proxy for i := 0; i < b.N; i++ { - proxy.IsHealty(client) //nolint:errcheck + proxy.IsHealthy(client) //nolint:errcheck proxy.IsExhausted() } } diff --git a/network/server.go b/network/server.go index 1d85c300..7d4ba358 100644 --- a/network/server.go +++ b/network/server.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "os" + "sync" "time" v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" @@ -25,6 +26,7 @@ type Server struct { pluginRegistry *plugin.Registry ctx context.Context //nolint:containedctx pluginTimeout time.Duration + mu *sync.RWMutex Network string // tcp/udp/unix Address string @@ -58,9 +60,9 @@ func (s *Server) OnBoot(engine Engine) Action { s.engine = engine // Set the server status to running. - s.engine.mu.Lock() + s.mu.Lock() s.Status = config.Running - s.engine.mu.Unlock() + s.mu.Unlock() // Run the OnBooted hooks. _, err = s.pluginRegistry.Run( @@ -173,11 +175,14 @@ func (s *Server) OnClose(conn net.Conn, err error) Action { span.AddEvent("Ran the OnClosing hooks") // Shutdown the server if there are no more connections and the server is stopped. - // This is used to shutdown the server gracefully. + // This is used to shut down the server gracefully. + s.mu.Lock() if uint64(s.engine.CountConnections()) == 0 && s.Status == config.Stopped { span.AddEvent("Shutting down the server") + s.mu.Unlock() return Shutdown } + s.mu.Unlock() // Disconnect the connection from the proxy. This effectively removes the mapping between // the incoming and the server connections in the pool of the busy connections and either @@ -251,11 +256,7 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { if err := server.proxy.PassThroughToServer(conn); err != nil { server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - server.engine.mu.Lock() - if server.Status == config.Stopped { - stopConnection <- struct{}{} - } - server.engine.mu.Unlock() + stopConnection <- struct{}{} break } } @@ -269,17 +270,14 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { if err := server.proxy.PassThroughToClient(conn); err != nil { server.logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) - server.engine.mu.Lock() - if server.Status == config.Stopped { - stopConnection <- struct{}{} - } - server.engine.mu.Unlock() + stopConnection <- struct{}{} break } } }(s, conn, stopConnection) - return None + <-stopConnection + return Close } // OnShutdown is called when the server is shutting down. It calls the OnShutdown hooks. @@ -306,9 +304,9 @@ func (s *Server) OnShutdown(Engine) { s.proxy.Shutdown() // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. - s.engine.mu.Lock() + s.mu.Lock() s.Status = config.Stopped - s.engine.mu.Unlock() + s.mu.Unlock() } // OnTick is called every TickInterval. It calls the OnTick hooks. @@ -402,9 +400,9 @@ func (s *Server) Shutdown() { s.proxy.Shutdown() // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. - s.engine.mu.Lock() + s.mu.Lock() s.Status = config.Stopped - s.engine.mu.Unlock() + s.mu.Unlock() // Shutdown the server. if err := s.engine.Stop(context.Background()); err != nil { @@ -419,8 +417,8 @@ func (s *Server) IsRunning() bool { defer span.End() span.SetAttributes(attribute.Bool("status", s.Status == config.Running)) - s.engine.mu.Lock() - defer s.engine.mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() return s.Status == config.Running } @@ -450,6 +448,7 @@ func NewServer( logger: logger, pluginRegistry: pluginRegistry, pluginTimeout: pluginTimeout, + mu: &sync.RWMutex{}, } // Try to resolve the address and log an error if it can't be resolved.