From 695ce80e8c901bb9ebab52c0032a6da4af3614e6 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Wed, 14 Aug 2024 10:41:08 +0200 Subject: [PATCH 1/2] Implement transport.If() for conditional transports (e.g. for debugging) --- README.md | 14 ++++----- if.go | 20 ++++++++++++ if_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++ transport.go | 12 +++++-- transport_test.go | 1 + 5 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 if.go create mode 100644 if_test.go diff --git a/README.md b/README.md index a564729..9b7d8b3 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ There are multiple use-cases where this pattern comes handy such as request logg ## Examples -Set up HTTP client, which sets `User-Agent`, `Authorization` and `TraceID` headers automatically : +Set up HTTP client, which sets `User-Agent`, `Authorization` and `TraceID` headers automatically: ```go authClient := http.Client{ Transport: transport.Chain( @@ -22,12 +22,12 @@ authClient := http.Client{ Or debug all outgoing requests globally within your application: ```go -if debugMode { - http.DefaultTransport = transport.Chain( - http.DefaultTransport, - transport.LogRequests, - ) -} +debugMode := os.Getenv("DEBUG") == "true" + +http.DefaultTransport = transport.Chain( + http.DefaultTransport, + transport.If(debugMode, transport.LogRequests), +) ``` # Authors diff --git a/if.go b/if.go new file mode 100644 index 0000000..8fb3202 --- /dev/null +++ b/if.go @@ -0,0 +1,20 @@ +package transport + +import ( + "net/http" +) + +// If sets given transport if given condition is true. Otherwise it sets nil transport, which will be ignored. +// +// Example: +// +// http.DefaultTransport = transport.Chain( +// http.DefaultTransport, +// transport.If(debugMode, transport.LogRequests), +// ) +func If(condition bool, transport func(http.RoundTripper) http.RoundTripper) func(http.RoundTripper) http.RoundTripper { + if condition { + return transport + } + return nil +} diff --git a/if_test.go b/if_test.go new file mode 100644 index 0000000..49e052f --- /dev/null +++ b/if_test.go @@ -0,0 +1,79 @@ +package transport_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/transport" +) + +func TestIfTrue(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("debug") != "true" { + t.Error("expected debug=true") + w.WriteHeader(500) + return + } + + fmt.Fprintf(w, "ok") + })) + defer server.Close() + + client := &http.Client{ + Timeout: 15 * time.Second, + Transport: transport.Chain( + http.DefaultTransport, + transport.If(true, transport.SetHeader("debug", "true")), // Set header. + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(request) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != 200 { + t.Fatal("unexpected response") + } +} + +func TestIfFalse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("debug") != "" { + t.Error("expected no debug header") + w.WriteHeader(500) + return + } + + fmt.Fprintf(w, "ok") + })) + defer server.Close() + + client := &http.Client{ + Timeout: 15 * time.Second, + Transport: transport.Chain( + http.DefaultTransport, + transport.If(false, transport.SetHeader("debug", "true")), // Do not set header. + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(request) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != 200 { + t.Fatal("unexpected response") + } +} diff --git a/transport.go b/transport.go index 299f9a0..56971f2 100644 --- a/transport.go +++ b/transport.go @@ -45,14 +45,22 @@ func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripp base = http.DefaultTransport } + // Filter out nil transports. + filtered := []func(http.RoundTripper) http.RoundTripper{} + for _, fn := range mw { + if fn != nil { + filtered = append(filtered, fn) + } + } + if c, ok := base.(*chain); ok { - c.middlewares = append(c.middlewares, mw...) + c.middlewares = append(c.middlewares, filtered...) return c } return &chain{ baseTransport: base, - middlewares: mw, + middlewares: filtered, } } diff --git a/transport_test.go b/transport_test.go index e2edd3f..748d319 100644 --- a/transport_test.go +++ b/transport_test.go @@ -13,6 +13,7 @@ func TestChain(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("User-Agent") != "transport-chain/v1.0.0" { w.WriteHeader(500) + return } fmt.Fprintf(w, expected) From cfb83ac0116ff7e9547459612b62990ce9deaafb Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Wed, 14 Aug 2024 11:22:01 +0200 Subject: [PATCH 2/2] PR feedback --- if.go | 1 + transport.go | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/if.go b/if.go index 8fb3202..8d0ef16 100644 --- a/if.go +++ b/if.go @@ -16,5 +16,6 @@ func If(condition bool, transport func(http.RoundTripper) http.RoundTripper) fun if condition { return transport } + return nil } diff --git a/transport.go b/transport.go index 56971f2..72b20a5 100644 --- a/transport.go +++ b/transport.go @@ -46,21 +46,21 @@ func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripp } // Filter out nil transports. - filtered := []func(http.RoundTripper) http.RoundTripper{} + mws := []func(http.RoundTripper) http.RoundTripper{} for _, fn := range mw { if fn != nil { - filtered = append(filtered, fn) + mws = append(mws, fn) } } if c, ok := base.(*chain); ok { - c.middlewares = append(c.middlewares, filtered...) + c.middlewares = append(c.middlewares, mws...) return c } return &chain{ baseTransport: base, - middlewares: filtered, + middlewares: mws, } }