diff --git a/cmd/run.go b/cmd/run.go index 4d7bb407..b676c03a 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -30,7 +30,6 @@ import ( usage "github.com/gatewayd-io/gatewayd/usagereport/v1" "github.com/getsentry/sentry-go" "github.com/go-co-op/gocron" - "github.com/panjf2000/gnet/v2" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" @@ -633,30 +632,30 @@ var runCmd = &cobra.Command{ cfg.Network, cfg.Address, cfg.GetTickInterval(), - []gnet.Option{ - // Scheduling options - gnet.WithMulticore(cfg.MultiCore), - gnet.WithLockOSThread(cfg.LockOSThread), - // NumEventLoop overrides Multicore option. - // gnet.WithNumEventLoop(1), - - // Can be used to send keepalive messages to the client. - gnet.WithTicker(cfg.EnableTicker), - - // Internal event-loop load balancing options - gnet.WithLoadBalancing(cfg.GetLoadBalancer()), - - // Buffer options - gnet.WithReadBufferCap(cfg.ReadBufferCap), - gnet.WithWriteBufferCap(cfg.WriteBufferCap), - gnet.WithSocketRecvBuffer(cfg.SocketRecvBuffer), - gnet.WithSocketSendBuffer(cfg.SocketSendBuffer), - - // TCP options - gnet.WithReuseAddr(cfg.ReuseAddress), - gnet.WithReusePort(cfg.ReusePort), - gnet.WithTCPKeepAlive(cfg.TCPKeepAlive), - gnet.WithTCPNoDelay(cfg.GetTCPNoDelay()), + []network.Option{ + // // Scheduling options + // gnet.WithMulticore(cfg.MultiCore), + // gnet.WithLockOSThread(cfg.LockOSThread), + // // NumEventLoop overrides Multicore option. + // // gnet.WithNumEventLoop(1), + + // // Can be used to send keepalive messages to the client. + // gnet.WithTicker(cfg.EnableTicker), + + // // Internal event-loop load balancing options + // gnet.WithLoadBalancing(cfg.GetLoadBalancer()), + + // // Buffer options + // gnet.WithReadBufferCap(cfg.ReadBufferCap), + // gnet.WithWriteBufferCap(cfg.WriteBufferCap), + // gnet.WithSocketRecvBuffer(cfg.SocketRecvBuffer), + // gnet.WithSocketSendBuffer(cfg.SocketSendBuffer), + + // // TCP options + // gnet.WithReuseAddr(cfg.ReuseAddress), + // gnet.WithReusePort(cfg.ReusePort), + // gnet.WithTCPKeepAlive(cfg.TCPKeepAlive), + // gnet.WithTCPNoDelay(cfg.GetTCPNoDelay()), }, proxies[name], logger, diff --git a/gatewayd.yaml b/gatewayd.yaml index f0032f2d..15386af9 100644 --- a/gatewayd.yaml +++ b/gatewayd.yaml @@ -3,7 +3,7 @@ loggers: default: output: ["console"] # "stdout", "stderr", "syslog", "rsyslog" and "file" - level: "info" # panic, fatal, error, warn, info (default), debug, trace + level: "debug" # panic, fatal, error, warn, info (default), debug, trace noColor: False timeFormat: "unix" # unixms, unixmicro and unixnano consoleTimeFormat: "RFC3339" # Go time format string diff --git a/gatewayd_plugins.yaml b/gatewayd_plugins.yaml index 655b7335..afc52483 100644 --- a/gatewayd_plugins.yaml +++ b/gatewayd_plugins.yaml @@ -75,7 +75,7 @@ timeout: 30s # and should only be used if one only has a single database in their PostgreSQL instance. plugins: - name: gatewayd-plugin-cache - enabled: True + enabled: False localPath: ../gatewayd-plugin-cache/gatewayd-plugin-cache args: ["--log-level", "debug"] env: diff --git a/network/client.go b/network/client.go index 7526c0d1..3074b778 100644 --- a/network/client.go +++ b/network/client.go @@ -146,12 +146,23 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { _, span := otel.Tracer(config.TracerName).Start(c.ctx, "Send") defer span.End() - sent, err := c.Conn.Write(data) - if err != nil { - c.logger.Error().Err(err).Msg("Couldn't send data to the server") - span.RecordError(err) - return 0, gerr.ErrClientSendFailed.Wrap(err) + sent := 0 + received := len(data) + for { + if sent >= received { + break + } + + n, err := c.Conn.Write(data) + if err != nil { + c.logger.Error().Err(err).Msg("Couldn't send data to the server") + span.RecordError(err) + return 0, gerr.ErrClientSendFailed.Wrap(err) + } + + sent += n } + c.logger.Debug().Fields( map[string]interface{}{ "length": sent, diff --git a/network/engine.go b/network/engine.go new file mode 100644 index 00000000..68c0a5e4 --- /dev/null +++ b/network/engine.go @@ -0,0 +1,176 @@ +package network + +import ( + "context" + "net" + "strconv" + "time" + + "github.com/rs/zerolog" +) + +type Option struct { + Multicore bool + NumEventLoop int + // LB LoadBalancing + ReuseAddr bool + ReusePort bool + MulticastInterfaceIndex int + ReadBufferCap int + WriteBufferCap int + LockOSThread bool + Ticker bool + TCPKeepAlive time.Duration + TCPNoDelay TCPSocketOpt + SocketRecvBuffer int + SocketSendBuffer int + LogPath string + LogLevel zerolog.Level + Logger zerolog.Logger +} + +type Action int + +const ( + None Action = iota + Close + Shutdown +) + +type TCPSocketOpt int + +const ( + TCPNoDelay TCPSocketOpt = iota + TCPDelay +) + +type Engine struct { + listener net.Listener + host string + port int + connections map[string]*net.Conn + bufferSize int + handler EventHandler + stopChannel chan struct{} +} + +func (engine *Engine) CountConnections() int { + return len(engine.connections) +} + +func (engine *Engine) Stop(ctx context.Context) error { + engine.stopChannel <- struct{}{} + return nil +} + +type ( + EventHandler interface { + OnBoot(eng Engine) (action Action) + OnShutdown(eng Engine) + OnOpen(c net.Conn) (out []byte, action Action) + OnClose(c net.Conn, err error) (action Action) + OnTraffic(c net.Conn) (action Action) + OnTick() (delay time.Duration, action Action) + } + + BuiltinEventEngine struct{} +) + +// Create a new TCP server. +func Run(network, address string, server *Server, opts ...Option) error { + engine := Engine{ + connections: make(map[string]*net.Conn), + stopChannel: make(chan struct{}), + } + + if action := server.OnBoot(engine); action != None { + return nil + } + server.logger.Debug().Msg("Server booted") + + if ln, err := net.Listen(network, address); err != nil { + server.logger.Error().Err(err).Msg("Failed to listen") + } else { + engine.listener = ln + } + defer engine.listener.Close() + server.logger.Debug().Str("address", engine.listener.Addr().String()).Msg("Server listening") + + if engine.listener == nil { + server.logger.Error().Msg("Listener is nil") + return nil + } + server.logger.Debug().Msg("Server started") + + if host, port, err := net.SplitHostPort(engine.listener.Addr().String()); err != nil { + server.logger.Error().Err(err).Msg("Failed to split host and port") + return err + } else { + engine.host = host + if engine.port, err = strconv.Atoi(port); err != nil { + server.logger.Error().Err(err).Msg("Failed to convert port to integer") + return err + } + } + + for { + // if <-engine.stopChannel == struct{}{} { + // server.logger.Debug().Msg("Server stopped") + // break + // } + server.logger.Debug().Msg("Server tick") + + conn, err := engine.listener.Accept() + if err != nil { + server.logger.Error().Err(err).Msg("Failed to accept connection") + return err + } + + server.logger.Debug().Str("address", conn.RemoteAddr().String()).Msg("Connection accepted") + + if out, action := server.OnOpen(conn); action != None { + conn.Write(out) + conn.Close() + if action == Shutdown { + server.logger.Debug().Str("address", conn.RemoteAddr().String()).Msg( + "Connection closed") + return nil + } + } + server.logger.Debug().Str("address", conn.RemoteAddr().String()).Msg("Connection accepted") + + // engine.connections[conn.RemoteAddr().String()] = &conn + go func(server *Server, conn net.Conn) { + for { + // if n, err := conn.Read([]byte{}); n == 0 && err != nil { + // return + // } + // if action := server.OnTraffic(conn); action == Close { + // if action := server.OnClose(conn, err); action == Close { + // conn.Close() + // break + // } + // } + if action := server.OnTraffic(conn); action == Close { + if action := server.OnClose(conn, err); action == Close { + // FIXME: this should be handled by the server + server.logger.Debug().Str("address", conn.RemoteAddr().String()).Msg( + "Connection closed") + conn.Close() + return + } + } + time.Sleep(100 * time.Millisecond) + } + }(server, conn) + + // defer delete(engine.connections, conn.RemoteAddr().String()) + + // if duration, action := server.OnTick(); action == Shutdown { + // return nil + // } else if duration > 0 { + // time.Sleep(duration) + // } + } + // engine.handler.OnShutdown(engine) +} diff --git a/network/proxy.go b/network/proxy.go index fbd04e8d..6493ae25 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -1,7 +1,11 @@ package network import ( + "bytes" "context" + "errors" + "io" + "net" "time" v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" @@ -12,15 +16,17 @@ import ( "github.com/gatewayd-io/gatewayd/pool" "github.com/getsentry/sentry-go" "github.com/go-co-op/gocron" - "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" "go.opentelemetry.io/otel" ) type IProxy interface { - Connect(gconn gnet.Conn) *gerr.GatewayDError - Disconnect(gconn gnet.Conn) *gerr.GatewayDError - PassThrough(gconn gnet.Conn) *gerr.GatewayDError + Connect(conn net.Conn) *gerr.GatewayDError + Disconnect(conn net.Conn) *gerr.GatewayDError + PassThroughToServer(conn net.Conn) *gerr.GatewayDError + PassThroughToClient(conn net.Conn) *gerr.GatewayDError + // TODO: Close both connections simultaneously. + // Close() *gerr.GatewayDError IsHealty(cl *Client) (*Client, *gerr.GatewayDError) IsExhausted() bool Shutdown() @@ -123,7 +129,7 @@ func NewProxy( // Connect maps a server connection from the available connection pool to a incoming connection. // It returns an error if the pool is exhausted. If the pool is elastic, it creates a new client // and maps it to the incoming connection. -func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { +func (pr *Proxy) Connect(conn net.Conn) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Connect") defer span.End() @@ -162,7 +168,7 @@ func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { span.RecordError(err) } - if err := pr.busyConnections.Put(gconn, client); err != nil { + if err := pr.busyConnections.Put(conn, client); err != nil { // This should never happen. span.RecordError(err) return err @@ -173,7 +179,7 @@ func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { fields := map[string]interface{}{ "function": "proxy.connect", "client": "unknown", - "server": RemoteAddr(gconn), + "server": RemoteAddr(conn), } if client.ID != "" { fields["client"] = client.ID[:7] @@ -198,11 +204,11 @@ func (pr *Proxy) Connect(gconn gnet.Conn) *gerr.GatewayDError { // Disconnect removes the client from the busy connection pool and tries to recycle // the server connection. -func (pr *Proxy) Disconnect(gconn gnet.Conn) *gerr.GatewayDError { +func (pr *Proxy) Disconnect(conn net.Conn) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Disconnect") defer span.End() - client := pr.busyConnections.Pop(gconn) + client := pr.busyConnections.Pop(conn) //nolint:nestif if client != nil { if client, ok := client.(*Client); ok { @@ -251,24 +257,20 @@ func (pr *Proxy) Disconnect(gconn gnet.Conn) *gerr.GatewayDError { return nil } -// PassThrough sends the data from the client to the server and vice versa. -func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { +// PassThrough sends the data from the client to the server. +func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough") defer span.End() - // TODO: Handle bi-directional traffic - // Currently the passthrough is a one-way street from the client to the server, that is, - // the client can send data to the server and receive the response back, but the server - // cannot take initiative and send data to the client. So, there should be another event-loop - // that listens for data from the server and sends it to the client. + // Check if the proxy has a egress client for the incoming connection. var client *Client - if pr.busyConnections.Get(gconn) == nil { + if pr.busyConnections.Get(conn) == nil { span.RecordError(gerr.ErrClientNotFound) return gerr.ErrClientNotFound } // Get the client from the busy connection pool. - if cl, ok := pr.busyConnections.Get(gconn).(*Client); ok { + if cl, ok := pr.busyConnections.Get(conn).(*Client); ok { client = cl } else { span.RecordError(gerr.ErrCastFailed) @@ -276,17 +278,22 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } span.AddEvent("Got the client from the busy connection pool") + if !client.IsConnected() || !pr.isConnectionHealthy(conn) { + return gerr.ErrClientNotConnected + } + // Receive the request from the client. - request, origErr := pr.receiveTrafficFromClient(gconn) + request, origErr := pr.receiveTrafficFromClient(conn) span.AddEvent("Received traffic from client") + // Run the OnTrafficFromClient hooks. pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout) defer cancel() - // Run the OnTrafficFromClient hooks. + result, err := pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ { @@ -302,6 +309,12 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } span.AddEvent("Ran the OnTrafficFromClient hooks") + if errors.Is(origErr, io.EOF) { + // Client closed the connection. + span.AddEvent("Client closed the connection") + return gerr.ErrClientNotConnected + } + // If the hook wants to terminate the connection, do it. if pr.shouldTerminate(result) { if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil { @@ -311,7 +324,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { metrics.TotalTrafficBytes.Observe(float64(modReceived)) span.AddEvent("Terminating connection") - return pr.sendTrafficToClient(gconn, modResponse, modReceived) + return pr.sendTrafficToClient(conn, modResponse, modReceived) } span.RecordError(gerr.ErrHookTerminatedConnection) return gerr.ErrHookTerminatedConnection @@ -330,7 +343,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { _, err = pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ { @@ -346,6 +359,33 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } span.AddEvent("Ran the OnTrafficToServer hooks") + return nil +} + +// PassThroughToClient sends the data from the server to the client. +func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError { + _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough") + defer span.End() + + var client *Client + if pr.busyConnections.Get(conn) == nil { + span.RecordError(gerr.ErrClientNotFound) + return gerr.ErrClientNotFound + } + + // Get the client from the busy connection pool. + if cl, ok := pr.busyConnections.Get(conn).(*Client); ok { + client = cl + } else { + span.RecordError(gerr.ErrCastFailed) + return gerr.ErrCastFailed + } + span.AddEvent("Got the client from the busy connection pool") + + if !client.IsConnected() || !pr.isConnectionHealthy(conn) { + return gerr.ErrClientNotConnected + } + // Receive the response from the server. received, response, err := pr.receiveTrafficFromServer(client) span.AddEvent("Received traffic from server") @@ -354,28 +394,27 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { // otherwise the client will be stuck. // TODO: Fix bug in handling connection close // See: https://github.com/gatewayd-io/gatewayd/issues/219 - if IsConnClosed(received, err) || IsConnTimedOut(err) { - pr.logger.Debug().Fields( - map[string]interface{}{ - "function": "proxy.passthrough", - "local": client.LocalAddr(), - "remote": client.RemoteAddr(), - }).Msg("Client disconnected") - - client.Close() - client = NewClient(pr.ctx, pr.ClientConfig, pr.logger) - pr.busyConnections.Remove(gconn) - if err := pr.busyConnections.Put(gconn, client); err != nil { - span.RecordError(err) - // This should never happen - return err - } - } + // if IsConnClosed(received, err) || IsConnTimedOut(err) { + // pr.logger.Debug().Fields( + // map[string]interface{}{ + // "function": "proxy.passthrough", + // "local": client.LocalAddr(), + // "remote": client.RemoteAddr(), + // }).Msg("Client disconnected") + // client.Close() + // client = NewClient(pr.ctx, pr.ClientConfig, pr.logger) + // pr.busyConnections.Remove(conn) + // if err := pr.busyConnections.Put(conn, client); err != nil { + // span.RecordError(err) + // // This should never happen + // return err + // } + // } // If the response is empty, don't send anything, instead just close the ingress connection. // TODO: Fix bug in handling connection close // See: https://github.com/gatewayd-io/gatewayd/issues/219 - if received == 0 { + if received == 0 || err != nil { pr.logger.Debug().Fields( map[string]interface{}{ "function": "proxy.passthrough", @@ -387,17 +426,20 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { return err } + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout) + defer cancel() + // Run the OnTrafficFromServer hooks. - result, err = pr.pluginRegistry.Run( + result, err := pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ - { - Name: "request", - Value: request, - }, + // { + // Name: "request", + // Value: request, + // }, { Name: "response", Value: response[:received], @@ -419,26 +461,26 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } // Send the response to the client. - errVerdict := pr.sendTrafficToClient(gconn, response, received) + errVerdict := pr.sendTrafficToClient(conn, response, received) span.AddEvent("Sent traffic to client") // Run the OnTrafficToClient hooks. _, err = pr.pluginRegistry.Run( pluginTimeoutCtx, trafficData( - gconn, + conn, client, []Field{ - { - Name: "request", - Value: request, - }, + // { + // Name: "request", + // Value: request, + // }, { Name: "response", Value: response[:received], }, }, - err, + nil, ), v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_CLIENT) if err != nil { @@ -501,8 +543,8 @@ func (pr *Proxy) Shutdown() { pr.logger.Debug().Msg("All available connections have been closed") pr.busyConnections.ForEach(func(key, value interface{}) bool { - if gconn, ok := key.(gnet.Conn); ok { - gconn.Close() + if conn, ok := key.(net.Conn); ok { + conn.Close() } if cl, ok := value.(*Client); ok { cl.Close() @@ -536,38 +578,62 @@ func (pr *Proxy) BusyConnections() []string { connections := make([]string, 0) pr.busyConnections.ForEach(func(key, _ interface{}) bool { - if gconn, ok := key.(gnet.Conn); ok { - connections = append(connections, RemoteAddr(gconn)) + if conn, ok := key.(net.Conn); ok { + connections = append(connections, RemoteAddr(conn)) } return true }) return connections } -// receiveTrafficFromClient is a function that receives data from the client. -func (pr *Proxy) receiveTrafficFromClient(gconn gnet.Conn) ([]byte, error) { +// receiveTrafficFromClient is a function that waits to receive data from the client. +func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, error) { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromClient") defer span.End() // request contains the data from the client. - request, err := gconn.Next(-1) - if err != nil { - pr.logger.Error().Err(err).Msg("Error reading from client") - span.RecordError(err) + received := 0 + buffer := bytes.NewBuffer(nil) + for { + // TODO: Make this configurable: Server.ReceiveChunkSize + chunk := make([]byte, pr.ClientConfig.ReceiveChunkSize) + read, err := conn.Read(chunk) + if read == 0 || err != nil { + pr.logger.Debug().Err(err).Msg("Error reading from client") + span.RecordError(err) + + metrics.BytesReceivedFromClient.Observe(float64(read)) + metrics.TotalTrafficBytes.Observe(float64(read)) + + return chunk[:read], err + } + + received += read + buffer.Write(chunk[:read]) + + // TODO: Make this configurable: Server.ReceiveChunkSize + if received == 0 || received < pr.ClientConfig.ReceiveChunkSize { + break + } + + if !pr.isConnectionHealthy(conn) { + break + } } + + length := len(buffer.Bytes()) pr.logger.Debug().Fields( map[string]interface{}{ - "length": len(request), - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "length": length, + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, ).Msg("Received data from client") - - metrics.BytesReceivedFromClient.Observe(float64(len(request))) - metrics.TotalTrafficBytes.Observe(float64(len(request))) + metrics.BytesReceivedFromClient.Observe(float64(length)) + metrics.TotalTrafficBytes.Observe(float64(length)) //nolint:wrapcheck - return request, err + return buffer.Bytes(), nil } // sendTrafficToServer is a function that sends data to the server. @@ -575,6 +641,11 @@ func (pr *Proxy) sendTrafficToServer(client *Client, request []byte) (int, *gerr _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "sendTrafficToServer") defer span.End() + if len(request) == 0 { + pr.logger.Trace().Msg("Empty request") + return 0, nil + } + // Send the request to the server. sent, err := client.Send(request) if err != nil { @@ -620,30 +691,39 @@ func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.Ga // sendTrafficToClient is a function that sends data to the client. func (pr *Proxy) sendTrafficToClient( - gconn gnet.Conn, response []byte, received int, + conn net.Conn, response []byte, received int, ) *gerr.GatewayDError { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "sendTrafficToClient") defer span.End() // Send the response to the client async. - origErr := gconn.AsyncWrite(response[:received], func(gconn gnet.Conn, err error) error { - pr.logger.Debug().Fields( - map[string]interface{}{ - "function": "proxy.passthrough", - "length": received, - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), - }, - ).Msg("Sent data to client") - span.RecordError(err) - return err - }) - if origErr != nil { - pr.logger.Error().Err(origErr).Msg("Error writing to client") - span.RecordError(origErr) - return gerr.ErrServerSendFailed.Wrap(origErr) + sent := 0 + for { + if sent >= received { + break + } + + n, origErr := conn.Write(response[:received]) + if origErr != nil { + pr.logger.Error().Err(origErr).Msg("Error writing to client") + span.RecordError(origErr) + return gerr.ErrServerSendFailed.Wrap(origErr) + } + + sent += n } + pr.logger.Debug().Fields( + map[string]interface{}{ + "function": "proxy.passthrough", + "length": sent, + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), + }, + ).Msg("Sent data to client") + + span.AddEvent("Sent data to client") + metrics.BytesSentToClient.Observe(float64(received)) metrics.TotalTrafficBytes.Observe(float64(received)) @@ -698,3 +778,17 @@ func (pr *Proxy) getPluginModifiedResponse(result map[string]interface{}) ([]byt return nil, 0 } + +func (pr *Proxy) isConnectionHealthy(conn net.Conn) bool { + if n, err := conn.Read([]byte{}); n == 0 && err != nil { + pr.logger.Debug().Fields( + map[string]interface{}{ + "remote": RemoteAddr(conn), + "local": LocalAddr(conn), + "reason": "read 0 bytes", + }).Msg("Connection to client is closed") + return false + } + + return true +} diff --git a/network/server.go b/network/server.go index cfba1da2..eeca6a37 100644 --- a/network/server.go +++ b/network/server.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io" + "net" "os" "time" @@ -13,15 +13,14 @@ import ( gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/metrics" "github.com/gatewayd-io/gatewayd/plugin" - "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" ) type Server struct { - gnet.BuiltinEventEngine - engine gnet.Engine + BuiltinEventEngine + engine Engine proxy IProxy logger zerolog.Logger pluginRegistry *plugin.Registry @@ -30,7 +29,7 @@ type Server struct { Network string // tcp/udp/unix Address string - Options []gnet.Option + Options []Option Status config.Status TickInterval time.Duration } @@ -38,7 +37,7 @@ type Server struct { // OnBoot is called when the server is booted. It calls the OnBooting and OnBooted hooks. // It also sets the status to running, which is used to determine if the server should be running // or shutdown. -func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { +func (s *Server) OnBoot(engine Engine) Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnBoot") defer span.End() @@ -75,16 +74,16 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { s.logger.Debug().Msg("GatewayD booted") - return gnet.None + return None } // OnOpen is called when a new connection is opened. It calls the OnOpening and OnOpened hooks. // It also checks if the server is at the soft or hard limit and closes the connection if it is. -func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { +func (s *Server) OnOpen(conn net.Conn) ([]byte, Action) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnOpen") defer span.End() - s.logger.Debug().Str("from", RemoteAddr(gconn)).Msg( + s.logger.Debug().Str("from", RemoteAddr(conn)).Msg( "GatewayD is opening a connection") pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) @@ -92,8 +91,8 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { // Run the OnOpening hooks. onOpeningData := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, } _, err := s.pluginRegistry.Run( @@ -107,24 +106,24 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { // Use the proxy to connect to the backend. Close the connection if the pool is exhausted. // This effectively get a connection from the pool and puts both the incoming and the server // connections in the pool of the busy connections. - if err := s.proxy.Connect(gconn); err != nil { + if err := s.proxy.Connect(conn); err != nil { if errors.Is(err, gerr.ErrPoolExhausted) { span.RecordError(err) - return nil, gnet.Close + return nil, Close } // This should never happen. // TODO: Send error to client or retry connection s.logger.Error().Err(err).Msg("Failed to connect to proxy") span.RecordError(err) - return nil, gnet.None + return nil, None } // Run the OnOpened hooks. onOpenedData := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, } _, err = s.pluginRegistry.Run( @@ -137,26 +136,27 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { metrics.ClientConnections.Inc() - return nil, gnet.None + return nil, None } // OnClose is called when a connection is closed. It calls the OnClosing and OnClosed hooks. // It also recycles the connection back to the available connection pool, unless the pool // is elastic and reuse is disabled. -func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { +func (s *Server) OnClose(conn net.Conn, err error) Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnClose") defer span.End() - s.logger.Debug().Str("from", RemoteAddr(gconn)).Msg( + s.logger.Debug().Str("from", RemoteAddr(conn)).Msg( "GatewayD is closing a connection") + // Run the OnClosing hooks. pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() - // Run the OnClosing hooks. + data := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, "error": "", } @@ -175,23 +175,23 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { // This is used to shutdown the server gracefully. if uint64(s.engine.CountConnections()) == 0 && s.Status == config.Stopped { span.AddEvent("Shutting down the server") - return gnet.Shutdown + return Shutdown } // Disconnect the connection from the proxy. This effectively removes the mapping between // the incoming and the server connections in the pool of the busy connections and either // recycles or disconnects the connections. - if err := s.proxy.Disconnect(gconn); err != nil { + if err := s.proxy.Disconnect(conn); err != nil { s.logger.Error().Err(err).Msg("Failed to disconnect the server connection") span.RecordError(err) - return gnet.Close + return Close } // Run the OnClosed hooks. data = map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, "error": "", } @@ -208,22 +208,23 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { metrics.ClientConnections.Dec() - return gnet.Close + return Close } // OnTraffic is called when data is received from the client. It calls the OnTraffic hooks. // It then passes the traffic to the proxied connection. -func (s *Server) OnTraffic(gconn gnet.Conn) gnet.Action { +func (s *Server) OnTraffic(conn net.Conn) Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnTraffic") defer span.End() + // Run the OnTraffic hooks. pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() - // Run the OnTraffic hooks. + onTrafficData := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, } _, err := s.pluginRegistry.Run( @@ -234,33 +235,67 @@ func (s *Server) OnTraffic(gconn gnet.Conn) gnet.Action { } span.AddEvent("Ran the OnTraffic hooks") - // Pass the traffic from the client to server and vice versa. + // Pass the traffic from the client to server. // If there is an error, log it and close the connection. - if err := s.proxy.PassThrough(gconn); err != nil { - s.logger.Trace().Err(err).Msg("Failed to pass through traffic") - span.RecordError(err) - switch { - case errors.Is(err, gerr.ErrPoolExhausted), - errors.Is(err, gerr.ErrCastFailed), - errors.Is(err, gerr.ErrClientNotFound), - errors.Is(err, gerr.ErrClientNotConnected), - errors.Is(err, gerr.ErrClientSendFailed), - errors.Is(err, gerr.ErrClientReceiveFailed), - errors.Is(err, gerr.ErrHookTerminatedConnection), - errors.Is(err.Unwrap(), io.EOF): - // TODO: Fix bug in handling connection close - // See: https://github.com/gatewayd-io/gatewayd/issues/219 - return gnet.Close + go func(s *Server, conn net.Conn) { + if err := s.proxy.PassThroughToServer(conn); err != nil { + s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + span.RecordError(err) + + // switch { + // case errors.Is(err, gerr.ErrPoolExhausted), + // errors.Is(err, gerr.ErrCastFailed), + // errors.Is(err, gerr.ErrClientNotFound), + // errors.Is(err, gerr.ErrClientNotConnected), + // errors.Is(err, gerr.ErrClientSendFailed), + // errors.Is(err, gerr.ErrClientReceiveFailed), + // errors.Is(err, gerr.ErrHookTerminatedConnection), + // errors.Is(err.Unwrap(), io.EOF): + // // TODO: Fix bug in handling connection close + // // See: https://github.com/gatewayd-io/gatewayd/issues/219 + // // if err := conn.Close(); err != nil { + // // s.logger.Error().Err(err).Msg("Failed to close connection") + // // span.RecordError(err) + // // } + // return + // } } - } - // Flush the connection to make sure all data is sent - gconn.Flush() + return + }(s, conn) + + // Pass the traffic from the server to client. + // If there is an error, log it and close the connection. + go func(s *Server, conn net.Conn) { + if err := s.proxy.PassThroughToClient(conn); err != nil { + s.logger.Trace().Err(err).Msg("Failed to pass through traffic") + span.RecordError(err) + + // switch { + // case errors.Is(err, gerr.ErrPoolExhausted), + // errors.Is(err, gerr.ErrCastFailed), + // errors.Is(err, gerr.ErrClientNotFound), + // errors.Is(err, gerr.ErrClientNotConnected), + // errors.Is(err, gerr.ErrClientSendFailed), + // errors.Is(err, gerr.ErrClientReceiveFailed), + // errors.Is(err, gerr.ErrHookTerminatedConnection), + // errors.Is(err.Unwrap(), io.EOF): + // // TODO: Fix bug in handling connection close + // // See: https://github.com/gatewayd-io/gatewayd/issues/219 + // if err := conn.Close(); err != nil { + // s.logger.Error().Err(err).Msg("Failed to close connection") + // span.RecordError(err) + // } + // return + // } + } + return + }(s, conn) - return gnet.None + return None } // OnShutdown is called when the server is shutting down. It calls the OnShutdown hooks. -func (s *Server) OnShutdown(gnet.Engine) { +func (s *Server) OnShutdown(Engine) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnShutdown") defer span.End() @@ -287,7 +322,7 @@ func (s *Server) OnShutdown(gnet.Engine) { } // OnTick is called every TickInterval. It calls the OnTick hooks. -func (s *Server) OnTick() (time.Duration, gnet.Action) { +func (s *Server) OnTick() (time.Duration, Action) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnTick") defer span.End() @@ -314,7 +349,7 @@ func (s *Server) OnTick() (time.Duration, gnet.Action) { // TickInterval is the interval at which the OnTick hooks are called. It can be adjusted // in the configuration file. - return s.TickInterval, gnet.None + return s.TickInterval, None } // Run starts the server and blocks until the server is stopped. It calls the OnRun hooks. @@ -334,7 +369,7 @@ func (s *Server) Run() error { pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) defer cancel() // Run the OnRun hooks. - // Since gnet.Run is blocking, we need to run OnRun before it. + // Since Run is blocking, we need to run OnRun before it. onRunData := map[string]interface{}{"address": addr} if err != nil && err.Unwrap() != nil { onRunData["error"] = err.OriginalError.Error() @@ -358,7 +393,7 @@ func (s *Server) Run() error { } // Start the server. - origErr := gnet.Run(s, s.Network+"://"+addr, s.Options...) + origErr := Run(s.Network, addr, s, s.Options...) if origErr != nil { s.logger.Error().Err(origErr).Msg("Failed to start server") span.RecordError(origErr) @@ -400,7 +435,7 @@ func NewServer( ctx context.Context, network, address string, tickInterval time.Duration, - options []gnet.Option, + options []Option, proxy IProxy, logger zerolog.Logger, pluginRegistry *plugin.Registry, diff --git a/network/utils.go b/network/utils.go index ff85f337..97fbf559 100644 --- a/network/utils.go +++ b/network/utils.go @@ -9,7 +9,6 @@ import ( "net" gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" ) @@ -52,19 +51,19 @@ func Resolve(network, address string, logger zerolog.Logger) (string, *gerr.Gate // trafficData creates the ingress/egress map for the traffic hooks. func trafficData( - gconn gnet.Conn, + conn net.Conn, client *Client, fields []Field, err interface{}, ) map[string]interface{} { - if gconn == nil || client == nil { + if conn == nil || client == nil { return nil } data := map[string]interface{}{ "client": map[string]interface{}{ - "local": LocalAddr(gconn), - "remote": RemoteAddr(gconn), + "local": LocalAddr(conn), + "remote": RemoteAddr(conn), }, "server": map[string]interface{}{ "local": client.LocalAddr(), @@ -126,17 +125,17 @@ func IsConnClosed(received int, err *gerr.GatewayDError) bool { } // LocalAddr returns the local address of the connection. -func LocalAddr(gconn gnet.Conn) string { - if gconn != nil && gconn.LocalAddr() != nil { - return gconn.LocalAddr().String() +func LocalAddr(conn net.Conn) string { + if conn != nil && conn.LocalAddr() != nil { + return conn.LocalAddr().String() } return "" } // RemoteAddr returns the remote address of the connection. -func RemoteAddr(gconn gnet.Conn) string { - if gconn != nil && gconn.RemoteAddr() != nil { - return gconn.RemoteAddr().String() +func RemoteAddr(conn net.Conn) string { + if conn != nil && conn.RemoteAddr() != nil { + return conn.RemoteAddr().String() } return "" }