diff --git a/network/client.go b/network/client.go index 7526c0d1..3074b778 100644 --- a/network/client.go +++ b/network/client.go @@ -146,12 +146,23 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Send") defer span.End() - sent, err := c.Conn.Write(data) - if err != nil { - c.logger.Error().Err(err).Msg("Couldn't send data to the server") - span.RecordError(err) - return 0, gerr.ErrClientSendFailed.Wrap(err) + sent := 0 + received := len(data) + for { + if sent >= received { + break + } + + n, err := c.Conn.Write(data) + if err != nil { + c.logger.Error().Err(err).Msg("Couldn't send data to the server") + span.RecordError(err) + return 0, gerr.ErrClientSendFailed.Wrap(err) + } + + sent += n } + c.logger.Debug().Fields( map[string]interface{}{ "length": sent, diff --git a/network/proxy.go b/network/proxy.go index fbd04e8d..943113df 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -1,7 +1,11 @@ package network import ( + "bytes" "context" + "errors" + "io" + "net" "time" v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" @@ -12,15 +16,15 @@ import ( "github.com/gatewayd-io/gatewayd/pool" "github.com/getsentry/sentry-go" "github.com/go-co-op/gocron" - "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" "go.opentelemetry.io/otel" ) type IProxy interface { - Connect(gconn gnet.Conn) *gerr.GatewayDError - Disconnect(gconn gnet.Conn) *gerr.GatewayDError - PassThrough(gconn gnet.Conn) *gerr.GatewayDError + Connect(conn net.Conn) *gerr.GatewayDError + Disconnect(conn net.Conn) *gerr.GatewayDError + PassThroughToServer(conn net.Conn) *gerr.GatewayDError + PassThroughToClient(conn net.Conn) *gerr.GatewayDError IsHealty(cl *Client) (*Client, *gerr.GatewayDError) IsExhausted() bool Shutdown() @@ -123,7 +127,7 @@ func NewProxy( // Connect maps a server connection from the available connection pool to a incoming connection. // It returns an error if the pool is exhausted. If the pool is elastic, it creates a new client // and maps it to the incoming connection. -func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { +func (pr *Proxy) Connect(conn net.Conn) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Connect") defer span.End() @@ -162,7 +166,7 @@ func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { span.RecordError(err) } - if err := pr.busyConnections.Put(gconn, client); err != nil { + if err := pr.busyConnections.Put(conn, client); err != nil { // This should never happen. span.RecordError(err) return err @@ -173,7 +177,7 @@ func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { fields := map[string]interface{}{ "function": "proxy.connect", "client": "unknown", - "server": RemoteAddr(gconn), + "server": RemoteAddr(conn), } if client.ID != "" { fields["client"] = client.ID[:7] @@ -198,11 +202,11 @@ func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { // Disconnect removes the client from the busy connection pool and tries to recycle // the server connection. -func (pr *Proxy) Disconnect(gconn gnet.Conn) *gerr.GatewayDError { +func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Disconnect") defer span.End() - client := pr.busyConnections.Pop(gconn) + client := pr.busyConnections.Pop(conn) //nolint:nestif if client != nil { if client, ok := client.(*Client); ok { @@ -251,24 +255,20 @@ func (pr *Proxy) Disconnect(gconn gnet.Conn) *gerr.GatewayDError { return nil } -// PassThrough sends the data from the client to the server and vice versa. -func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { +// PassThrough sends the data from the client to the server. +func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough") defer span.End() - // TODO: Handle bi-directional traffic - // Currently the passthrough is a one-way street from the client to the server, that is, - // the client can send data to the server and receive the response back, but the server - // cannot take initiative and send data to the client. So, there should be another event-loop - // that listens for data from the server and sends it to the client. + // Check if the proxy has a egress client for the incoming connection. var client *Client - if pr.busyConnections.Get(gconn) == nil { + if pr.busyConnections.Get(conn) == nil { span.RecordError(gerr.ErrClientNotFound) return gerr.ErrClientNotFound } // Get the client from the busy connection pool. - if cl, ok := pr.busyConnections.Get(gconn).(*Client); ok { + if cl, ok := pr.busyConnections.Get(conn).(*Client); ok { client = cl } else { span.RecordError(gerr.ErrCastFailed) @@ -276,17 +276,22 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } span.AddEvent("Got the client from the busy connection pool") + if !client.IsConnected() || !pr.isConnectionHealthy(conn) { + return gerr.ErrClientNotConnected + } + // Receive the request from the client. - request, origErr := pr.receiveTrafficFromClient(gconn) + request, origErr := pr.receiveTrafficFromClient(conn) span.AddEvent("Received traffic from client") + // Run the OnTrafficFromClient hooks. pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout) defer cancel() - // Run the OnTrafficFromClient hooks. + result, err := pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ { @@ -302,6 +307,12 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } span.AddEvent("Ran the OnTrafficFromClient hooks") + if errors.Is(origErr, io.EOF) { + // Client closed the connection. + span.AddEvent("Client closed the connection") + return gerr.ErrClientNotConnected + } + // If the hook wants to terminate the connection, do it. if pr.shouldTerminate(result) { if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil { @@ -311,7 +322,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { metrics.TotalTrafficBytes.Observe(float64(modReceived)) span.AddEvent("Terminating connection") - return pr.sendTrafficToClient(gconn, modResponse, modReceived) + return pr.sendTrafficToClient(conn, modResponse, modReceived) } span.RecordError(gerr.ErrHookTerminatedConnection) return gerr.ErrHookTerminatedConnection @@ -330,7 +341,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { _, err = pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ { @@ -346,36 +357,39 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } span.AddEvent("Ran the OnTrafficToServer hooks") - // Receive the response from the server. - received, response, err := pr.receiveTrafficFromServer(client) - span.AddEvent("Received traffic from server") + return nil +} - // The connection to the server is closed, so we MUST reconnect, - // otherwise the client will be stuck. - // TODO: Fix bug in handling connection close - // See: https://github.com/gatewayd-io/gatewayd/issues/219 - if IsConnClosed(received, err) || IsConnTimedOut(err) { - pr.logger.Debug().Fields( - map[string]interface{}{ - "function": "proxy.passthrough", - "local": client.LocalAddr(), - "remote": client.RemoteAddr(), - }).Msg("Client disconnected") +// PassThroughToClient sends the data from the server to the client. +func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { + _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough") + defer span.End() - client.Close() - client = NewClient(pr.ctx, pr.ClientConfig, pr.logger) - pr.busyConnections.Remove(gconn) - if err := pr.busyConnections.Put(gconn, client); err != nil { - span.RecordError(err) - // This should never happen - return err - } + var client *Client + if pr.busyConnections.Get(conn) == nil { + span.RecordError(gerr.ErrClientNotFound) + return gerr.ErrClientNotFound + } + + // Get the client from the busy connection pool. + if cl, ok := pr.busyConnections.Get(conn).(*Client); ok { + client = cl + } else { + span.RecordError(gerr.ErrCastFailed) + return gerr.ErrCastFailed + } + span.AddEvent("Got the client from the busy connection pool") + + if !client.IsConnected() || !pr.isConnectionHealthy(conn) { + return gerr.ErrClientNotConnected } + // Receive the response from the server. + received, response, err := pr.receiveTrafficFromServer(client) + span.AddEvent("Received traffic from server") + // If the response is empty, don't send anything, instead just close the ingress connection. - // TODO: Fix bug in handling connection close - // See: https://github.com/gatewayd-io/gatewayd/issues/219 - if received == 0 { + if received == 0 || err != nil { pr.logger.Debug().Fields( map[string]interface{}{ "function": "proxy.passthrough", @@ -387,17 +401,16 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { return err } + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout) + defer cancel() + // Run the OnTrafficFromServer hooks. - result, err = pr.pluginRegistry.Run( + result, err := pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ - { - Name: "request", - Value: request, - }, { Name: "response", Value: response[:received], @@ -419,26 +432,22 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } // Send the response to the client. - errVerdict := pr.sendTrafficToClient(gconn, response, received) + errVerdict := pr.sendTrafficToClient(conn, response, received) span.AddEvent("Sent traffic to client") // Run the OnTrafficToClient hooks. _, err = pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ - { - Name: "request", - Value: request, - }, { Name: "response", Value: response[:received], }, }, - err, + nil, ), v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_CLIENT) if err != nil { @@ -501,8 +510,8 @@ func (pr *Proxy) Shutdown() { pr.logger.Debug().Msg("All available connections have been closed") pr.busyConnections.ForEach(func(key, value interface{}) bool { - if gconn, ok := key.(gnet.Conn); ok { - gconn.Close() + if conn, ok := key.(net.Conn); ok { + conn.Close() } if cl, ok := value.(*Client); ok { cl.Close() @@ -536,38 +545,60 @@ func (pr *Proxy) BusyConnections() []string { connections := make([]string, 0) pr.busyConnections.ForEach(func(key, _ interface{}) bool { - if gconn, ok := key.(gnet.Conn); ok { - connections = append(connections, RemoteAddr(gconn)) + if conn, ok := key.(net.Conn); ok { + connections = append(connections, RemoteAddr(conn)) } return true }) return connections } -// receiveTrafficFromClient is a function that receives data from the client. -func (pr *Proxy) receiveTrafficFromClient(gconn gnet.Conn) ([]byte, error) { +// receiveTrafficFromClient is a function that waits to receive data from the client. +func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, error) { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromClient") defer span.End() // request contains the data from the client. - request, err := gconn.Next(-1) - if err != nil { - pr.logger.Error().Err(err).Msg("Error reading from client") - span.RecordError(err) + received := 0 + buffer := bytes.NewBuffer(nil) + for { + chunk := make([]byte, pr.ClientConfig.ReceiveChunkSize) + read, err := conn.Read(chunk) + if read == 0 || err != nil { + pr.logger.Debug().Err(err).Msg("Error reading from client") + span.RecordError(err) + + metrics.BytesReceivedFromClient.Observe(float64(read)) + metrics.TotalTrafficBytes.Observe(float64(read)) + + return chunk[:read], err + } + + received += read + buffer.Write(chunk[:read]) + + if received == 0 || received < pr.ClientConfig.ReceiveChunkSize { + break + } + + if !pr.isConnectionHealthy(conn) { + break + } } + + length := len(buffer.Bytes()) pr.logger.Debug().Fields( map[string]interface{}{ - "length": len(request), - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "length": length, + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, ).Msg("Received data from client") - - metrics.BytesReceivedFromClient.Observe(float64(len(request))) - metrics.TotalTrafficBytes.Observe(float64(len(request))) + metrics.BytesReceivedFromClient.Observe(float64(length)) + metrics.TotalTrafficBytes.Observe(float64(length)) //nolint:wrapcheck - return request, err + return buffer.Bytes(), nil } // sendTrafficToServer is a function that sends data to the server. @@ -575,6 +606,11 @@ func (pr *Proxy) sendTrafficToServer(client *Client, request []byte) (int, *gerr _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "sendTrafficToServer") defer span.End() + if len(request) == 0 { + pr.logger.Trace().Msg("Empty request") + return 0, nil + } + // Send the request to the server. sent, err := client.Send(request) if err != nil { @@ -620,30 +656,39 @@ func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.Ga // sendTrafficToClient is a function that sends data to the client. func (pr *Proxy) sendTrafficToClient( - gconn gnet.Conn, response []byte, received int, + conn net.Conn, response []byte, received int, ) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "sendTrafficToClient") defer span.End() // Send the response to the client async. - origErr := gconn.AsyncWrite(response[:received], func(gconn gnet.Conn, err error) error { - pr.logger.Debug().Fields( - map[string]interface{}{ - "function": "proxy.passthrough", - "length": received, - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), - }, - ).Msg("Sent data to client") - span.RecordError(err) - return err - }) - if origErr != nil { - pr.logger.Error().Err(origErr).Msg("Error writing to client") - span.RecordError(origErr) - return gerr.ErrServerSendFailed.Wrap(origErr) + sent := 0 + for { + if sent >= received { + break + } + + n, origErr := conn.Write(response[:received]) + if origErr != nil { + pr.logger.Error().Err(origErr).Msg("Error writing to client") + span.RecordError(origErr) + return gerr.ErrServerSendFailed.Wrap(origErr) + } + + sent += n } + pr.logger.Debug().Fields( + map[string]interface{}{ + "function": "proxy.passthrough", + "length": sent, + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), + }, + ).Msg("Sent data to client") + + span.AddEvent("Sent data to client") + metrics.BytesSentToClient.Observe(float64(received)) metrics.TotalTrafficBytes.Observe(float64(received)) @@ -698,3 +743,17 @@ func (pr *Proxy) getPluginModifiedResponse(result map[string]interface{}) ([]byt return nil, 0 } + +func (pr *Proxy) isConnectionHealthy(conn net.Conn) bool { + if n, err := conn.Read([]byte{}); n == 0 && err != nil { + pr.logger.Debug().Fields( + map[string]interface{}{ + "remote": RemoteAddr(conn), + "local": LocalAddr(conn), + "reason": "read 0 bytes", + }).Msg("Connection to client is closed") + return false + } + + return true +} diff --git a/network/server.go b/network/server.go index cfba1da2..4ba918e9 100644 --- a/network/server.go +++ b/network/server.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io" + "net" "os" "time" @@ -13,15 +13,13 @@ import ( gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/metrics" "github.com/gatewayd-io/gatewayd/plugin" - "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" ) type Server struct { - gnet.BuiltinEventEngine - engine gnet.Engine + engine Engine proxy IProxy logger zerolog.Logger pluginRegistry *plugin.Registry @@ -30,7 +28,7 @@ type Server struct { Network string // tcp/udp/unix Address string - Options []gnet.Option + Options Option Status config.Status TickInterval time.Duration } @@ -38,7 +36,7 @@ type Server struct { // OnBoot is called when the server is booted. It calls the OnBooting and OnBooted hooks. // It also sets the status to running, which is used to determine if the server should be running // or shutdown. -func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { +func (s *Server) OnBoot(engine Engine) Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnBoot") defer span.End() @@ -75,16 +73,16 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { s.logger.Debug().Msg("GatewayD booted") - return gnet.None + return None } // OnOpen is called when a new connection is opened. It calls the OnOpening and OnOpened hooks. // It also checks if the server is at the soft or hard limit and closes the connection if it is. -func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { +func (s *Server) OnOpen(conn net.Conn) ([]byte, Action) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnOpen") defer span.End() - s.logger.Debug().Str("from", RemoteAddr(gconn)).Msg( + s.logger.Debug().Str("from", RemoteAddr(conn)).Msg( "GatewayD is opening a connection") pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) @@ -92,8 +90,8 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { // Run the OnOpening hooks. onOpeningData := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, } _, err := s.pluginRegistry.Run( @@ -107,24 +105,24 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { // Use the proxy to connect to the backend. Close the connection if the pool is exhausted. // This effectively get a connection from the pool and puts both the incoming and the server // connections in the pool of the busy connections. - if err := s.proxy.Connect(gconn); err != nil { + if err := s.proxy.Connect(conn); err != nil { if errors.Is(err, gerr.ErrPoolExhausted) { span.RecordError(err) - return nil, gnet.Close + return nil, Close } // This should never happen. // TODO: Send error to client or retry connection s.logger.Error().Err(err).Msg("Failed to connect to proxy") span.RecordError(err) - return nil, gnet.None + return nil, None } // Run the OnOpened hooks. onOpenedData := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, } _, err = s.pluginRegistry.Run( @@ -137,26 +135,27 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { metrics.ClientConnections.Inc() - return nil, gnet.None + return nil, None } // OnClose is called when a connection is closed. It calls the OnClosing and OnClosed hooks. // It also recycles the connection back to the available connection pool, unless the pool // is elastic and reuse is disabled. -func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { +func (s *Server) OnClose(conn net.Conn, err error) Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnClose") defer span.End() - s.logger.Debug().Str("from", RemoteAddr(gconn)).Msg( + s.logger.Debug().Str("from", RemoteAddr(conn)).Msg( "GatewayD is closing a connection") + // Run the OnClosing hooks. pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() - // Run the OnClosing hooks. + data := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, "error": "", } @@ -175,23 +174,30 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { // This is used to shutdown the server gracefully. if uint64(s.engine.CountConnections()) == 0 && s.Status == config.Stopped { span.AddEvent("Shutting down the server") - return gnet.Shutdown + return Shutdown } // Disconnect the connection from the proxy. This effectively removes the mapping between // the incoming and the server connections in the pool of the busy connections and either // recycles or disconnects the connections. - if err := s.proxy.Disconnect(gconn); err != nil { + if err := s.proxy.Disconnect(conn); err != nil { s.logger.Error().Err(err).Msg("Failed to disconnect the server connection") span.RecordError(err) - return gnet.Close + return Close + } + + // Close the incoming connection. + if err := conn.Close(); err != nil { + s.logger.Error().Err(err).Msg("Failed to close the incoming connection") + span.RecordError(err) + return Close } // Run the OnClosed hooks. data = map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, "error": "", } @@ -208,22 +214,23 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { metrics.ClientConnections.Dec() - return gnet.Close + return Close } // OnTraffic is called when data is received from the client. It calls the OnTraffic hooks. // It then passes the traffic to the proxied connection. -func (s *Server) OnTraffic(gconn gnet.Conn) gnet.Action { +func (s *Server) OnTraffic(conn net.Conn, stopConnection chan struct{}) Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnTraffic") defer span.End() + // Run the OnTraffic hooks. pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() - // Run the OnTraffic hooks. + onTrafficData := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, } _, err := s.pluginRegistry.Run( @@ -234,33 +241,39 @@ func (s *Server) OnTraffic(gconn gnet.Conn) gnet.Action { } span.AddEvent("Ran the OnTraffic hooks") - // Pass the traffic from the client to server and vice versa. + // Pass the traffic from the client to server. // If there is an error, log it and close the connection. - if err := s.proxy.PassThrough(gconn); err != nil { - s.logger.Trace().Err(err).Msg("Failed to pass through traffic") - span.RecordError(err) - switch { - case errors.Is(err, gerr.ErrPoolExhausted), - errors.Is(err, gerr.ErrCastFailed), - errors.Is(err, gerr.ErrClientNotFound), - errors.Is(err, gerr.ErrClientNotConnected), - errors.Is(err, gerr.ErrClientSendFailed), - errors.Is(err, gerr.ErrClientReceiveFailed), - errors.Is(err, gerr.ErrHookTerminatedConnection), - errors.Is(err.Unwrap(), io.EOF): - // TODO: Fix bug in handling connection close - // See: https://github.com/gatewayd-io/gatewayd/issues/219 - return gnet.Close + go func(s *Server, conn net.Conn, stopConnection chan struct{}) { + for { + s.logger.Trace().Msg("Passing through traffic from client to server") + if err := s.proxy.PassThroughToServer(conn); err != nil { + s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + span.RecordError(err) + stopConnection <- struct{}{} + break + } } - } - // Flush the connection to make sure all data is sent - gconn.Flush() + }(s, conn, stopConnection) + + // Pass the traffic from the server to client. + // If there is an error, log it and close the connection. + go func(s *Server, conn net.Conn, stopConnection chan struct{}) { + for { + s.logger.Debug().Msg("Passing through traffic from server to client") + if err := s.proxy.PassThroughToClient(conn); err != nil { + s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + span.RecordError(err) + stopConnection <- struct{}{} + break + } + } + }(s, conn, stopConnection) - return gnet.None + return None } // OnShutdown is called when the server is shutting down. It calls the OnShutdown hooks. -func (s *Server) OnShutdown(gnet.Engine) { +func (s *Server) OnShutdown(Engine) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnShutdown") defer span.End() @@ -287,7 +300,7 @@ func (s *Server) OnShutdown(gnet.Engine) { } // OnTick is called every TickInterval. It calls the OnTick hooks. -func (s *Server) OnTick() (time.Duration, gnet.Action) { +func (s *Server) OnTick() (time.Duration, Action) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnTick") defer span.End() @@ -314,7 +327,7 @@ func (s *Server) OnTick() (time.Duration, gnet.Action) { // TickInterval is the interval at which the OnTick hooks are called. It can be adjusted // in the configuration file. - return s.TickInterval, gnet.None + return s.TickInterval, None } // Run starts the server and blocks until the server is stopped. It calls the OnRun hooks. @@ -334,7 +347,7 @@ func (s *Server) Run() error { pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() // Run the OnRun hooks. - // Since gnet.Run is blocking, we need to run OnRun before it. + // Since Run is blocking, we need to run OnRun before it. onRunData := map[string]interface{}{"address": addr} if err != nil && err.Unwrap() != nil { onRunData["error"] = err.OriginalError.Error() @@ -358,7 +371,7 @@ func (s *Server) Run() error { } // Start the server. - origErr := gnet.Run(s, s.Network+"://"+addr, s.Options...) + origErr := Run(s.Network, addr, s, s.Options) if origErr != nil { s.logger.Error().Err(origErr).Msg("Failed to start server") span.RecordError(origErr) @@ -400,7 +413,7 @@ func NewServer( ctx context.Context, network, address string, tickInterval time.Duration, - options []gnet.Option, + options Option, proxy IProxy, logger zerolog.Logger, pluginRegistry *plugin.Registry, diff --git a/network/utils.go b/network/utils.go index ff85f337..97fbf559 100644 --- a/network/utils.go +++ b/network/utils.go @@ -9,7 +9,6 @@ import ( "net" gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" ) @@ -52,19 +51,19 @@ func Resolve(network, address string, logger zerolog.Logger) (string, *gerr.Gate // trafficData creates the ingress/egress map for the traffic hooks. func trafficData( - gconn gnet.Conn, + conn net.Conn, client *Client, fields []Field, err interface{}, ) map[string]interface{} { - if gconn == nil || client == nil { + if conn == nil || client == nil { return nil } data := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, "server": map[string]interface{}{ "local": client.LocalAddr(), @@ -126,17 +125,17 @@ func IsConnClosed(received int, err *gerr.GatewayDError) bool { } // LocalAddr returns the local address of the connection. -func LocalAddr(gconn gnet.Conn) string { - if gconn != nil && gconn.LocalAddr() != nil { - return gconn.LocalAddr().String() +func LocalAddr(conn net.Conn) string { + if conn != nil && conn.LocalAddr() != nil { + return conn.LocalAddr().String() } return "" } // RemoteAddr returns the remote address of the connection. -func RemoteAddr(gconn gnet.Conn) string { - if gconn != nil && gconn.RemoteAddr() != nil { - return gconn.RemoteAddr().String() +func RemoteAddr(conn net.Conn) string { + if conn != nil && conn.RemoteAddr() != nil { + return conn.RemoteAddr().String() } return "" }