Skip to content

Commit

Permalink
Move Run function to server.Run
Browse files Browse the repository at this point in the history
Check if listener is not nil before closing it
  • Loading branch information
mostafa committed Oct 15, 2023
1 parent 068f03c commit 91998ea
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 143 deletions.
6 changes: 0 additions & 6 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ const (
ErrCodeNetworkNotSupported
ErrCodeResolveFailed
ErrCodePoolExhausted
ErrCodeStartServerFailed
ErrCodePluginNotFound
ErrCodePluginNotReady
ErrCodeStartPluginFailed
Expand All @@ -23,7 +22,6 @@ const (
ErrCodeServerSendFailed
ErrCodeServerListenFailed
ErrCodeSplitHostPortFailed
ErrCodeCloseListenerFailed
ErrCodeAcceptFailed
ErrCodeReadFailed
ErrCodePutFailed
Expand Down Expand Up @@ -56,8 +54,6 @@ var (
ErrCodeResolveFailed, "failed to resolve address", nil)
ErrPoolExhausted = NewGatewayDError(
ErrCodePoolExhausted, "pool is exhausted", nil)
ErrFailedToStartServer = NewGatewayDError(
ErrCodeStartServerFailed, "failed to start server", nil)

ErrPluginNotFound = NewGatewayDError(
ErrCodePluginNotFound, "plugin not found", nil)
Expand Down Expand Up @@ -87,8 +83,6 @@ var (
ErrCodeServerListenFailed, "couldn't listen on the server", nil)
ErrSplitHostPortFailed = NewGatewayDError(
ErrCodeSplitHostPortFailed, "failed to split host:port", nil)
ErrCloseListenerFailed = NewGatewayDError(
ErrCodeCloseListenerFailed, "failed to close listener", nil)
ErrAcceptFailed = NewGatewayDError(
ErrCodeAcceptFailed, "failed to accept connection", nil)

Expand Down
138 changes: 8 additions & 130 deletions network/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ package network
import (
"context"
"net"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/rs/zerolog"
)

type Option struct {
Expand All @@ -36,6 +35,7 @@ type Engine struct {
host string
port int
connections uint32
logger zerolog.Logger
running *atomic.Bool
stopServer chan struct{}
mu *sync.RWMutex
Expand All @@ -52,136 +52,14 @@ func (engine *Engine) Stop(ctx context.Context) error {
defer cancel()

engine.running.Store(false)
if err := engine.listener.Close(); err != nil {
engine.stopServer <- struct{}{}
close(engine.stopServer)
return gerr.ErrCloseListenerFailed.Wrap(err)
if engine.listener != nil {
if err := engine.listener.Close(); err != nil {
engine.logger.Error().Err(err).Msg("Failed to close listener")
}
} else {
engine.logger.Error().Msg("Listener is not initialized")
}
engine.stopServer <- struct{}{}
close(engine.stopServer)
return nil
}

// Run starts a server and connects all the handlers.
func Run(network, address string, server *Server) *gerr.GatewayDError {
server.engine = Engine{
connections: 0,
stopServer: make(chan struct{}),
mu: &sync.RWMutex{},
running: &atomic.Bool{},
}

if action := server.OnBoot(server.engine); action != None {
return nil
}

var err error
server.engine.listener, err = net.Listen(network, address)
if err != nil {
server.logger.Error().Err(err).Msg("Server failed to start listening")
return gerr.ErrServerListenFailed.Wrap(err)
}

if server.engine.listener == nil {
server.logger.Error().Msg("Server is not properly initialized")
return nil
}

var port string
server.engine.host, port, err = net.SplitHostPort(server.engine.listener.Addr().String())
if err != nil {
server.logger.Error().Err(err).Msg("Failed to split host and port")
return gerr.ErrSplitHostPortFailed.Wrap(err)
}

if server.engine.port, err = strconv.Atoi(port); err != nil {
server.logger.Error().Err(err).Msg("Failed to convert port to integer")
return gerr.ErrCastFailed.Wrap(err)
}

go func(server *Server) {
<-server.engine.stopServer
server.OnShutdown()
server.logger.Debug().Msg("Server stopped")
}(server)

go func(server *Server) {
if !server.Options.EnableTicker {
return
}

for {
select {
case <-server.engine.stopServer:
return
default:
interval, action := server.OnTick()
if action == Shutdown {
server.OnShutdown()
return
}
if interval == time.Duration(0) {
return
}
time.Sleep(interval)
}
}
}(server)

server.engine.running.Store(true)

for {
select {
case <-server.engine.stopServer:
server.logger.Info().Msg("Server stopped")
return nil
default:
conn, err := server.engine.listener.Accept()
if err != nil {
if !server.engine.running.Load() {
return nil
}
server.logger.Error().Err(err).Msg("Failed to accept connection")
return gerr.ErrAcceptFailed.Wrap(err)
}

if out, action := server.OnOpen(conn); action != None {
if _, err := conn.Write(out); err != nil {
server.logger.Error().Err(err).Msg("Failed to write to connection")
}
conn.Close()
if action == Shutdown {
server.OnShutdown()
return nil
}
}
server.engine.mu.Lock()
server.engine.connections++
server.engine.mu.Unlock()

// For every new connection, a new unbuffered channel is created to help
// stop the proxy, recycle the server connection and close stale connections.
stopConnection := make(chan struct{})
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
if action := server.OnTraffic(conn, stopConnection); action == Close {
stopConnection <- struct{}{}
}
}(server, conn, stopConnection)

go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
for {
select {
case <-stopConnection:
server.engine.mu.Lock()
server.engine.connections--
server.engine.mu.Unlock()
server.OnClose(conn, err)
return
case <-server.engine.stopServer:
return
}
}
}(server, conn, stopConnection)
}
}
}
129 changes: 122 additions & 7 deletions network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"net"
"os"
"strconv"
"sync"
"sync/atomic"
"time"

v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
Expand Down Expand Up @@ -341,7 +343,7 @@ func (s *Server) OnTick() (time.Duration, Action) {
}

// Run starts the server and blocks until the server is stopped. It calls the OnRun hooks.
func (s *Server) Run() error {
func (s *Server) Run() *gerr.GatewayDError {
_, span := otel.Tracer("gatewayd").Start(s.ctx, "Run")
defer span.End()

Expand Down Expand Up @@ -381,14 +383,127 @@ func (s *Server) Run() error {
}

// Start the server.
origErr := Run(s.Network, addr, s)
if origErr != nil && origErr.Unwrap() != nil {
s.logger.Error().Err(origErr).Msg("Failed to start server")
span.RecordError(origErr)
return gerr.ErrFailedToStartServer.Wrap(origErr)
s.engine = Engine{
connections: 0,
logger: s.logger,
stopServer: make(chan struct{}),
mu: &sync.RWMutex{},
running: &atomic.Bool{},
}

return nil
if action := s.OnBoot(s.engine); action != None {
return nil
}

listener, origErr := net.Listen(s.Network, addr)
if origErr != nil {
s.logger.Error().Err(origErr).Msg("Server failed to start listening")
return gerr.ErrServerListenFailed.Wrap(origErr)
}
s.engine.listener = listener

if s.engine.listener == nil {
s.logger.Error().Msg("Server is not properly initialized")
return nil
}

var port string
s.engine.host, port, origErr = net.SplitHostPort(s.engine.listener.Addr().String())
if origErr != nil {
s.logger.Error().Err(origErr).Msg("Failed to split host and port")
return gerr.ErrSplitHostPortFailed.Wrap(origErr)
}

if s.engine.port, origErr = strconv.Atoi(port); origErr != nil {
s.logger.Error().Err(origErr).Msg("Failed to convert port to integer")
return gerr.ErrCastFailed.Wrap(origErr)
}

go func(server *Server) {
<-server.engine.stopServer
server.OnShutdown()
server.logger.Debug().Msg("Server stopped")
}(s)

go func(server *Server) {
if !server.Options.EnableTicker {
return
}

for {
select {
case <-server.engine.stopServer:
return
default:
interval, action := server.OnTick()
if action == Shutdown {
server.OnShutdown()
return
}
if interval == time.Duration(0) {
return
}
time.Sleep(interval)
}
}
}(s)

s.engine.running.Store(true)

for {
select {
case <-s.engine.stopServer:
s.logger.Info().Msg("Server stopped")
return nil
default:
conn, err := s.engine.listener.Accept()
if err != nil {
if !s.engine.running.Load() {
return nil
}
s.logger.Error().Err(err).Msg("Failed to accept connection")
return gerr.ErrAcceptFailed.Wrap(err)
}

if out, action := s.OnOpen(conn); action != None {
if _, err := conn.Write(out); err != nil {
s.logger.Error().Err(err).Msg("Failed to write to connection")
}
conn.Close()
if action == Shutdown {
s.OnShutdown()
return nil
}
}
s.engine.mu.Lock()
s.engine.connections++
s.engine.mu.Unlock()

// For every new connection, a new unbuffered channel is created to help
// stop the proxy, recycle the server connection and close stale connections.
stopConnection := make(chan struct{})
go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
if action := server.OnTraffic(conn, stopConnection); action == Close {
stopConnection <- struct{}{}
}
}(s, conn, stopConnection)

go func(server *Server, conn net.Conn, stopConnection chan struct{}) {
for {
select {
case <-stopConnection:
server.engine.mu.Lock()
server.engine.connections--
server.engine.mu.Unlock()
server.OnClose(conn, err)
return
case <-server.engine.stopServer:
return
}
}
}(s, conn, stopConnection)
}
}
}

// Shutdown stops the server.
Expand Down

0 comments on commit 91998ea

Please sign in to comment.