Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Oct 11, 2023
1 parent 74856a0 commit 7cf63b5
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 113 deletions.
15 changes: 10 additions & 5 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ func StopGracefully(
}
}

logger.Info().Msg("Stopping GatewayD")
span.AddEvent("Stopping GatewayD", trace.WithAttributes(
logger.Info().Msg("GatewayD is shutting down")
span.AddEvent("GatewayD is shutting down", trace.WithAttributes(
attribute.String("signal", signal),
))
if healthCheckScheduler != nil {
Expand Down Expand Up @@ -749,9 +749,13 @@ var runCmd = &cobra.Command{
)
signalsCh := make(chan os.Signal, 1)
signal.Notify(signalsCh, signals...)
go func(pluginRegistry *plugin.Registry,
go func(pluginTimeoutCtx context.Context,
pluginRegistry *plugin.Registry,
logger zerolog.Logger,
servers map[string]*network.Server,
metricsMerger *metrics.Merger,
metricsServer *http.Server,
stopChan chan struct{},
) {
for sig := range signalsCh {
for _, s := range signals {
Expand All @@ -771,13 +775,14 @@ var runCmd = &cobra.Command{
}
}
}
}(pluginRegistry, logger, servers)
}(pluginTimeoutCtx, pluginRegistry, logger, servers, metricsMerger, metricsServer, stopChan)

_, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers")
// Start the server.
for name, server := range servers {
logger := loggers[name]
go func(
span trace.Span,
server *network.Server,
logger zerolog.Logger,
healthCheckScheduler *gocron.Scheduler,
Expand All @@ -797,7 +802,7 @@ var runCmd = &cobra.Command{
pluginRegistry.Shutdown()
os.Exit(gerr.FailedToStartServer)
}
}(server, logger, healthCheckScheduler, metricsMerger, pluginRegistry)
}(span, server, logger, healthCheckScheduler, metricsMerger, pluginRegistry)
}
span.End()

Expand Down
10 changes: 10 additions & 0 deletions cmd/test_plugins.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
acceptancePolicy: accept
compatibilityPolicy: strict
enableMetricsMerger: true
healthCheckPeriod: 5s
metricsMergerPeriod: 5s
plugins: []
reloadOnCrash: true
terminationPolicy: stop
timeout: 30s
verificationPolicy: passdown
3 changes: 3 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (
ErrCodeServerReceiveFailed
ErrCodeServerSendFailed
ErrCodeServerListenFailed
ErrCodeServerCloseFailed
ErrCodeSplitHostPortFailed
ErrCodeAcceptFailed
ErrCodeReadFailed
Expand Down Expand Up @@ -87,6 +88,8 @@ var (
ErrCodeSplitHostPortFailed, "failed to split host:port", nil)
ErrAcceptFailed = NewGatewayDError(
ErrCodeAcceptFailed, "failed to accept connection", nil)
ErrServerCloseFailed = NewGatewayDError(
ErrCodeServerCloseFailed, "failed to close server", nil)

ErrReadFailed = NewGatewayDError(
ErrCodeReadFailed, "failed to read from the client", nil)
Expand Down
11 changes: 8 additions & 3 deletions metrics/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ var (
Name: "proxied_connections",
Help: "Number of proxy connects",
})
ProxyPassThroughs = promauto.NewCounter(prometheus.CounterOpts{
ProxyPassThroughsToClient = promauto.NewCounter(prometheus.CounterOpts{
Namespace: Namespace,
Name: "proxy_passthroughs_total",
Help: "Number of successful proxy passthroughs",
Name: "proxy_passthroughs_to_client_total",
Help: "Number of successful proxy passthroughs from server to client",
})
ProxyPassThroughsToServer = promauto.NewCounter(prometheus.CounterOpts{
Namespace: Namespace,
Name: "proxy_passthroughs_to_server_total",
Help: "Number of successful proxy passthroughs from client to server",
})
ProxyPassThroughTerminations = promauto.NewCounter(prometheus.CounterOpts{
Namespace: Namespace,
Expand Down
75 changes: 44 additions & 31 deletions network/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
"sync/atomic"
"time"

"github.com/gatewayd-io/gatewayd/config"
Expand All @@ -26,8 +27,9 @@ type IClient interface {
type Client struct {
net.Conn

logger zerolog.Logger
ctx context.Context //nolint:containedctx
logger zerolog.Logger
ctx context.Context //nolint:containedctx
connected atomic.Bool

TCPKeepAlive bool
TCPKeepAlivePeriod time.Duration
Expand All @@ -53,6 +55,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
return nil
}

client.connected.Store(false)
client.logger = logger

// Try to resolve the address and log an error if it can't be resolved.
Expand Down Expand Up @@ -87,6 +90,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
}

client.Conn = conn
client.connected.Store(true)

// Set the TCP keep alive.
client.TCPKeepAlive = clientConfig.TCPKeepAlive
Expand Down Expand Up @@ -146,6 +150,11 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) {
_, span := otel.Tracer(config.TracerName).Start(c.ctx, "Send")
defer span.End()

if !c.connected.Load() {
span.RecordError(gerr.ErrClientNotConnected)
return 0, gerr.ErrClientNotConnected
}

sent := 0
received := len(data)
for {
Expand All @@ -170,8 +179,6 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) {
},
).Msg("Sent data to server")

metrics.BytesSentToServer.Observe(float64(sent))

return sent, nil
}

Expand All @@ -180,6 +187,11 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
_, span := otel.Tracer(config.TracerName).Start(c.ctx, "Receive")
defer span.End()

if !c.connected.Load() {
span.RecordError(gerr.ErrClientNotConnected)
return 0, nil, gerr.ErrClientNotConnected
}

var ctx context.Context
var cancel context.CancelFunc
if c.ReceiveTimeout > 0 {
Expand All @@ -192,26 +204,21 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
var received int
buffer := bytes.NewBuffer(nil)
// Read the data in chunks.
select { //nolint:gosimple
case <-time.After(time.Millisecond):
for ctx.Err() == nil {
chunk := make([]byte, c.ReceiveChunkSize)
read, err := c.Conn.Read(chunk)
if err != nil {
c.logger.Error().Err(err).Msg("Couldn't receive data from the server")
span.RecordError(err)
metrics.BytesReceivedFromServer.Observe(float64(received))
return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
}
received += read
buffer.Write(chunk[:read])
for ctx.Err() == nil {
chunk := make([]byte, c.ReceiveChunkSize)
read, err := c.Conn.Read(chunk)
if err != nil {
c.logger.Error().Err(err).Msg("Couldn't receive data from the server")
span.RecordError(err)
return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
}
received += read
buffer.Write(chunk[:read])

if read == 0 || read < c.ReceiveChunkSize {
break
}
if read == 0 || read < c.ReceiveChunkSize {
break
}
}
metrics.BytesReceivedFromServer.Observe(float64(received))
return received, buffer.Bytes(), nil
}

Expand All @@ -220,6 +227,12 @@ func (c *Client) Close() {
_, span := otel.Tracer(config.TracerName).Start(c.ctx, "Close")
defer span.End()

// Set the deadline to now so that the connection is closed immediately.
if err := c.Conn.SetDeadline(time.Now()); err != nil {
c.logger.Error().Err(err).Msg("Failed to set deadline")
span.RecordError(err)
}

c.logger.Debug().Str("address", c.Address).Msg("Closing connection to server")
if c.Conn != nil {
c.Conn.Close()
Expand All @@ -228,6 +241,7 @@ func (c *Client) Close() {
c.Conn = nil
c.Address = ""
c.Network = ""
c.connected.Store(false)

metrics.ServerConnections.Dec()
}
Expand Down Expand Up @@ -257,20 +271,15 @@ func (c *Client) IsConnected() bool {
return false
}

if n, err := c.Read([]byte{}); n == 0 && err != nil {
c.logger.Debug().Fields(
map[string]interface{}{
"address": c.Address,
"reason": "read 0 bytes",
}).Msg("Connection to server is closed")
return false
}

return true
return c.connected.Load()
}

// RemoteAddr returns the remote address of the client safely.
func (c *Client) RemoteAddr() string {
if !c.connected.Load() {
return ""
}

if c.Conn != nil && c.Conn.RemoteAddr() != nil {
return c.Conn.RemoteAddr().String()
}
Expand All @@ -280,6 +289,10 @@ func (c *Client) RemoteAddr() string {

// LocalAddr returns the local address of the client safely.
func (c *Client) LocalAddr() string {
if !c.connected.Load() {
return ""
}

if c.Conn != nil && c.Conn.LocalAddr() != nil {
return c.Conn.LocalAddr().String()
}
Expand Down
Loading

0 comments on commit 7cf63b5

Please sign in to comment.