Skip to content

Commit

Permalink
Merge pull request #335 from gatewayd-io/shutdown-metrics-server-grac…
Browse files Browse the repository at this point in the history
…efully

Shutdown metrics server gracefully
  • Loading branch information
mostafa authored Sep 24, 2023
2 parents cb97fbe + 2b6f503 commit f4e6be2
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 13 deletions.
37 changes: 32 additions & 5 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -54,6 +55,7 @@ var (
globalConfigFile string
conf *config.Config
pluginRegistry *plugin.Registry
metricsServer *http.Server

UsageReportURL = "localhost:59091"

Expand All @@ -72,6 +74,7 @@ func StopGracefully(
pluginTimeoutCtx context.Context,
sig os.Signal,
metricsMerger *metrics.Merger,
metricsServer *http.Server,
pluginRegistry *plugin.Registry,
logger zerolog.Logger,
servers map[string]*network.Server,
Expand Down Expand Up @@ -110,6 +113,16 @@ func StopGracefully(
logger.Info().Msg("Stopped metrics merger")
span.AddEvent("Stopped metrics merger")
}
if metricsServer != nil {
//nolint:contextcheck
if err := metricsServer.Shutdown(context.Background()); err != nil {
logger.Error().Err(err).Msg("Failed to stop metrics server")
span.RecordError(err)
} else {
logger.Info().Msg("Stopped metrics server")
span.AddEvent("Stopped metrics server")
}
}
for name, server := range servers {
logger.Info().Str("name", name).Msg("Stopping server")
server.Shutdown() //nolint:contextcheck
Expand Down Expand Up @@ -352,7 +365,14 @@ var runCmd = &cobra.Command{
span.RecordError(err)
sentry.CaptureException(err)
}
next.ServeHTTP(responseWriter, request)
// The WriteHeader method intentionally does nothing, to prevent a bug
// in the merging metrics that causes the headers to be written twice,
// which results in an error: "http: superfluous response.WriteHeader call".
next.ServeHTTP(
&metrics.HeaderBypassResponseWriter{
ResponseWriter: responseWriter,
},
request)
}
return http.HandlerFunc(handler)
}
Expand All @@ -371,6 +391,7 @@ var runCmd = &cobra.Command{
if conf.Plugin.EnableMetricsMerger && metricsMerger != nil {
handler = mergedMetricsHandler(handler)
}

// Check if the metrics server is already running before registering the handler.
if _, err = http.Get(address); err != nil { //nolint:gosec
http.Handle(metricsConfig.Path, gziphandler.GzipHandler(handler))
Expand All @@ -379,16 +400,21 @@ var runCmd = &cobra.Command{
span.RecordError(err)
}

//nolint:gosec
if err = http.ListenAndServe(
metricsConfig.Address, nil); err != nil {
// Create a new metrics server.
metricsServer = &http.Server{
Addr: metricsConfig.Address,
Handler: handler,
ReadHeaderTimeout: metricsConfig.GetReadHeaderTimeout(),
}

// Start the metrics server.
if err = metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
logger.Error().Err(err).Msg("Failed to start metrics server")
span.RecordError(err)
}
}(conf.Global.Metrics[config.Default], logger)

// This is a notification hook, so we don't care about the result.
// TODO: Use a context with a timeout
if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]interface{}); ok {
_, err = pluginRegistry.Run(
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER)
Expand Down Expand Up @@ -723,6 +749,7 @@ var runCmd = &cobra.Command{
pluginTimeoutCtx,
sig,
metricsMerger,
metricsServer,
pluginRegistry,
logger,
servers,
Expand Down
2 changes: 2 additions & 0 deletions cmd/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func Test_runCmd(t *testing.T) {
nil,
nil,
nil,
nil,
loggers[config.Default],
servers,
stopChan,
Expand Down Expand Up @@ -115,6 +116,7 @@ func Test_runCmdWithCachePlugin(t *testing.T) {
nil,
nil,
nil,
nil,
loggers[config.Default],
servers,
stopChan,
Expand Down
7 changes: 4 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ func (c *Config) LoadDefaults(ctx context.Context) {
}

defaultMetric := Metrics{
Enabled: true,
Address: DefaultMetricsAddress,
Path: DefaultMetricsPath,
Enabled: true,
Address: DefaultMetricsAddress,
Path: DefaultMetricsPath,
ReadHeaderTimeout: DefaultReadHeaderTimeout,
}

defaultClient := Client{
Expand Down
5 changes: 3 additions & 2 deletions config/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ const (
ChecksumBufferSize = 65536

// Metrics constants.
DefaultMetricsAddress = "localhost:9090"
DefaultMetricsPath = "/metrics"
DefaultMetricsAddress = "localhost:9090"
DefaultMetricsPath = "/metrics"
DefaultReadHeaderTimeout = 10 * time.Second

// Sentry constants.
DefaultTraceSampleRate = 0.2
Expand Down
7 changes: 7 additions & 0 deletions config/getters.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,10 @@ func GetDefaultConfigFilePath(filename string) string {
// The fallback is the current directory.
return filepath.Join("./", filename)
}

func (m Metrics) GetReadHeaderTimeout() time.Duration {
if m.ReadHeaderTimeout <= 0 {
return DefaultReadHeaderTimeout
}
return m.ReadHeaderTimeout
}
6 changes: 6 additions & 0 deletions config/getters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,9 @@ func TestGetPlugins(t *testing.T) {
func TestGetDefaultConfigFilePath(t *testing.T) {
assert.Equal(t, GlobalConfigFilename, GetDefaultConfigFilePath(GlobalConfigFilename))
}

// TestGetReadTimeout tests the GetReadTimeout function.
func TestGetReadHeaderTimeout(t *testing.T) {
metrics := Metrics{}
assert.Equal(t, DefaultReadHeaderTimeout, metrics.GetReadHeaderTimeout())
}
7 changes: 4 additions & 3 deletions config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ type Logger struct {
}

type Metrics struct {
Enabled bool `json:"enabled"`
Address string `json:"address"`
Path string `json:"path"`
Enabled bool `json:"enabled"`
Address string `json:"address"`
Path string `json:"path"`
ReadHeaderTimeout time.Duration `json:"readHeaderTimeout" jsonschema:"oneof_type=string;integer"`
}

type Pool struct {
Expand Down
1 change: 1 addition & 0 deletions gatewayd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ metrics:
enabled: True
address: localhost:9090
path: /metrics
readHeaderTimeout: 10s # duration, prevents Slowloris attacks

clients:
default:
Expand Down
19 changes: 19 additions & 0 deletions metrics/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package metrics

import "net/http"

// HeaderBypassResponseWriter implements the http.ResponseWriter interface
// and allows us to bypass the response header when writing to the response.
// This is useful for merging metrics from multiple sources.
type HeaderBypassResponseWriter struct {
http.ResponseWriter
}

// WriteHeader intentionally does nothing, but is required to
// implement the http.ResponseWriter.
func (w *HeaderBypassResponseWriter) WriteHeader(int) {}

// Write writes the data to the response.
func (w *HeaderBypassResponseWriter) Write(data []byte) (int, error) {
return w.ResponseWriter.Write(data) //nolint:wrapcheck
}

0 comments on commit f4e6be2

Please sign in to comment.