From aadf9613871ef6986867dcf10e95f792c63b0d47 Mon Sep 17 00:00:00 2001 From: Nash Gadre Date: Tue, 29 Oct 2024 23:22:20 -0400 Subject: [PATCH] update fast server proxy handler --- fast-server/config/config.go | 214 ++++++++++++++++++++++---- fast-server/handlers/proxy_handler.go | 39 ++--- 2 files changed, 201 insertions(+), 52 deletions(-) diff --git a/fast-server/config/config.go b/fast-server/config/config.go index 7986bee..f20519b 100644 --- a/fast-server/config/config.go +++ b/fast-server/config/config.go @@ -14,35 +14,140 @@ import ( var ProductionConfigPath = "/etc/fast/config.yaml" +type ProxyConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Protocol string `yaml:"protocol,omitempty"` + InsecureSkipVerify bool `yaml:"insecure_skip_verify,omitempty"` +} + +func (p *ProxyConfig) setDefaults() { + if p.Protocol == "" { + p.Protocol = "http" + } + // InsecureSkipVerify defaults to false +} + +func (p *ProxyConfig) validate() error { + if p.Host == "" { + return fmt.Errorf("proxy host cannot be empty") + } + if p.Port <= 0 || p.Port > 65535 { + return fmt.Errorf("invalid proxy port: %d", p.Port) + } + if p.Protocol != "http" && p.Protocol != "https" { + return fmt.Errorf("invalid proxy protocol: %s", p.Protocol) + } + return nil +} + +type SSLConfig struct { + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` +} + +func (s *SSLConfig) validate() error { + if s.CertFile == "" || s.KeyFile == "" { + return fmt.Errorf("SSL cert_file and key_file must both be specified") + } + return nil +} + type Domain struct { - Name string `yaml:"name"` - Type string `yaml:"type"` - PublicDir string `yaml:"public_dir"` - Proxy struct { - Host string `yaml:"host"` - Port int `yaml:"port"` - } `yaml:"proxy"` - SSL struct { - CertFile string `yaml:"cert_file"` - KeyFile string `yaml:"key_file"` - } `yaml:"ssl"` + Name string `yaml:"name"` + Type string `yaml:"type"` + PublicDir string `yaml:"public_dir"` + Proxy ProxyConfig `yaml:"proxy"` + SSL SSLConfig `yaml:"ssl"` +} + +func (d *Domain) setDefaults() { + if d.Type == "proxy" { + d.Proxy.setDefaults() + } +} + +func (d *Domain) validate() error { + if d.Name == "" { + return fmt.Errorf("domain name cannot be empty") + } + + validTypes := map[string]bool{ + "static": true, + "proxy": true, + "file_directory": true, + } + if !validTypes[d.Type] { + return fmt.Errorf("invalid domain type: %s", d.Type) + } + + if d.Type != "proxy" && d.PublicDir == "" { + return fmt.Errorf("public_dir is required for type: %s", d.Type) + } + + if d.Type == "proxy" { + if err := d.Proxy.validate(); err != nil { + return fmt.Errorf("proxy configuration error for %s: %v", d.Name, err) + } + } + + return d.SSL.validate() +} + +type ServerConfig struct { + Port int `yaml:"port"` + HTTPPort int `yaml:"http_port"` +} + +func (s *ServerConfig) setDefaults() { + if s.Port == 0 { + s.Port = 443 + } + if s.HTTPPort == 0 { + s.HTTPPort = 80 + } +} + +func (s *ServerConfig) validate() error { + if s.Port <= 0 || s.Port > 65535 { + return fmt.Errorf("invalid server port: %d", s.Port) + } + if s.HTTPPort <= 0 || s.HTTPPort > 65535 { + return fmt.Errorf("invalid HTTP port: %d", s.HTTPPort) + } + return nil +} + +type LogConfig struct { + File string `yaml:"file"` + Level string `yaml:"level"` +} + +func (l *LogConfig) setDefaults() { + if l.Level == "" { + l.Level = "info" + } +} + +func (l *LogConfig) validate() error { + validLevels := map[string]bool{ + "debug": true, + "info": true, + "warn": true, + "error": true, + } + if !validLevels[l.Level] { + return fmt.Errorf("invalid log level: %s", l.Level) + } + return nil } type Config struct { - Server struct { - Port int `yaml:"port"` - HTTPPort int `yaml:"http_port"` - } `yaml:"server"` - Domains []Domain `yaml:"domains"` - GlobalSSL struct { - CertFile string `yaml:"cert_file"` - KeyFile string `yaml:"key_file"` - } `yaml:"global_ssl"` - Log struct { - File string `yaml:"file"` - Level string `yaml:"level"` - } `yaml:"log"` - Settings struct { + Server ServerConfig `yaml:"server"` + Domains []Domain `yaml:"domains"` + GlobalSSL SSLConfig `yaml:"global_ssl"` + Log LogConfig `yaml:"log"` + Settings struct { ReadTimeout string `yaml:"read_timeout"` WriteTimeout string `yaml:"write_timeout"` GracefulShutdownTimeout string `yaml:"graceful_shutdown_timeout"` @@ -50,15 +155,53 @@ type Config struct { IsDevelopment bool `yaml:"is_development"` } +func (c *Config) setDefaults() { + c.Server.setDefaults() + c.Log.setDefaults() + + for i := range c.Domains { + c.Domains[i].setDefaults() + } + + if c.Settings.ReadTimeout == "" { + c.Settings.ReadTimeout = "5s" + } + if c.Settings.WriteTimeout == "" { + c.Settings.WriteTimeout = "10s" + } + if c.Settings.GracefulShutdownTimeout == "" { + c.Settings.GracefulShutdownTimeout = "30s" + } +} + +func (c *Config) validate() error { + if err := c.Server.validate(); err != nil { + return fmt.Errorf("server configuration error: %v", err) + } + + if err := c.Log.validate(); err != nil { + return fmt.Errorf("log configuration error: %v", err) + } + + if len(c.Domains) == 0 { + return fmt.Errorf("at least one domain must be configured") + } + + for _, domain := range c.Domains { + if err := domain.validate(); err != nil { + return fmt.Errorf("domain %s configuration error: %v", domain.Name, err) + } + } + + return nil +} + func isLaunchedByDebugger() bool { - // Check if gops is available _, err := exec.LookPath("gops") if err != nil { - // If gops is not available, fall back to a simple check return strings.Contains(os.Args[0], "debugger") || strings.Contains(os.Args[0], "___go_build_") } - // Use gops to check the parent process gopsOut, err := exec.Command("gops", strconv.Itoa(os.Getppid())).Output() if err != nil { echo.New().Logger.Warnf("Error running gops: %v", err) @@ -74,7 +217,7 @@ func isLaunchedByDebugger() bool { return strings.Contains(gopsOutStr, "/dlv") || strings.Contains(gopsOutStr, "/dlv-dap") || strings.Contains(gopsOutStr, "debugserver") - default: // linux and others + default: return strings.Contains(gopsOutStr, "/dlv") } } @@ -82,10 +225,10 @@ func isLaunchedByDebugger() bool { func LoadConfig() (*Config, error) { var configPath string if isLaunchedByDebugger() { - configPath = "test/config.yaml" // Local path for development + configPath = "test/config.yaml" echo.New().Logger.Info("Debug mode detected. Using local config.yaml") } else { - configPath = ProductionConfigPath // Default production path + configPath = ProductionConfigPath } data, err := ioutil.ReadFile(configPath) @@ -94,10 +237,15 @@ func LoadConfig() (*Config, error) { } var config Config - err = yaml.Unmarshal(data, &config) - if err != nil { + if err := yaml.Unmarshal(data, &config); err != nil { return nil, fmt.Errorf("failed to parse config file: %v", err) } + config.setDefaults() + + if err := config.validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %v", err) + } + return &config, nil } diff --git a/fast-server/handlers/proxy_handler.go b/fast-server/handlers/proxy_handler.go index 5a84c1a..7d37110 100644 --- a/fast-server/handlers/proxy_handler.go +++ b/fast-server/handlers/proxy_handler.go @@ -12,14 +12,13 @@ import ( ) func HandleProxy(c echo.Context, domain config.Domain) error { - - // Determine scheme based on request - scheme := "http" - if c.Request().TLS != nil { - scheme = "https" + // Determine target scheme from config + targetScheme := domain.Proxy.Protocol + if targetScheme == "" { + targetScheme = "http" // Default to HTTP if not specified } - target, err := url.Parse(fmt.Sprintf("%s://%s:%d", scheme, domain.Proxy.Host, domain.Proxy.Port)) + target, err := url.Parse(fmt.Sprintf("%s://%s:%d", targetScheme, domain.Proxy.Host, domain.Proxy.Port)) if err != nil { c.Logger().Errorf("Error parsing proxy URL: %v", err) return echo.ErrInternalServerError @@ -34,7 +33,7 @@ func HandleProxy(c echo.Context, domain config.Domain) error { TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS13, - InsecureSkipVerify: true, // Disabled for Proxy + InsecureSkipVerify: domain.Proxy.InsecureSkipVerify, CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, @@ -52,6 +51,12 @@ func HandleProxy(c echo.Context, domain config.Domain) error { DisableCompression: false, } + // Get original scheme + sourceScheme := "http" + if c.Request().TLS != nil { + sourceScheme = "https" + } + proxyMiddleware := middleware.ProxyWithConfig(middleware.ProxyConfig{ Balancer: middleware.NewRoundRobinBalancer([]*middleware.ProxyTarget{ { @@ -62,26 +67,22 @@ func HandleProxy(c echo.Context, domain config.Domain) error { "/*": "/$1", }, Transport: transport, + ModifyResponse: func(res *http.Response) error { + // Optional: Modify response headers here if needed + return nil + }, }) - // Set proxy headers before proxying the request + // Set proxy headers c.Request().Header.Set("X-Forwarded-Host", originalHost) c.Request().Header.Set("X-Real-IP", c.RealIP()) c.Request().Header.Set("X-Forwarded-For", c.RealIP()) - c.Request().Header.Set("X-Forwarded-Proto", scheme) + c.Request().Header.Set("X-Forwarded-Proto", sourceScheme) - // Important: Keep the original host for the second FAST server + // Keep original host c.Request().Host = originalHost - // Execute the proxy middleware - err = proxyMiddleware(func(c echo.Context) error { + return proxyMiddleware(func(c echo.Context) error { return nil })(c) - - if err != nil { - c.Logger().Errorf("Proxy error: %v", err) - return err - } - - return nil }