Skip to content

Commit

Permalink
Implement Multi-Proxy and Load Balancer Strategy Support
Browse files Browse the repository at this point in the history
This commit introduces significant updates to enhance the handling of proxies within the server configuration and adds support for load balancing strategies.

Changes:
- Updated API tests to reflect changes from a single `Proxy` to a list of `Proxies`.
- Adjusted initialization and configuration of proxies in `run.go` to support multiple proxies and load balancer strategies.
- Updated configuration files to include new fields for multiple proxies and load balancer strategies.
- Enhanced global configuration validation for clients, pools, and proxies.
- Added new `loadBalancer` section in `gatewayd.yaml` for rules and strategies.
- Implemented load balancing strategy selection and Round Robin strategy.
- Added tests for load balancer strategies.
- Added new error type `ErrorCodeLoadBalancerStrategyNotFound`.
- Improved proxy connection handling and added informative comments.

Configuration Example:
- Updated `gatewayd.yaml` to reflect new support for multiple proxies and load balancer strategies.
- Ensure to update your configuration files accordingly.

Testing:
- Updated existing tests and added new tests for multi-proxy and load balancing functionality.
- Verified configuration validation for proxies and load balancers.

Impact:
- Improved flexibility and scalability of server configuration.
- Enabled robust proxy management and efficient load distribution.
  • Loading branch information
sinadarbouy committed Jul 17, 2024
1 parent 18ae198 commit 391ee7d
Show file tree
Hide file tree
Showing 16 changed files with 395 additions and 60 deletions.
2 changes: 1 addition & 1 deletion api/api_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func getAPIConfig() *API {
context.Background(),
network.Server{
Logger: logger,
Proxy: defaultProxy,
Proxies: []network.IProxy{defaultProxy},
PluginRegistry: pluginReg,
PluginTimeout: config.DefaultPluginTimeout,
Network: "tcp",
Expand Down
2 changes: 1 addition & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func TestGetServers(t *testing.T) {
Options: network.Option{
EnableTicker: false,
},
Proxy: proxy,
Proxies: []network.IProxy{proxy},
Logger: zerolog.Logger{},
PluginRegistry: pluginRegistry,
PluginTimeout: config.DefaultPluginTimeout,
Expand Down
2 changes: 1 addition & 1 deletion api/healthcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func Test_Healthchecker(t *testing.T) {
Options: network.Option{
EnableTicker: false,
},
Proxy: proxy,
Proxies: []network.IProxy{proxy},
Logger: zerolog.Logger{},
PluginRegistry: pluginRegistry,
PluginTimeout: config.DefaultPluginTimeout,
Expand Down
29 changes: 21 additions & 8 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,18 @@ var runCmd = &cobra.Command{
// Create and initialize servers.
for name, cfg := range conf.Global.Servers {
logger := loggers[name]

var serverProxies []network.IProxy
for _, proxyName := range cfg.Proxies {
proxy, exists := proxies[proxyName]
if !exists {
// This may occur if a proxy referenced in the server configuration does not exist.
logger.Error().Str("proxyName", proxyName).Msg("failed to find proxy configuration")
return
}
serverProxies = append(serverProxies, proxy)
}

servers[name] = network.NewServer(
runCtx,
network.Server{
Expand All @@ -885,14 +897,15 @@ var runCmd = &cobra.Command{
// Can be used to send keepalive messages to the client.
EnableTicker: cfg.EnableTicker,
},
Proxy: proxies[name],
Logger: logger,
PluginRegistry: pluginRegistry,
PluginTimeout: conf.Plugin.Timeout,
EnableTLS: cfg.EnableTLS,
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
HandshakeTimeout: cfg.HandshakeTimeout,
Proxies: serverProxies,
Logger: logger,
PluginRegistry: pluginRegistry,
PluginTimeout: conf.Plugin.Timeout,
EnableTLS: cfg.EnableTLS,
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
HandshakeTimeout: cfg.HandshakeTimeout,
LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
},
)

Expand Down
62 changes: 49 additions & 13 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ func (c *Config) LoadDefaults(ctx context.Context) *gerr.GatewayDError {
CertFile: "",
KeyFile: "",
HandshakeTimeout: DefaultHandshakeTimeout,
Proxies: []string{Default},
LoadBalancer: LoadBalancer{Strategy: DefaultLoadBalancerStrategy},
}

c.globalDefaults = GlobalConfig{
Expand Down Expand Up @@ -413,7 +415,7 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError {
}

var errors []*gerr.GatewayDError
configObjects := []string{"loggers", "metrics", "clients", "pools", "proxies", "servers"}
configObjects := []string{"loggers", "metrics", "servers"}
sort.Strings(configObjects)
var seenConfigObjects []string

Expand Down Expand Up @@ -441,18 +443,16 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError {
seenConfigObjects = append(seenConfigObjects, "metrics")
}

clientConfigGroups := make(map[string]bool)
for configGroup := range globalConfig.Clients {
clientConfigGroups[configGroup] = true
if globalConfig.Clients[configGroup] == nil {
err := fmt.Errorf("\"clients.%s\" is nil or empty", configGroup)
span.RecordError(err)
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
}
}

if len(globalConfig.Clients) > 1 {
seenConfigObjects = append(seenConfigObjects, "clients")
}

for configGroup := range globalConfig.Pools {
if globalConfig.Pools[configGroup] == nil {
err := fmt.Errorf("\"pools.%s\" is nil or empty", configGroup)
Expand All @@ -461,10 +461,6 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError {
}
}

if len(globalConfig.Pools) > 1 {
seenConfigObjects = append(seenConfigObjects, "pools")
}

for configGroup := range globalConfig.Proxies {
if globalConfig.Proxies[configGroup] == nil {
err := fmt.Errorf("\"proxies.%s\" is nil or empty", configGroup)
Expand All @@ -473,10 +469,6 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError {
}
}

if len(globalConfig.Proxies) > 1 {
seenConfigObjects = append(seenConfigObjects, "proxies")
}

for configGroup := range globalConfig.Servers {
if globalConfig.Servers[configGroup] == nil {
err := fmt.Errorf("\"servers.%s\" is nil or empty", configGroup)
Expand All @@ -489,6 +481,50 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError {
seenConfigObjects = append(seenConfigObjects, "servers")
}

// ValidateClientsPoolsProxies checks if all configGroups in globalConfig.Pools and globalConfig.Proxies
// are referenced in globalConfig.Clients.
if len(globalConfig.Clients) != len(globalConfig.Pools) || len(globalConfig.Clients) != len(globalConfig.Proxies) {
err := goerrors.New("clients, pools, and proxies do not have the same number of objects")
span.RecordError(err)
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
}

// Check if all proxies are referenced in client configuration
for configGroup := range globalConfig.Proxies {
if !clientConfigGroups[configGroup] {
err := fmt.Errorf(`"proxies.%s" not referenced in client configuration`, configGroup)
span.RecordError(err)
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
}
}

// Check if all pools are referenced in client configuration
for configGroup := range globalConfig.Pools {
if !clientConfigGroups[configGroup] {
err := fmt.Errorf(`"pools.%s" not referenced in client configuration`, configGroup)
span.RecordError(err)
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
}
}

// Each server configuration should have at least one proxy defined.
// Each proxy in the server configuration should be referenced in proxies configuration.
for serverName, server := range globalConfig.Servers {
if len(server.Proxies) == 0 {
err := fmt.Errorf(`"servers.%s" has no proxies defined`, serverName)
span.RecordError(err)
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
continue
}
for _, proxyName := range server.Proxies {
if _, exists := c.globalDefaults.Proxies[proxyName]; !exists {
err := fmt.Errorf(`"servers.%s" references a non-existent proxy "%s"`, serverName, proxyName)
span.RecordError(err)
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
}
}
}

sort.Strings(seenConfigObjects)

if len(seenConfigObjects) > 0 && !reflect.DeepEqual(configObjects, seenConfigObjects) {
Expand Down
9 changes: 5 additions & 4 deletions config/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ const (
DefaultHealthCheckPeriod = 60 * time.Second // This must match PostgreSQL authentication timeout.

// Server constants.
DefaultListenNetwork = "tcp"
DefaultListenAddress = "0.0.0.0:15432"
DefaultTickInterval = 5 * time.Second
DefaultHandshakeTimeout = 5 * time.Second
DefaultListenNetwork = "tcp"
DefaultListenAddress = "0.0.0.0:15432"
DefaultTickInterval = 5 * time.Second
DefaultHandshakeTimeout = 5 * time.Second
DefaultLoadBalancerStrategy = "ROUND_ROBIN"

// Utility constants.
DefaultSeed = 1000
Expand Down
6 changes: 6 additions & 0 deletions config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ type Proxy struct {
HealthCheckPeriod time.Duration `json:"healthCheckPeriod" jsonschema:"oneof_type=string;integer"`
}

type LoadBalancer struct {
Strategy string `json:"strategy"`
}

type Server struct {
EnableTicker bool `json:"enableTicker"`
TickInterval time.Duration `json:"tickInterval" jsonschema:"oneof_type=string;integer"`
Expand All @@ -105,6 +109,8 @@ type Server struct {
CertFile string `json:"certFile"`
KeyFile string `json:"keyFile"`
HandshakeTimeout time.Duration `json:"handshakeTimeout" jsonschema:"oneof_type=string;integer"`
Proxies []string `json:"proxies"`
LoadBalancer LoadBalancer `json:"loadBalancer"`
}

type API struct {
Expand Down
5 changes: 5 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const (
ErrCodeMsgEncodeError
ErrCodeConfigParseError
ErrCodePublishAsyncAction
ErrorCodeLoadBalancerStrategyNotFound
)

var (
Expand Down Expand Up @@ -194,6 +195,10 @@ var (
ErrCodePublishAsyncAction, "error publishing async action", nil,
}

ErrLoadBalancerStrategyNotFound = &GatewayDError{
ErrorCodeLoadBalancerStrategyNotFound, "The specified load balancer strategy does not exist.", nil,
}

// Unwrapped errors.
ErrLoggerRequired = errors.New("terminate action requires a logger parameter")
)
Expand Down
32 changes: 32 additions & 0 deletions gatewayd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,51 @@ clients:
backoff: 1s # duration
backoffMultiplier: 2.0 # 0 means no backoff
disableBackoffCaps: false
default-2:
network: tcp
address: localhost:5433
tcpKeepAlive: False
tcpKeepAlivePeriod: 30s # duration
receiveChunkSize: 8192
receiveDeadline: 0s # duration, 0ms/0s means no deadline
receiveTimeout: 0s # duration, 0ms/0s means no timeout
sendDeadline: 0s # duration, 0ms/0s means no deadline
dialTimeout: 60s # duration
# Retry configuration
retries: 3 # 0 means no retry and fail immediately on the first attempt
backoff: 1s # duration
backoffMultiplier: 2.0 # 0 means no backoff
disableBackoffCaps: false

pools:
default:
size: 10
default-2:
size: 10

proxies:
default:
healthCheckPeriod: 60s # duration
default-2:
healthCheckPeriod: 60s # duration

servers:
default:
network: tcp
address: 0.0.0.0:15432
proxies:
- "default"
- "default-2"
loadBalancer:
strategy: ROUND_ROBIN
# Not yet implemented.
# loadBalancingRules:
# - condition: "default"
# percentages:
# - proxy: "default"
# percentage: 70
# - proxy: "default-2"
# percentage: 30
enableTicker: False
tickInterval: 5s # duration
enableTLS: False
Expand Down
6 changes: 6 additions & 0 deletions network/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package network

// Load balancing strategies.
const (
RoundRobinStrategy = "ROUND_ROBIN"
)
18 changes: 18 additions & 0 deletions network/loadbalancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package network

import (
gerr "github.com/gatewayd-io/gatewayd/errors"
)

type LoadBalancerStrategy interface {
GetNextProxy() IProxy
}

func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) {
switch server.LoadbalancerStrategyName {
case RoundRobinStrategy:
return NewRoundRobin(server), nil
default:
return nil, gerr.ErrLoadBalancerStrategyNotFound
}
}
40 changes: 40 additions & 0 deletions network/loadbalancer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package network

import (
"errors"
"testing"

gerr "github.com/gatewayd-io/gatewayd/errors"
)

func TestNewLoadBalancerStrategy(t *testing.T) {
serverValid := &Server{
LoadbalancerStrategyName: RoundRobinStrategy,
Proxies: []IProxy{MockProxy{}},
}

// Test case 1: Valid strategy name
strategy, err := NewLoadBalancerStrategy(serverValid)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}

_, ok := strategy.(*RoundRobin)
if !ok {
t.Errorf("Expected strategy to be of type RoundRobin")
}

// Test case 2: InValid strategy name
serverInvalid := &Server{
LoadbalancerStrategyName: "InvalidStrategy",
Proxies: []IProxy{MockProxy{}},
}

strategy, err = NewLoadBalancerStrategy(serverInvalid)
if !errors.Is(err, gerr.ErrLoadBalancerStrategyNotFound) {
t.Errorf("Expected ErrLoadBalancerStrategyNotFound, got %v", err)
}
if strategy != nil {
t.Errorf("Expected strategy to be nil for invalid strategy name")
}
}
20 changes: 20 additions & 0 deletions network/round-robin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package network

import "sync/atomic"

type RoundRobin struct {
proxies []IProxy
next atomic.Uint32
}

func NewRoundRobin(server *Server) *RoundRobin {
return &RoundRobin{proxies: server.Proxies}
}

func (r *RoundRobin) GetNextProxy() IProxy {
proxiesLen := uint32(len(r.proxies))

nextIndex := r.next.Add(1)

return r.proxies[nextIndex%proxiesLen]
}
Loading

0 comments on commit 391ee7d

Please sign in to comment.