Skip to content

Commit

Permalink
Internalize net.Conn in Client struct to prevent direct access to met…
Browse files Browse the repository at this point in the history
…hods

Use a mutex to prevent data races to client.conn
Reset context to prevent early timeouts
Check for field value before logging it
  • Loading branch information
mostafa committed Oct 15, 2023
1 parent 22a5d09 commit fb9fdab
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 46 deletions.
56 changes: 36 additions & 20 deletions network/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"

Expand All @@ -26,11 +27,11 @@ type IClient interface {
}

type Client struct {
net.Conn

conn net.Conn
logger zerolog.Logger
ctx context.Context //nolint:containedctx
connected atomic.Bool
mu sync.Mutex

TCPKeepAlive bool
TCPKeepAlivePeriod time.Duration
Expand Down Expand Up @@ -69,6 +70,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
// Create a resolved client.
client = Client{
ctx: clientCtx,
mu: sync.Mutex{},
Network: clientConfig.Network,
Address: addr,
}
Expand All @@ -90,14 +92,14 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
return nil
}

client.Conn = conn
client.conn = conn
client.connected.Store(true)

// Set the TCP keep alive.
client.TCPKeepAlive = clientConfig.TCPKeepAlive
client.TCPKeepAlivePeriod = clientConfig.TCPKeepAlivePeriod

if c, ok := client.Conn.(*net.TCPConn); ok {
if c, ok := client.conn.(*net.TCPConn); ok {
if err := c.SetKeepAlive(client.TCPKeepAlive); err != nil {
logger.Error().Err(err).Msg("Failed to set keep alive")
span.RecordError(err)
Expand All @@ -112,7 +114,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
// Set the receive deadline (timeout).
client.ReceiveDeadline = clientConfig.ReceiveDeadline
if client.ReceiveDeadline > 0 {
if err := client.Conn.SetReadDeadline(time.Now().Add(client.ReceiveDeadline)); err != nil {
if err := client.conn.SetReadDeadline(time.Now().Add(client.ReceiveDeadline)); err != nil {
logger.Error().Err(err).Msg("Failed to set receive deadline")
span.RecordError(err)
} else {
Expand All @@ -124,7 +126,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
// Set the send deadline (timeout).
client.SendDeadline = clientConfig.SendDeadline
if client.SendDeadline > 0 {
if err := client.Conn.SetWriteDeadline(time.Now().Add(client.SendDeadline)); err != nil {
if err := client.conn.SetWriteDeadline(time.Now().Add(client.SendDeadline)); err != nil {
logger.Error().Err(err).Msg("Failed to set send deadline")
span.RecordError(err)
} else {
Expand Down Expand Up @@ -163,7 +165,7 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) {
break
}

written, err := c.Conn.Write(data)
written, err := c.conn.Write(data)
if err != nil {
c.logger.Error().Err(err).Msg("Couldn't send data to the server")
span.RecordError(err)
Expand Down Expand Up @@ -207,7 +209,7 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
// Read the data in chunks.
for ctx.Err() == nil {
chunk := make([]byte, c.ReceiveChunkSize)
read, err := c.Conn.Read(chunk)
read, err := c.conn.Read(chunk)
if err != nil {
c.logger.Error().Err(err).Msg("Couldn't receive data from the server")
span.RecordError(err)
Expand All @@ -232,7 +234,7 @@ func (c *Client) Reconnect() error {
address := c.Address
network := c.Network

if c.Conn != nil {
if c.conn != nil {
c.Close()
}
c.connected.Store(false)
Expand All @@ -248,7 +250,7 @@ func (c *Client) Reconnect() error {
return gerr.ErrClientConnectionFailed.Wrap(err)
}

c.Conn = conn
c.conn = conn
c.ID = GetID(
conn.LocalAddr().Network(), conn.LocalAddr().String(), config.DefaultSeed, c.logger)
c.connected.Store(true)
Expand All @@ -263,21 +265,29 @@ func (c *Client) Close() {
_, span := otel.Tracer(config.TracerName).Start(c.ctx, "Close")
defer span.End()

c.mu.Lock()
defer c.mu.Unlock()

// Set the deadline to now so that the connection is closed immediately.
if err := c.Conn.SetDeadline(time.Now()); err != nil {
// This will stop all the Conn.Read() and Conn.Write() calls.
// Ref: https://groups.google.com/g/golang-nuts/c/VPVWFrpIEyo
if err := c.conn.SetDeadline(time.Now()); err != nil {
c.logger.Error().Err(err).Msg("Failed to set deadline")
span.RecordError(err)
}

c.connected.Store(false)
c.logger.Debug().Str("address", c.Address).Msg("Closing connection to server")
if c.Conn != nil {
c.Conn.Close()
if c.conn != nil {
if err := c.conn.Close(); err != nil {
c.logger.Error().Err(err).Msg("Failed to close connection")
span.RecordError(err)
}
}
c.ID = ""
c.Conn = nil
c.conn = nil
c.Address = ""
c.Network = ""
c.connected.Store(false)

metrics.ServerConnections.Dec()
}
Expand All @@ -298,7 +308,7 @@ func (c *Client) IsConnected() bool {
return false
}

if c != nil && c.Conn == nil || c.ID == "" {
if c != nil && c.conn == nil || c.ID == "" {
c.logger.Debug().Fields(
map[string]interface{}{
"address": c.Address,
Expand All @@ -316,8 +326,11 @@ func (c *Client) RemoteAddr() string {
return ""
}

if c.Conn != nil && c.Conn.RemoteAddr() != nil {
return c.Conn.RemoteAddr().String()
c.mu.Lock()
defer c.mu.Unlock()

if c.conn != nil && c.conn.RemoteAddr() != nil {
return c.conn.RemoteAddr().String()
}

return ""
Expand All @@ -329,8 +342,11 @@ func (c *Client) LocalAddr() string {
return ""
}

if c.Conn != nil && c.Conn.LocalAddr() != nil {
return c.Conn.LocalAddr().String()
c.mu.Lock()
defer c.mu.Unlock()

if c.conn != nil && c.conn.LocalAddr() != nil {
return c.conn.LocalAddr().String()
}

return ""
Expand Down
54 changes: 28 additions & 26 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ func (pr *Proxy) PassThroughToServer(conn net.Conn) *gerr.GatewayDError {
_, err = pr.sendTrafficToServer(client, request)
span.AddEvent("Sent traffic to server")

pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.pluginTimeout)
defer cancel()

// Run the OnTrafficToServer hooks.
_, err = pr.pluginRegistry.Run(
pluginTimeoutCtx,
Expand Down Expand Up @@ -396,12 +399,14 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError {

// If the response is empty, don't send anything, instead just close the ingress connection.
if received == 0 || err != nil {
pr.logger.Debug().Fields(
map[string]interface{}{
"function": "proxy.passthrough",
"local": client.LocalAddr(),
"remote": client.RemoteAddr(),
}).Msg("No data to send to client")
fields := map[string]interface{}{"function": "proxy.passthrough"}
if client.LocalAddr() != "" {
fields["local_addr"] = client.LocalAddr()
}
if client.RemoteAddr() != "" {
fields["remote_addr"] = client.RemoteAddr()
}
pr.logger.Debug().Fields(fields).Msg("No data to send to client")
span.AddEvent("No data to send to client")
span.RecordError(err)
return err
Expand Down Expand Up @@ -442,6 +447,9 @@ func (pr *Proxy) PassThroughToClient(conn net.Conn) *gerr.GatewayDError {
span.AddEvent("Sent traffic to client")

// Run the OnTrafficToClient hooks.
pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.pluginTimeout)
defer cancel()

_, err = pr.pluginRegistry.Run(
pluginTimeoutCtx,
trafficData(
Expand Down Expand Up @@ -509,12 +517,6 @@ func (pr *Proxy) Shutdown() {
pr.availableConnections.ForEach(func(key, value interface{}) bool {
if client, ok := value.(*Client); ok {
if client.IsConnected() {
// This will stop all the Conn.Read() and Conn.Write() calls.
// Ref: https://groups.google.com/g/golang-nuts/c/VPVWFrpIEyo
if err := client.Conn.SetDeadline(time.Now()); err != nil {
pr.logger.Error().Err(err).Msg("Error setting the deadline")
span.RecordError(err)
}
client.Close()
}
}
Expand All @@ -537,11 +539,6 @@ func (pr *Proxy) Shutdown() {
}
if client, ok := value.(*Client); ok {
if client != nil {
// This will stop all the Conn.Read() and Conn.Write() calls.
if err := client.Conn.SetDeadline(time.Now()); err != nil {
pr.logger.Error().Err(err).Msg("Error setting the deadline")
span.RecordError(err)
}
client.Close()
}
}
Expand All @@ -561,7 +558,7 @@ func (pr *Proxy) AvailableConnections() []string {
connections := make([]string, 0)
pr.availableConnections.ForEach(func(_, value interface{}) bool {
if cl, ok := value.(*Client); ok {
connections = append(connections, cl.Conn.LocalAddr().String())
connections = append(connections, cl.LocalAddr())
}
return true
})
Expand Down Expand Up @@ -668,14 +665,19 @@ func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.Ga

// Receive the response from the server.
received, response, err := client.Receive()
pr.logger.Debug().Fields(
map[string]interface{}{
"function": "proxy.passthrough",
"length": received,
"local": client.LocalAddr(),
"remote": client.RemoteAddr(),
},
).Msg("Received data from database")

fields := map[string]interface{}{
"function": "proxy.passthrough",
"length": received,
}
if client.LocalAddr() != "" {
fields["local"] = client.LocalAddr()
}
if client.RemoteAddr() != "" {
fields["remote"] = client.RemoteAddr()
}

pr.logger.Debug().Fields(fields).Msg("Received data from database")

metrics.BytesReceivedFromServer.Observe(float64(received))
metrics.TotalTrafficBytes.Observe(float64(received))
Expand Down

0 comments on commit fb9fdab

Please sign in to comment.