Skip to content

Commit

Permalink
Fix connection close issue
Browse files Browse the repository at this point in the history
Use a separate rwmutex for server code
  • Loading branch information
mostafa committed Oct 13, 2023
1 parent 020e63f commit 22a5d09
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 56 deletions.
6 changes: 2 additions & 4 deletions network/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError {
stopConnection := make(chan struct{})
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
if action := server.OnTraffic(conn, stopConnection); action == Close {
return
stopConnection <- struct{}{}
}
}(server, conn, stopConnection)

Expand All @@ -174,9 +174,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError {
server.engine.mu.Lock()
server.engine.connections--
server.engine.mu.Unlock()
if action := server.OnClose(conn, err); action == Close {
return
}
server.OnClose(conn, err)
return
case <-server.engine.stopServer:
return
Expand Down
67 changes: 37 additions & 30 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type IProxy interface {
Disconnect(conn net.Conn) *gerr.GatewayDError
PassThroughToServer(conn net.Conn) *gerr.GatewayDError
PassThroughToClient(conn net.Conn) *gerr.GatewayDError
IsHealty(cl *Client) (*Client, *gerr.GatewayDError)
IsHealthy(cl *Client) (*Client, *gerr.GatewayDError)
IsExhausted() bool
Shutdown()
AvailableConnections() []string
Expand Down Expand Up @@ -160,7 +160,7 @@ func (pr *Proxy) Connect(conn net.Conn) *gerr.GatewayDError {
}
}

client, err := pr.IsHealty(client)
client, err := pr.IsHealthy(client)
if err != nil {
pr.logger.Error().Err(err).Msg("Failed to connect to the client")
span.RecordError(err)
Expand Down Expand Up @@ -207,34 +207,38 @@ func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError {
defer span.End()

client := pr.busyConnections.Pop(conn)
if client == nil {
// If this ever happens, it means that the client connection
// is pre-empted from the busy connections pool.
pr.logger.Debug().Msg("Client connection is pre-empted from the busy connections pool")
span.RecordError(gerr.ErrClientNotFound)
return gerr.ErrClientNotFound
}

//nolint:nestif
if client != nil {
if client, ok := client.(*Client); ok {
if (pr.Elastic && pr.ReuseElasticClients) || !pr.Elastic {
_, err := pr.IsHealty(client)
if err != nil {
pr.logger.Error().Err(err).Msg("Failed to reconnect to the client")
span.RecordError(err)
}
// If the client is not in the pool, put it back.
err = pr.availableConnections.Put(client.ID, client)
if err != nil {
pr.logger.Error().Err(err).Msg("Failed to put the client back in the pool")
span.RecordError(err)
}
} else {
span.RecordError(gerr.ErrClientNotConnected)
return gerr.ErrClientNotConnected
if client, ok := client.(*Client); ok {
if (pr.Elastic && pr.ReuseElasticClients) || !pr.Elastic {
// Recycle the server connection by reconnecting.
if err := client.Reconnect(); err != nil {
pr.logger.Error().Err(err).Msg("Failed to reconnect to the client")
span.RecordError(err)
}

// If the client is not in the pool, put it back.
if err := pr.availableConnections.Put(client.ID, client); err != nil {
pr.logger.Error().Err(err).Msg("Failed to put the client back in the pool")
span.RecordError(err)
}
} else {
// This should never happen, but if it does,
// then there are some serious issues with the pool.
span.RecordError(gerr.ErrCastFailed)
return gerr.ErrCastFailed
span.RecordError(gerr.ErrClientNotConnected)
return gerr.ErrClientNotConnected
}
} else {
span.RecordError(gerr.ErrClientNotFound)
return gerr.ErrClientNotFound
// This should never happen, but if it does,
// then there are some serious issues with the pool.
pr.logger.Error().Msg("Failed to cast the client to the Client type")
span.RecordError(gerr.ErrCastFailed)
return gerr.ErrCastFailed
}

metrics.ProxiedConnections.Dec()
Expand All @@ -255,7 +259,7 @@ func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError {
return nil
}

// PassThrough sends the data from the client to the server.
// PassThroughToServer sends the data from the client to the server.
func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError {
_, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough")
defer span.End()
Expand Down Expand Up @@ -466,9 +470,9 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError {
return errVerdict
}

// IsHealty checks if the pool is exhausted or the client is disconnected.
func (pr *Proxy) IsHealty(client *Client) (*Client, *gerr.GatewayDError) {
_, span := otel.Tracer(config.TracerName).Start(pr.ctx, "IsHealty")
// IsHealthy checks if the pool is exhausted or the client is disconnected.
func (pr *Proxy) IsHealthy(client *Client) (*Client, *gerr.GatewayDError) {
_, span := otel.Tracer(config.TracerName).Start(pr.ctx, "IsHealthy")
defer span.End()

if pr.IsExhausted() {
Expand Down Expand Up @@ -526,7 +530,10 @@ func (pr *Proxy) Shutdown() {
pr.logger.Error().Err(err).Msg("Error setting the deadline")
span.RecordError(err)
}
conn.Close()
if err := conn.Close(); err != nil {
pr.logger.Error().Err(err).Msg("Failed to close the connection")
span.RecordError(err)
}
}
if client, ok := value.(*Client); ok {
if client != nil {
Expand Down
4 changes: 2 additions & 2 deletions network/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestNewProxy(t *testing.T) {
assert.Equal(t, false, proxy.Elastic)
assert.Equal(t, false, proxy.ReuseElasticClients)
assert.Equal(t, false, proxy.IsExhausted())
c, err := proxy.IsHealty(client)
c, err := proxy.IsHealthy(client)
assert.Nil(t, err)
assert.Equal(t, client, c)
}
Expand Down Expand Up @@ -369,7 +369,7 @@ func BenchmarkProxyIsHealthyAndIsExhausted(b *testing.B) {

// Connect to the proxy
for i := 0; i < b.N; i++ {
proxy.IsHealty(client) //nolint:errcheck
proxy.IsHealthy(client) //nolint:errcheck
proxy.IsExhausted()
}
}
Expand Down
39 changes: 19 additions & 20 deletions network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"os"
"sync"
"time"

v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
Expand All @@ -25,6 +26,7 @@ type Server struct {
pluginRegistry *plugin.Registry
ctx context.Context //nolint:containedctx
pluginTimeout time.Duration
mu *sync.RWMutex

Network string // tcp/udp/unix
Address string
Expand Down Expand Up @@ -58,9 +60,9 @@ func (s *Server) OnBoot(engine Engine) Action {
s.engine = engine

// Set the server status to running.
s.engine.mu.Lock()
s.mu.Lock()
s.Status = config.Running
s.engine.mu.Unlock()
s.mu.Unlock()

// Run the OnBooted hooks.
_, err = s.pluginRegistry.Run(
Expand Down Expand Up @@ -173,11 +175,14 @@ func (s *Server) OnClose(conn net.Conn, err error) Action {
span.AddEvent("Ran the OnClosing hooks")

// Shutdown the server if there are no more connections and the server is stopped.
// This is used to shutdown the server gracefully.
// This is used to shut down the server gracefully.
s.mu.Lock()
if uint64(s.engine.CountConnections()) == 0 && s.Status == config.Stopped {
span.AddEvent("Shutting down the server")
s.mu.Unlock()
return Shutdown
}
s.mu.Unlock()

// Disconnect the connection from the proxy. This effectively removes the mapping between
// the incoming and the server connections in the pool of the busy connections and either
Expand Down Expand Up @@ -251,11 +256,7 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action {
if err := server.proxy.PassThroughToServer(conn); err != nil {
server.logger.Trace().Err(err).Msg("Failed to pass through traffic")
span.RecordError(err)
server.engine.mu.Lock()
if server.Status == config.Stopped {
stopConnection <- struct{}{}
}
server.engine.mu.Unlock()
stopConnection <- struct{}{}
break
}
}
Expand All @@ -269,17 +270,14 @@ func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action {
if err := server.proxy.PassThroughToClient(conn); err != nil {
server.logger.Trace().Err(err).Msg("Failed to pass through traffic")
span.RecordError(err)
server.engine.mu.Lock()
if server.Status == config.Stopped {
stopConnection <- struct{}{}
}
server.engine.mu.Unlock()
stopConnection <- struct{}{}
break
}
}
}(s, conn, stopConnection)

return None
<-stopConnection
return Close
}

// OnShutdown is called when the server is shutting down. It calls the OnShutdown hooks.
Expand All @@ -306,9 +304,9 @@ func (s *Server) OnShutdown(Engine) {
s.proxy.Shutdown()

// Set the server status to stopped. This is used to shutdown the server gracefully in OnClose.
s.engine.mu.Lock()
s.mu.Lock()
s.Status = config.Stopped
s.engine.mu.Unlock()
s.mu.Unlock()
}

// OnTick is called every TickInterval. It calls the OnTick hooks.
Expand Down Expand Up @@ -402,9 +400,9 @@ func (s *Server) Shutdown() {
s.proxy.Shutdown()

// Set the server status to stopped. This is used to shutdown the server gracefully in OnClose.
s.engine.mu.Lock()
s.mu.Lock()
s.Status = config.Stopped
s.engine.mu.Unlock()
s.mu.Unlock()

// Shutdown the server.
if err := s.engine.Stop(context.Background()); err != nil {
Expand All @@ -419,8 +417,8 @@ func (s *Server) IsRunning() bool {
defer span.End()
span.SetAttributes(attribute.Bool("status", s.Status == config.Running))

s.engine.mu.Lock()
defer s.engine.mu.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
return s.Status == config.Running
}

Expand Down Expand Up @@ -450,6 +448,7 @@ func NewServer(
logger: logger,
pluginRegistry: pluginRegistry,
pluginTimeout: pluginTimeout,
mu: &sync.RWMutex{},
}

// Try to resolve the address and log an error if it can't be resolved.
Expand Down

0 comments on commit 22a5d09

Please sign in to comment.