Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
theseanything committed Dec 11, 2024
1 parent aba7a43 commit a338f18
Show file tree
Hide file tree
Showing 16 changed files with 144 additions and 310 deletions.
21 changes: 11 additions & 10 deletions handlers/backend_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ import (
"time"

"github.com/prometheus/client_golang/prometheus"

"github.com/alphagov/router/logger"
"github.com/rs/zerolog"
)

var TLSSkipVerify bool
Expand All @@ -24,7 +23,7 @@ func NewBackendHandler(
backendID string,
backendURL *url.URL,
connectTimeout, headerTimeout time.Duration,
logger logger.Logger,
logger zerolog.Logger,
) http.Handler {

proxy := httputil.NewSingleHostReverseProxy(backendURL)
Expand Down Expand Up @@ -67,7 +66,7 @@ type backendTransport struct {
backendID string

wrapped *http.Transport
logger logger.Logger
logger zerolog.Logger
}

// Construct a backendTransport that wraps an http.Transport and implements http.RoundTripper.
Expand All @@ -76,7 +75,7 @@ type backendTransport struct {
func newBackendTransport(
backendID string,
connectTimeout, headerTimeout time.Duration,
logger logger.Logger,
logger zerolog.Logger,
) *backendTransport {

transport := http.Transport{}
Expand Down Expand Up @@ -161,11 +160,13 @@ func (bt *backendTransport) RoundTrip(req *http.Request) (resp *http.Response, e
responseCode = http.StatusInternalServerError
}
closeBody(resp)
logger.NotifySentry(logger.ReportableError{Error: err, Request: req, Response: resp})
bt.logger.LogFromBackendRequest(
map[string]interface{}{"error": err.Error(), "status": responseCode},
req,
)
bt.logger.Error().
Err(err).
Int("status", responseCode).
Str("method", req.Method).
Str("url", req.URL.String()).
Msg("backend request error")

return newErrorResponse(responseCode), nil
}
responseCode = resp.StatusCode
Expand Down
9 changes: 4 additions & 5 deletions handlers/backend_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/ghttp"
"github.com/rs/zerolog"

"github.com/prometheus/client_golang/prometheus"
promtest "github.com/prometheus/client_golang/prometheus/testutil"
prommodel "github.com/prometheus/client_model/go"

log "github.com/alphagov/router/logger"
)

var _ = Describe("Backend handler", func() {
var (
timeout = 1 * time.Second
logger log.Logger
logger zerolog.Logger

backend *ghttp.Server
backendURL *url.URL
Expand All @@ -33,8 +33,7 @@ var _ = Describe("Backend handler", func() {
BeforeEach(func() {
var err error

logger, err = log.New(GinkgoWriter)
Expect(err).NotTo(HaveOccurred(), "Could not create logger")
logger = zerolog.New(os.Stdout)

backend = ghttp.NewServer()

Expand Down
29 changes: 17 additions & 12 deletions handlers/redirect_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

"github.com/prometheus/client_golang/prometheus"

"github.com/alphagov/router/logger"
"github.com/rs/zerolog"
)

const (
Expand All @@ -20,43 +20,47 @@ const (
downcaseRedirectHandlerType = "downcase-redirect-handler"
)

func NewRedirectHandler(source, target string, preserve bool) http.Handler {
func NewRedirectHandler(source, target string, preserve bool, logger zerolog.Logger) http.Handler {
status := http.StatusMovedPermanently
if preserve {
return &pathPreservingRedirectHandler{source, target, status}
return &pathPreservingRedirectHandler{source, target, status, logger}
}
return &redirectHandler{target, status}
return &redirectHandler{target, status, logger}
}

func addCacheHeaders(w http.ResponseWriter) {
w.Header().Set("Expires", time.Now().Add(cacheDuration).Format(time.RFC1123))
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, public", cacheDuration/time.Second))
}

func addGAQueryParam(target string, r *http.Request) string {
func addGAQueryParam(target string, r *http.Request) (string, error) {
if ga := r.URL.Query().Get("_ga"); ga != "" {
u, err := url.Parse(target)
if err != nil {
defer logger.NotifySentry(logger.ReportableError{Error: err, Request: r})
return target
return target, err
}
values := u.Query()
values.Set("_ga", ga)
u.RawQuery = values.Encode()
return u.String()
return u.String(), nil
}
return target
return target, nil
}

type redirectHandler struct {
url string
code int
url string
code int
logger zerolog.Logger
}

func (handler *redirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
addCacheHeaders(w)

target := addGAQueryParam(handler.url, r)
target, err := addGAQueryParam(handler.url, r)
if err != nil {
handler.logger.Error().Err(err).Msg("failed to add GA query param")
}

http.Redirect(w, r, target, handler.code)

redirectCountMetric.With(prometheus.Labels{
Expand All @@ -68,6 +72,7 @@ type pathPreservingRedirectHandler struct {
sourcePrefix string
targetPrefix string
code int
logger zerolog.Logger
}

func (handler *pathPreservingRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down
14 changes: 9 additions & 5 deletions handlers/redirect_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/rs/zerolog"

"github.com/prometheus/client_golang/prometheus"
promtest "github.com/prometheus/client_golang/prometheus/testutil"
Expand All @@ -17,16 +19,18 @@ var _ = Describe("A redirect handler", func() {
var handler http.Handler
var rr *httptest.ResponseRecorder
const url = "https://source.example.com/source/path/subpath?q1=a&q2=b"
var logger zerolog.Logger

BeforeEach(func() {
rr = httptest.NewRecorder()
logger = zerolog.New(os.Stdout)
})

// These behaviours apply to all combinations of both NewRedirectHandler flags.
for _, preserve := range []bool{true, false} {
Context(fmt.Sprintf("where preserve=%t", preserve), func() {
BeforeEach(func() {
handler = NewRedirectHandler("/source", "/target", preserve)
handler = NewRedirectHandler("/source", "/target", preserve, logger)
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil))
})

Expand All @@ -49,7 +53,7 @@ var _ = Describe("A redirect handler", func() {

Context("where preserve=true", func() {
BeforeEach(func() {
handler = NewRedirectHandler("/source", "/target", true)
handler = NewRedirectHandler("/source", "/target", true, logger)
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil))
})

Expand All @@ -60,7 +64,7 @@ var _ = Describe("A redirect handler", func() {

Context("where preserve=false", func() {
BeforeEach(func() {
handler = NewRedirectHandler("/source", "/target", false)
handler = NewRedirectHandler("/source", "/target", false, logger)
})

It("returns only the configured path in the location header", func() {
Expand All @@ -80,7 +84,7 @@ var _ = Describe("A redirect handler", func() {
Entry(nil, false, http.StatusMovedPermanently),
Entry(nil, true, http.StatusMovedPermanently),
func(preserve bool, expectedStatus int) {
handler = NewRedirectHandler("/source", "/target", preserve)
handler = NewRedirectHandler("/source", "/target", preserve, logger)
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil))
Expect(rr.Result().StatusCode).To(Equal(expectedStatus))
})
Expand All @@ -95,7 +99,7 @@ var _ = Describe("A redirect handler", func() {
lbls := prometheus.Labels{"redirect_type": typeLabel}
before := promtest.ToFloat64(redirectCountMetric.With(lbls))

handler = NewRedirectHandler("/source", "/target", preserve)
handler = NewRedirectHandler("/source", "/target", preserve, logger)
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil))

after := promtest.ToFloat64(redirectCountMetric.With(lbls))
Expand Down
9 changes: 4 additions & 5 deletions lib/backends.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package router

import (
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/alphagov/router/handlers"
"github.com/alphagov/router/logger"
"github.com/rs/zerolog"
)

func loadBackendsFromEnv(connTimeout, headerTimeout time.Duration, logger logger.Logger) (backends map[string]http.Handler) {
func loadBackendsFromEnv(connTimeout, headerTimeout time.Duration, logger zerolog.Logger) (backends map[string]http.Handler) {
backends = make(map[string]http.Handler)

for _, envvar := range os.Environ() {
Expand All @@ -26,13 +25,13 @@ func loadBackendsFromEnv(connTimeout, headerTimeout time.Duration, logger logger
backendURL := pair[1]

if backendURL == "" {
logWarn(fmt.Errorf("router: couldn't find URL for backend %s, skipping", backendID))
logger.Warn().Msgf("no URL for backend %s provided, skipping", backendID)
continue
}

backend, err := url.Parse(backendURL)
if err != nil {
logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s (error: %w), skipping", backendURL, backendID, err))
logger.Warn().Err(err).Msgf("failed to parse URL %s for backend %s, skipping", backendURL, backendID)
continue
}

Expand Down
11 changes: 8 additions & 3 deletions lib/backends_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ import (

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/rs/zerolog"
)

var _ = Describe("Backends", func() {
var (
logger = zerolog.New(os.Stdout)
)

Context("When calling loadBackendsFromEnv", func() {
It("should load backends from environment variables", func() {
os.Setenv("BACKEND_URL_testBackend", "http://example.com")
defer os.Unsetenv("BACKEND_URL_testBackend")

backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil)
backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, logger)

Expect(backends).To(HaveKey("testBackend"))
Expect(backends["testBackend"]).ToNot(BeNil())
Expand All @@ -24,7 +29,7 @@ var _ = Describe("Backends", func() {
os.Setenv("BACKEND_URL_emptyBackend", "")
defer os.Unsetenv("BACKEND_URL_emptyBackend")

backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil)
backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, logger)

Expect(backends).ToNot(HaveKey("emptyBackend"))
})
Expand All @@ -33,7 +38,7 @@ var _ = Describe("Backends", func() {
os.Setenv("BACKEND_URL_invalidBackend", "://invalid-url")
defer os.Unsetenv("BACKEND_URL_invalidBackend")

backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil)
backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, logger)

Expect(backends).ToNot(HaveKey("invalidBackend"))
})
Expand Down
Loading

0 comments on commit a338f18

Please sign in to comment.