Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Delayed round-tripper #12

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions delayed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package transport

import (
"fmt"
"math/rand"
"net/http"
"time"
)

type DelayedConfig struct {
// RequestDelayMin is the delay before the request is sent.
RequestDelayMin time.Duration

// RequestDelayMax is the maximum delay before the request is sent.
RequestDelayMax time.Duration

// ResponseDelayMin is the delay before the response is returned.
ResponseDelayMin time.Duration

// ResponseDelayMax is the maximum delay before the response is returned.
ResponseDelayMax time.Duration
}
xiam marked this conversation as resolved.
Show resolved Hide resolved

// Delayed is a middleware that delays requests and responses, useful when
// testing timeouts.
func Delayed(conf DelayedConfig) func(http.RoundTripper) http.RoundTripper {
if conf.RequestDelayMin > conf.RequestDelayMax {
panic(fmt.Errorf("connect delay min %v is greater than max %v", conf.RequestDelayMin, conf.RequestDelayMax))
}

if conf.ResponseDelayMin > conf.ResponseDelayMax {
panic(fmt.Errorf("transport delay min %v is greater than max %v", conf.ResponseDelayMin, conf.ResponseDelayMax))
}

return func(next http.RoundTripper) http.RoundTripper {
return RoundTripFunc(func(req *http.Request) (*http.Response, error) {
ctx := req.Context()

requestDelay := randDelay(conf.RequestDelayMin, conf.RequestDelayMax)

// wait before sending request
if requestDelay > 0 {
ticker := time.NewTicker(requestDelay)
defer ticker.Stop()
VojtechVitek marked this conversation as resolved.
Show resolved Hide resolved

select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
VojtechVitek marked this conversation as resolved.
Show resolved Hide resolved
}
}

res, err := next.RoundTrip(req)

// wait before sending response body
responseDelay := randDelay(conf.ResponseDelayMin, conf.ResponseDelayMax)
if responseDelay > 0 {
ticker := time.NewTicker(responseDelay)
defer ticker.Stop()

select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
VojtechVitek marked this conversation as resolved.
Show resolved Hide resolved
}
}

return res, err
})
}
}

func randDelay(min, max time.Duration) time.Duration {
if min >= max {
return min
}
return min + time.Duration(rand.Int63n(int64(max-min)))
}
232 changes: 232 additions & 0 deletions delayed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package transport_test

import (
"context"
"fmt"
"io/ioutil"
"testing"
"time"

"net/http"
"net/http/httptest"

"github.com/go-chi/transport"
)

func TestDelayed(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ok")
}))
defer server.Close()

t.Run("default config", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.Delayed(
transport.DelayedConfig{},
),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
resp, err := client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

t.Logf("elapsed time: %v", timeElapsed)

if resp.StatusCode != 200 {
t.Fatal("expected some header, but did not receive")
}

buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

t.Logf("response: %s", string(buf))
})

t.Run("delayed response", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.Delayed(
transport.DelayedConfig{
ResponseDelayMin: 100 * time.Millisecond,
ResponseDelayMax: 200 * time.Millisecond,
},
),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}
})

t.Run("delayed connect", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.Delayed(
transport.DelayedConfig{
RequestDelayMin: 100 * time.Millisecond,
RequestDelayMax: 200 * time.Millisecond,
},
),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}
})

t.Run("delayed connect and response", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.Delayed(
transport.DelayedConfig{
RequestDelayMin: 50 * time.Millisecond,
RequestDelayMax: 100 * time.Millisecond,
ResponseDelayMin: 50 * time.Millisecond,
ResponseDelayMax: 100 * time.Millisecond,
},
),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}
})

t.Run("chained transport", func(t *testing.T) {
var customTransportHit bool

customTransport := transport.RoundTripFunc(func(req *http.Request) (*http.Response, error) {
customTransportHit = true

return http.DefaultTransport.RoundTrip(req)
})

client := &http.Client{
Transport: transport.Chain(
customTransport,

transport.Delayed(
transport.DelayedConfig{
RequestDelayMin: 100 * time.Millisecond,
RequestDelayMax: 200 * time.Millisecond,
},
),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}

if customTransportHit == false {
t.Fatal("expected custom transport to be hit, but it was not")
}
})

t.Run("honor request context", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.Delayed(
transport.DelayedConfig{
RequestDelayMin: 100 * time.Millisecond,
RequestDelayMax: 200 * time.Millisecond,
},
),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

request = request.WithContext(ctx)

timeStart := time.Now()
_, err = client.Do(request)
timeElapsed := time.Since(timeStart)

if err == nil {
t.Fatalf("expected error, but got none")
}

if timeElapsed < 50*time.Millisecond {
t.Fatalf("expected at least 50ms delay, but got %v", timeElapsed)
}

if timeElapsed > 100*time.Millisecond {
t.Fatalf("expected less than 100ms delay, but got %v", timeElapsed)
}
})
}
Loading