diff --git a/network/engine.go b/network/engine.go index c9800e1a..62706db3 100644 --- a/network/engine.go +++ b/network/engine.go @@ -115,43 +115,49 @@ func Run(network, address string, server *Server) *gerr.GatewayDError { }(server) for { - conn, err := server.engine.listener.Accept() - if err != nil { - server.logger.Error().Err(err).Msg("Failed to accept connection") - return gerr.ErrAcceptFailed.Wrap(err) - } - - if out, action := server.OnOpen(conn); action != None { - if _, err := conn.Write(out); err != nil { - server.logger.Error().Err(err).Msg("Failed to write to connection") + select { + case <-server.engine.stopServer: + server.logger.Info().Msg("Server stopped") + return nil + default: + conn, err := server.engine.listener.Accept() + if err != nil { + server.logger.Error().Err(err).Msg("Failed to accept connection") + return gerr.ErrAcceptFailed.Wrap(err) } - conn.Close() - if action == Shutdown { - server.OnShutdown(server.engine) - return nil - } - } - server.engine.mu.Lock() - server.engine.connections++ - server.engine.mu.Unlock() - - // For every new connection, a new unbuffered channel is created to help - // stop the proxy, recycle the server connection and close stale connections. - stopConnection := make(chan struct{}) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - if action := server.OnTraffic(conn, stopConnection); action == Close { - return - } - }(server, conn, stopConnection) - go func(server *Server, conn net.Conn, stopConnection chan struct{}) { - <-stopConnection + if out, action := server.OnOpen(conn); action != None { + if _, err := conn.Write(out); err != nil { + server.logger.Error().Err(err).Msg("Failed to write to connection") + } + conn.Close() + if action == Shutdown { + server.OnShutdown(server.engine) + return nil + } + } server.engine.mu.Lock() - server.engine.connections-- + server.engine.connections++ server.engine.mu.Unlock() - if action := server.OnClose(conn, err); action == Close { - return - } - }(server, conn, stopConnection) + + // For every new connection, a new unbuffered channel is created to help + // stop the proxy, recycle the server connection and close stale connections. + stopConnection := make(chan struct{}) + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + if action := server.OnTraffic(conn, stopConnection); action == Close { + return + } + }(server, conn, stopConnection) + + go func(server *Server, conn net.Conn, stopConnection chan struct{}) { + <-stopConnection + server.engine.mu.Lock() + server.engine.connections-- + server.engine.mu.Unlock() + if action := server.OnClose(conn, err); action == Close { + return + } + }(server, conn, stopConnection) + } } } diff --git a/network/proxy.go b/network/proxy.go index 66bf065d..b57e97c8 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -558,6 +558,7 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromClient") defer span.End() + // TODO: Add a channel/context with cancel to stop the receive loop. // request contains the data from the client. received := 0 buffer := bytes.NewBuffer(nil) @@ -636,6 +637,7 @@ func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.Ga _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromServer") defer span.End() + // TODO: Add a channel/context with cancel to stop the receive loop. // Receive the response from the server. received, response, err := client.Receive() pr.logger.Debug().Fields( diff --git a/network/server_test.go b/network/server_test.go index 3e646ac6..fc98043f 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -195,6 +195,8 @@ func TestRunServer(t *testing.T) { for { select { case <-stop: + server.Shutdown() + // Read the log file and check if the log file contains the expected log messages. if _, err := os.Stat("server_test.log"); err == nil { logFile, err := os.Open("server_test.log") @@ -216,7 +218,6 @@ func TestRunServer(t *testing.T) { assert.Contains(t, logLines, "GatewayD is shutting down...", "GatewayD should be shutting down") assert.NoError(t, os.Remove("server_test.log")) - server.Shutdown() } return case err := <-errs: @@ -231,14 +232,15 @@ func TestRunServer(t *testing.T) { }(t, server, stop, &waitGroup) waitGroup.Add(1) - go func(t *testing.T, server *Server, errs chan error, waitGroup *sync.WaitGroup) { + go func(t *testing.T, server *Server, errs chan error, stop chan struct{}, waitGroup *sync.WaitGroup) { t.Helper() if err := server.Run(); err != nil { errs <- err t.Fail() } + t.Log("Server is being shut down") waitGroup.Done() - }(t, server, errs, &waitGroup) + }(t, server, errs, stop, &waitGroup) waitGroup.Add(1) go func(t *testing.T, server *Server, proxy *Proxy, stop chan struct{}, waitGroup *sync.WaitGroup) { @@ -290,13 +292,12 @@ func TestRunServer(t *testing.T) { CollectAndComparePrometheusMetrics(t) client.Close() - time.Sleep(100 * time.Millisecond) - stop <- struct{}{} - waitGroup.Done() - return + break } time.Sleep(100 * time.Millisecond) } + stop <- struct{}{} + waitGroup.Done() }(t, server, proxy, stop, &waitGroup) waitGroup.Wait()