Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa committed Oct 11, 2023
1 parent 74856a0 commit 227a6ac
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 115 deletions.
15 changes: 10 additions & 5 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ func StopGracefully(
}
}

logger.Info().Msg("Stopping GatewayD")
span.AddEvent("Stopping GatewayD", trace.WithAttributes(
logger.Info().Msg("GatewayD is shutting down")
span.AddEvent("GatewayD is shutting down", trace.WithAttributes(
attribute.String("signal", signal),
))
if healthCheckScheduler != nil {
Expand Down Expand Up @@ -749,9 +749,13 @@ var runCmd = &cobra.Command{
)
signalsCh := make(chan os.Signal, 1)
signal.Notify(signalsCh, signals...)
go func(pluginRegistry *plugin.Registry,
go func(pluginTimeoutCtx context.Context,
pluginRegistry *plugin.Registry,
logger zerolog.Logger,
servers map[string]*network.Server,
metricsMerger *metrics.Merger,
metricsServer *http.Server,
stopChan chan struct{},
) {
for sig := range signalsCh {
for _, s := range signals {
Expand All @@ -771,13 +775,14 @@ var runCmd = &cobra.Command{
}
}
}
}(pluginRegistry, logger, servers)
}(pluginTimeoutCtx, pluginRegistry, logger, servers, metricsMerger, metricsServer, stopChan)

_, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers")
// Start the server.
for name, server := range servers {
logger := loggers[name]
go func(
span trace.Span,
server *network.Server,
logger zerolog.Logger,
healthCheckScheduler *gocron.Scheduler,
Expand All @@ -797,7 +802,7 @@ var runCmd = &cobra.Command{
pluginRegistry.Shutdown()
os.Exit(gerr.FailedToStartServer)
}
}(server, logger, healthCheckScheduler, metricsMerger, pluginRegistry)
}(span, server, logger, healthCheckScheduler, metricsMerger, pluginRegistry)
}
span.End()

Expand Down
3 changes: 3 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (
ErrCodeServerReceiveFailed
ErrCodeServerSendFailed
ErrCodeServerListenFailed
ErrCodeServerCloseFailed
ErrCodeSplitHostPortFailed
ErrCodeAcceptFailed
ErrCodeReadFailed
Expand Down Expand Up @@ -87,6 +88,8 @@ var (
ErrCodeSplitHostPortFailed, "failed to split host:port", nil)
ErrAcceptFailed = NewGatewayDError(
ErrCodeAcceptFailed, "failed to accept connection", nil)
ErrServerCloseFailed = NewGatewayDError(
ErrCodeServerCloseFailed, "failed to close server", nil)

ErrReadFailed = NewGatewayDError(
ErrCodeReadFailed, "failed to read from the client", nil)
Expand Down
11 changes: 8 additions & 3 deletions metrics/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ var (
Name: "proxied_connections",
Help: "Number of proxy connects",
})
ProxyPassThroughs = promauto.NewCounter(prometheus.CounterOpts{
ProxyPassThroughsToClient = promauto.NewCounter(prometheus.CounterOpts{
Namespace: Namespace,
Name: "proxy_passthroughs_total",
Help: "Number of successful proxy passthroughs",
Name: "proxy_passthroughs_to_client_total",
Help: "Number of successful proxy passthroughs from server to client",
})
ProxyPassThroughsToServer = promauto.NewCounter(prometheus.CounterOpts{
Namespace: Namespace,
Name: "proxy_passthroughs_to_server_total",
Help: "Number of successful proxy passthroughs from client to server",
})
ProxyPassThroughTerminations = promauto.NewCounter(prometheus.CounterOpts{
Namespace: Namespace,
Expand Down
75 changes: 44 additions & 31 deletions network/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
"sync/atomic"
"time"

"github.com/gatewayd-io/gatewayd/config"
Expand All @@ -26,8 +27,9 @@ type IClient interface {
type Client struct {
net.Conn

logger zerolog.Logger
ctx context.Context //nolint:containedctx
logger zerolog.Logger
ctx context.Context //nolint:containedctx
connected atomic.Bool

TCPKeepAlive bool
TCPKeepAlivePeriod time.Duration
Expand All @@ -53,6 +55,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
return nil
}

client.connected.Store(false)
client.logger = logger

// Try to resolve the address and log an error if it can't be resolved.
Expand Down Expand Up @@ -87,6 +90,7 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.
}

client.Conn = conn
client.connected.Store(true)

// Set the TCP keep alive.
client.TCPKeepAlive = clientConfig.TCPKeepAlive
Expand Down Expand Up @@ -146,6 +150,11 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) {
_, span := otel.Tracer(config.TracerName).Start(c.ctx, "Send")
defer span.End()

if !c.connected.Load() {
span.RecordError(gerr.ErrClientNotConnected)
return 0, gerr.ErrClientNotConnected
}

sent := 0
received := len(data)
for {
Expand All @@ -170,8 +179,6 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) {
},
).Msg("Sent data to server")

metrics.BytesSentToServer.Observe(float64(sent))

return sent, nil
}

Expand All @@ -180,6 +187,11 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
_, span := otel.Tracer(config.TracerName).Start(c.ctx, "Receive")
defer span.End()

if !c.connected.Load() {
span.RecordError(gerr.ErrClientNotConnected)
return 0, nil, gerr.ErrClientNotConnected
}

var ctx context.Context
var cancel context.CancelFunc
if c.ReceiveTimeout > 0 {
Expand All @@ -192,26 +204,21 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
var received int
buffer := bytes.NewBuffer(nil)
// Read the data in chunks.
select { //nolint:gosimple
case <-time.After(time.Millisecond):
for ctx.Err() == nil {
chunk := make([]byte, c.ReceiveChunkSize)
read, err := c.Conn.Read(chunk)
if err != nil {
c.logger.Error().Err(err).Msg("Couldn't receive data from the server")
span.RecordError(err)
metrics.BytesReceivedFromServer.Observe(float64(received))
return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
}
received += read
buffer.Write(chunk[:read])
for ctx.Err() == nil {
chunk := make([]byte, c.ReceiveChunkSize)
read, err := c.Conn.Read(chunk)
if err != nil {
c.logger.Error().Err(err).Msg("Couldn't receive data from the server")
span.RecordError(err)
return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
}
received += read
buffer.Write(chunk[:read])

if read == 0 || read < c.ReceiveChunkSize {
break
}
if read == 0 || read < c.ReceiveChunkSize {
break
}
}
metrics.BytesReceivedFromServer.Observe(float64(received))
return received, buffer.Bytes(), nil
}

Expand All @@ -220,6 +227,12 @@ func (c *Client) Close() {
_, span := otel.Tracer(config.TracerName).Start(c.ctx, "Close")
defer span.End()

// Set the deadline to now so that the connection is closed immediately.
if err := c.Conn.SetDeadline(time.Now()); err != nil {
c.logger.Error().Err(err).Msg("Failed to set deadline")
span.RecordError(err)
}

c.logger.Debug().Str("address", c.Address).Msg("Closing connection to server")
if c.Conn != nil {
c.Conn.Close()
Expand All @@ -228,6 +241,7 @@ func (c *Client) Close() {
c.Conn = nil
c.Address = ""
c.Network = ""
c.connected.Store(false)

metrics.ServerConnections.Dec()
}
Expand Down Expand Up @@ -257,20 +271,15 @@ func (c *Client) IsConnected() bool {
return false
}

if n, err := c.Read([]byte{}); n == 0 && err != nil {
c.logger.Debug().Fields(
map[string]interface{}{
"address": c.Address,
"reason": "read 0 bytes",
}).Msg("Connection to server is closed")
return false
}

return true
return c.connected.Load()
}

// RemoteAddr returns the remote address of the client safely.
func (c *Client) RemoteAddr() string {
if !c.connected.Load() {
return ""
}

if c.Conn != nil && c.Conn.RemoteAddr() != nil {
return c.Conn.RemoteAddr().String()
}
Expand All @@ -280,6 +289,10 @@ func (c *Client) RemoteAddr() string {

// LocalAddr returns the local address of the client safely.
func (c *Client) LocalAddr() string {
if !c.connected.Load() {
return ""
}

if c.Conn != nil && c.Conn.LocalAddr() != nil {
return c.Conn.LocalAddr().String()
}
Expand Down
95 changes: 60 additions & 35 deletions network/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/gatewayd-io/gatewayd/config"
Expand Down Expand Up @@ -35,6 +36,7 @@ type Engine struct {
host string
port int
connections uint32
running *atomic.Bool
stopServer chan struct{}
mu *sync.RWMutex
}
Expand All @@ -49,6 +51,11 @@ func (engine *Engine) Stop(ctx context.Context) error {
_, cancel := context.WithDeadline(ctx, time.Now().Add(config.DefaultEngineStopTimeout))
defer cancel()

engine.running.Store(false)
if err := engine.listener.Close(); err != nil {
engine.stopServer <- struct{}{}
return gerr.ErrServerCloseFailed.Wrap(err)
}
engine.stopServer <- struct{}{}
return nil
}
Expand All @@ -59,6 +66,7 @@ func Run(network, address string, server *Server) *gerr.GatewayDError {
connections: 0,
stopServer: make(chan struct{}),
mu: &sync.RWMutex{},
running: &atomic.Bool{},
}

if action := server.OnBoot(server.engine); action != None {
Expand All @@ -71,7 +79,6 @@ func Run(network, address string, server *Server) *gerr.GatewayDError {
server.logger.Error().Err(err).Msg("Server failed to start listening")
return gerr.ErrServerListenFailed.Wrap(err)
}
defer server.engine.listener.Close()

if server.engine.listener == nil {
server.logger.Error().Msg("Server is not properly initialized")
Expand Down Expand Up @@ -114,44 +121,62 @@ func Run(network, address string, server *Server) *gerr.GatewayDError {
}
}(server)

for {
conn, err := server.engine.listener.Accept()
if err != nil {
server.logger.Error().Err(err).Msg("Failed to accept connection")
return gerr.ErrAcceptFailed.Wrap(err)
}
server.engine.running.Store(true)

if out, action := server.OnOpen(conn); action != None {
if _, err := conn.Write(out); err != nil {
server.logger.Error().Err(err).Msg("Failed to write to connection")
}
conn.Close()
if action == Shutdown {
server.OnShutdown(server.engine)
return nil
}
}
server.engine.mu.Lock()
server.engine.connections++
server.engine.mu.Unlock()

// For every new connection, a new unbuffered channel is created to help
// stop the proxy, recycle the server connection and close stale connections.
stopConnection := make(chan struct{})
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
if action := server.OnTraffic(conn, stopConnection); action == Close {
return
for {
select {
case <-server.engine.stopServer:
server.logger.Info().Msg("Server stopped")
return nil
default:
conn, err := server.engine.listener.Accept()
if err != nil {
if !server.engine.running.Load() {
return nil
}
server.logger.Error().Err(err).Msg("Failed to accept connection")
return gerr.ErrAcceptFailed.Wrap(err)
}
}(server, conn, stopConnection)

go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
<-stopConnection
if out, action := server.OnOpen(conn); action != None {
if _, err := conn.Write(out); err != nil {
server.logger.Error().Err(err).Msg("Failed to write to connection")
}
conn.Close()
if action == Shutdown {
server.OnShutdown(server.engine)
return nil
}
}
server.engine.mu.Lock()
server.engine.connections--
server.engine.connections++
server.engine.mu.Unlock()
if action := server.OnClose(conn, err); action == Close {
return
}
}(server, conn, stopConnection)

// For every new connection, a new unbuffered channel is created to help
// stop the proxy, recycle the server connection and close stale connections.
stopConnection := make(chan struct{})
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
if action := server.OnTraffic(conn, stopConnection); action == Close {
return
}
}(server, conn, stopConnection)

go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
for {
select {
case <-stopConnection:
server.engine.mu.Lock()
server.engine.connections--
server.engine.mu.Unlock()
if action := server.OnClose(conn, err); action == Close {
return
}
return
case <-server.engine.stopServer:
return
}
}
}(server, conn, stopConnection)
}
}
}
Loading

0 comments on commit 227a6ac

Please sign in to comment.