Skip to content

Commit

Permalink
WIP: waitgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Oct 11, 2023
1 parent 74856a0 commit 4a1d08e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 41 deletions.
74 changes: 40 additions & 34 deletions network/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,43 +115,49 @@ func Run(network, address string, server *Server) *gerr.GatewayDError {
}(server)

for {
conn, err := server.engine.listener.Accept()
if err != nil {
server.logger.Error().Err(err).Msg("Failed to accept connection")
return gerr.ErrAcceptFailed.Wrap(err)
}

if out, action := server.OnOpen(conn); action != None {
if _, err := conn.Write(out); err != nil {
server.logger.Error().Err(err).Msg("Failed to write to connection")
select {
case <-server.engine.stopServer:
server.logger.Info().Msg("Server stopped")
return nil
default:
conn, err := server.engine.listener.Accept()
if err != nil {
server.logger.Error().Err(err).Msg("Failed to accept connection")
return gerr.ErrAcceptFailed.Wrap(err)
}
conn.Close()
if action == Shutdown {
server.OnShutdown(server.engine)
return nil
}
}
server.engine.mu.Lock()
server.engine.connections++
server.engine.mu.Unlock()

// For every new connection, a new unbuffered channel is created to help
// stop the proxy, recycle the server connection and close stale connections.
stopConnection := make(chan struct{})
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
if action := server.OnTraffic(conn, stopConnection); action == Close {
return
}
}(server, conn, stopConnection)

go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
<-stopConnection
if out, action := server.OnOpen(conn); action != None {
if _, err := conn.Write(out); err != nil {
server.logger.Error().Err(err).Msg("Failed to write to connection")
}
conn.Close()
if action == Shutdown {
server.OnShutdown(server.engine)
return nil
}
}
server.engine.mu.Lock()
server.engine.connections--
server.engine.connections++
server.engine.mu.Unlock()
if action := server.OnClose(conn, err); action == Close {
return
}
}(server, conn, stopConnection)

// For every new connection, a new unbuffered channel is created to help
// stop the proxy, recycle the server connection and close stale connections.
stopConnection := make(chan struct{})
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
if action := server.OnTraffic(conn, stopConnection); action == Close {
return
}
}(server, conn, stopConnection)

go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
<-stopConnection
server.engine.mu.Lock()
server.engine.connections--
server.engine.mu.Unlock()
if action := server.OnClose(conn, err); action == Close {
return
}
}(server, conn, stopConnection)
}
}
}
2 changes: 2 additions & 0 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD
_, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromClient")
defer span.End()

// TODO: Add a channel/context with cancel to stop the receive loop.
// request contains the data from the client.
received := 0
buffer := bytes.NewBuffer(nil)
Expand Down Expand Up @@ -636,6 +637,7 @@ func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.Ga
_, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromServer")
defer span.End()

// TODO: Add a channel/context with cancel to stop the receive loop.
// Receive the response from the server.
received, response, err := client.Receive()
pr.logger.Debug().Fields(
Expand Down
15 changes: 8 additions & 7 deletions network/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ func TestRunServer(t *testing.T) {
for {
select {
case <-stop:
server.Shutdown()

// Read the log file and check if the log file contains the expected log messages.
if _, err := os.Stat("server_test.log"); err == nil {
logFile, err := os.Open("server_test.log")
Expand All @@ -216,7 +218,6 @@ func TestRunServer(t *testing.T) {
assert.Contains(t, logLines, "GatewayD is shutting down...", "GatewayD should be shutting down")

assert.NoError(t, os.Remove("server_test.log"))
server.Shutdown()
}
return
case err := <-errs:
Expand All @@ -231,14 +232,15 @@ func TestRunServer(t *testing.T) {
}(t, server, stop, &waitGroup)

waitGroup.Add(1)
go func(t *testing.T, server *Server, errs chan error, waitGroup *sync.WaitGroup) {
go func(t *testing.T, server *Server, errs chan error, stop chan struct{}, waitGroup *sync.WaitGroup) {

Check failure on line 235 in network/server_test.go

View workflow job for this annotation

GitHub Actions / Test GatewayD

`TestRunServer$6` - `stop` is unused (unparam)
t.Helper()
if err := server.Run(); err != nil {
errs <- err
t.Fail()
}
t.Log("Server is being shut down")
waitGroup.Done()
}(t, server, errs, &waitGroup)
}(t, server, errs, stop, &waitGroup)

waitGroup.Add(1)
go func(t *testing.T, server *Server, proxy *Proxy, stop chan struct{}, waitGroup *sync.WaitGroup) {
Expand Down Expand Up @@ -290,13 +292,12 @@ func TestRunServer(t *testing.T) {
CollectAndComparePrometheusMetrics(t)

client.Close()
time.Sleep(100 * time.Millisecond)
stop <- struct{}{}
waitGroup.Done()
return
break
}
time.Sleep(100 * time.Millisecond)
}
stop <- struct{}{}
waitGroup.Done()
}(t, server, proxy, stop, &waitGroup)

waitGroup.Wait()
Expand Down

0 comments on commit 4a1d08e

Please sign in to comment.