diff --git a/network/client.go b/network/client.go index 267e1df7..3d859599 100644 --- a/network/client.go +++ b/network/client.go @@ -231,7 +231,7 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { ctx = context.Background() } - var received int + total := 0 buffer := bytes.NewBuffer(nil) // Read the data in chunks. for ctx.Err() == nil { @@ -240,19 +240,19 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { if err != nil { c.logger.Error().Err(err).Msg("Couldn't receive data from the server") span.RecordError(err) - return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) + return total, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err) } - received += read + total += read buffer.Write(chunk[:read]) - if read == 0 || read < c.ReceiveChunkSize { + if read < c.ReceiveChunkSize { break } } span.AddEvent("Received data from server") - return received, buffer.Bytes(), nil + return total, buffer.Bytes(), nil } // Reconnect reconnects to the server. diff --git a/network/proxy.go b/network/proxy.go index 7d3e5aac..b0a0ddaf 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -494,8 +494,17 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate received, response, err := pr.receiveTrafficFromServer(client) span.AddEvent("Received traffic from server") - // If the response is empty, don't send anything, instead just close the ingress connection. - if received == 0 || err != nil { + // If there is no data to send to the client, + // we don't need to run the hooks and + // we obviously have no data to send to the client. + if received == 0 { + span.AddEvent("No data to send to client") + stack.PopLastRequest() + return nil + } + + // If there is an error, close the ingress connection. + if err != nil { fields := map[string]interface{}{"function": "proxy.passthrough"} if client.LocalAddr() != "" { fields["localAddr"] = client.LocalAddr() @@ -517,7 +526,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate // Get the last request from the stack. lastRequest := stack.PopLastRequest() - request := make([]byte, 0) + request := []byte{} if lastRequest != nil { request = lastRequest.Data } @@ -698,7 +707,7 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD defer span.End() // request contains the data from the client. - received := 0 + total := 0 buffer := bytes.NewBuffer(nil) for { chunk := make([]byte, pr.ClientConfig.ReceiveChunkSize) @@ -713,10 +722,10 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD return chunk[:read], gerr.ErrReadFailed.Wrap(err) } - received += read + total += read buffer.Write(chunk[:read]) - if received == 0 || received < pr.ClientConfig.ReceiveChunkSize { + if read < pr.ClientConfig.ReceiveChunkSize { break } @@ -725,10 +734,9 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD } } - length := len(buffer.Bytes()) pr.Logger.Debug().Fields( map[string]interface{}{ - "length": length, + "length": total, "local": LocalAddr(conn), "remote": RemoteAddr(conn), }, @@ -736,8 +744,8 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD span.AddEvent("Received data from client") - metrics.BytesReceivedFromClient.Observe(float64(length)) - metrics.TotalTrafficBytes.Observe(float64(length)) + metrics.BytesReceivedFromClient.Observe(float64(total)) + metrics.TotalTrafficBytes.Observe(float64(total)) return buffer.Bytes(), nil } diff --git a/network/server.go b/network/server.go index a1b50deb..210931a9 100644 --- a/network/server.go +++ b/network/server.go @@ -79,7 +79,7 @@ type Server struct { LoadbalancerStrategyName string LoadbalancerRules []config.LoadBalancingRule LoadbalancerConsistentHash *config.ConsistentHash - connectionToProxyMap map[*ConnWrapper]IProxy + connectionToProxyMap *sync.Map } var _ IServer = (*Server)(nil) @@ -181,7 +181,7 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { } // Assign connection to proxy - s.connectionToProxyMap[conn] = proxy + s.connectionToProxyMap.Store(conn, proxy) // Run the OnOpened hooks. pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout) @@ -696,7 +696,7 @@ func NewServer( connections: 0, running: &atomic.Bool{}, stopServer: make(chan struct{}), - connectionToProxyMap: make(map[*ConnWrapper]IProxy), + connectionToProxyMap: &sync.Map{}, LoadbalancerStrategyName: srv.LoadbalancerStrategyName, LoadbalancerRules: srv.LoadbalancerRules, LoadbalancerConsistentHash: srv.LoadbalancerConsistentHash, @@ -737,11 +737,19 @@ func (s *Server) CountConnections() int { // GetProxyForConnection returns the proxy associated with the given connection. func (s *Server) GetProxyForConnection(conn *ConnWrapper) (IProxy, bool) { - proxy, exists := s.connectionToProxyMap[conn] - return proxy, exists + proxy, exists := s.connectionToProxyMap.Load(conn) + if !exists { + return nil, false + } + + if proxy, ok := proxy.(IProxy); ok { + return proxy, true + } + + return nil, false } // RemoveConnectionFromMap removes the given connection from the connection-to-proxy map. func (s *Server) RemoveConnectionFromMap(conn *ConnWrapper) { - delete(s.connectionToProxyMap, conn) + s.connectionToProxyMap.Delete(conn) }