diff --git a/delayed.go b/delayed.go new file mode 100644 index 0000000..98caf69 --- /dev/null +++ b/delayed.go @@ -0,0 +1,70 @@ +package transport + +import ( + "fmt" + "math/rand" + "net/http" + "time" +) + +// DelayedRequest is a middleware that delays requests, useful when testing +// timeouts while waiting on a request to be sent upstream. +func DelayedRequest(requestDelayMin, requestDelayMax time.Duration) func(http.RoundTripper) http.RoundTripper { + if requestDelayMin > requestDelayMax { + panic(fmt.Sprintf("requestDelayMin %v is greater than requestDelayMax %v", requestDelayMin, requestDelayMax)) + } + return delayedRoundTripper(randDelay(requestDelayMin, requestDelayMax), 0) +} + +// DelayedResponse is a middleware that delays responses, useful when testing +// timeouts after upstream has processed the request, the response is hold back +// until the delay is over. +func DelayedResponse(responseDelayMin, responseDelayMax time.Duration) func(http.RoundTripper) http.RoundTripper { + if responseDelayMin > responseDelayMax { + panic(fmt.Sprintf("responseDelayMin %v is greater than responseDelayMax %v", responseDelayMin, responseDelayMax)) + } + return delayedRoundTripper(0, randDelay(responseDelayMin, responseDelayMax)) +} + +func delayedRoundTripper(requestDelay, responseDelay time.Duration) func(http.RoundTripper) http.RoundTripper { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripFunc(func(req *http.Request) (*http.Response, error) { + ctx := req.Context() + + // wait before sending request + if requestDelay > 0 { + ticker := time.NewTicker(requestDelay) + defer ticker.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + } + } + + res, err := next.RoundTrip(req) + + // wait before sending response body + if responseDelay > 0 { + ticker := time.NewTicker(responseDelay) + defer ticker.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + } + } + + return res, err + }) + } +} + +func randDelay(min, max time.Duration) time.Duration { + if min >= max { + return min + } + return min + time.Duration(rand.Int63n(int64(max-min))) +} diff --git a/delayed_test.go b/delayed_test.go new file mode 100644 index 0000000..a53063f --- /dev/null +++ b/delayed_test.go @@ -0,0 +1,207 @@ +package transport_test + +import ( + "context" + "fmt" + "io/ioutil" + "testing" + "time" + + "net/http" + "net/http/httptest" + + "github.com/go-chi/transport" +) + +func TestDelayed(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "ok") + })) + defer server.Close() + + t.Run("default config", func(t *testing.T) { + client := &http.Client{ + Transport: transport.Chain( + nil, + transport.DelayedRequest(0, 0), + transport.DelayedResponse(0, 0), + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + + timeStart := time.Now() + resp, err := client.Do(request) + if err != nil { + t.Fatal(err) + } + timeElapsed := time.Since(timeStart) + + t.Logf("elapsed time: %v", timeElapsed) + + if resp.StatusCode != 200 { + t.Fatal("expected some header, but did not receive") + } + + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + t.Logf("response: %s", string(buf)) + }) + + t.Run("delayed response", func(t *testing.T) { + client := &http.Client{ + Transport: transport.Chain( + nil, + transport.DelayedResponse(100*time.Millisecond, 200*time.Millisecond), + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + + timeStart := time.Now() + _, err = client.Do(request) + if err != nil { + t.Fatal(err) + } + timeElapsed := time.Since(timeStart) + + if timeElapsed < 100*time.Millisecond { + t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) + } + }) + + t.Run("delayed connect", func(t *testing.T) { + client := &http.Client{ + Transport: transport.Chain( + nil, + transport.DelayedResponse( + 100*time.Millisecond, + 200*time.Millisecond, + ), + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + + timeStart := time.Now() + _, err = client.Do(request) + if err != nil { + t.Fatal(err) + } + timeElapsed := time.Since(timeStart) + + if timeElapsed < 100*time.Millisecond { + t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) + } + }) + + t.Run("delayed request and response", func(t *testing.T) { + client := &http.Client{ + Transport: transport.Chain( + nil, + transport.DelayedRequest(50*time.Millisecond, 100*time.Millisecond), + transport.DelayedResponse(50*time.Millisecond, 100*time.Millisecond), + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + + timeStart := time.Now() + _, err = client.Do(request) + if err != nil { + t.Fatal(err) + } + timeElapsed := time.Since(timeStart) + + if timeElapsed < 100*time.Millisecond { + t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) + } + }) + + t.Run("chained transport", func(t *testing.T) { + var customTransportHit bool + + customTransport := transport.RoundTripFunc(func(req *http.Request) (*http.Response, error) { + customTransportHit = true + + return http.DefaultTransport.RoundTrip(req) + }) + + client := &http.Client{ + Transport: transport.Chain( + customTransport, + transport.DelayedRequest(100*time.Millisecond, 200*time.Millisecond), + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + + timeStart := time.Now() + _, err = client.Do(request) + if err != nil { + t.Fatal(err) + } + timeElapsed := time.Since(timeStart) + + if timeElapsed < 100*time.Millisecond { + t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) + } + + if customTransportHit == false { + t.Fatal("expected custom transport to be hit, but it was not") + } + }) + + t.Run("honor request context", func(t *testing.T) { + client := &http.Client{ + Transport: transport.Chain( + nil, + transport.DelayedRequest(100*time.Millisecond, 200*time.Millisecond), + ), + } + + request, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + request = request.WithContext(ctx) + + timeStart := time.Now() + _, err = client.Do(request) + timeElapsed := time.Since(timeStart) + + if err == nil { + t.Fatalf("expected error, but got none") + } + + if timeElapsed < 50*time.Millisecond { + t.Fatalf("expected at least 50ms delay, but got %v", timeElapsed) + } + + if timeElapsed > 100*time.Millisecond { + t.Fatalf("expected less than 100ms delay, but got %v", timeElapsed) + } + }) +}