Skip to content

Commit

Permalink
Fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Jan 17, 2024
1 parent 4299373 commit d0b684a
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 51 deletions.
138 changes: 92 additions & 46 deletions pkg/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"errors"
"io"
"log"
"math/rand"
"net/http"
"time"
Expand Down Expand Up @@ -44,35 +43,35 @@ func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes
}
}

type xResponseWriter interface {
http.ResponseWriter
discardedResponse() bool
capturedStatusCode() int
}
type xBodyCapturer interface {
io.ReadCloser
Capture()
}

func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
lazyBody := &lazyBodyCapturer{
reader: request.Body,
buf: bytes.NewBuffer([]byte{}),
}
lazyBody := newLazyBodyCapturer(request.Body)
request.Body = lazyBody
for i := 0; ; i++ {
capturedResp := &responseWriterDelegator{
isRetryableStatusCode: r.isRetryableStatusCode,
ResponseWriter: writer,
headerBuf: make(http.Header),
discardErrResp: i < r.maxRetries &&
request.Context().Err() == nil, // abort early on timeout, context cancel
}
discardErrResp := i < r.maxRetries && request.Context().Err() == nil
capturedResp := newResponseWriterDelegator(writer, r.isRetryableStatusCode, discardErrResp)
// call next handler in chain
req, err := http.NewRequestWithContext(request.Context(), request.Method, request.URL.String(), lazyBody)
if err != nil {
log.Printf("clone request: %v", err)
writer.WriteHeader(http.StatusInternalServerError)
return
}
r.nextHandler.ServeHTTP(capturedResp, req)
lazyBody.Capture()
if !capturedResp.discardErrResp || // max retries reached
!r.isRetryableStatusCode(capturedResp.statusCode) {
reqClone := request.Clone(request.Context()) // also copies the reference to the lazy body capturer
r.nextHandler.ServeHTTP(capturedResp, reqClone)

if !capturedResp.discardedResponse() || // max retries reached or context error
!r.isRetryableStatusCode(capturedResp.capturedStatusCode()) {
break
}
// setup for retry
lazyBody.Capture()

totalRetries.Inc()
// Exponential backoff
// exponential backoff
jitter := time.Duration(r.rSource.Intn(50))
time.Sleep(time.Millisecond*time.Duration(1<<uint(i)) + jitter)
}
Expand All @@ -85,12 +84,12 @@ func (r RetryMiddleware) isRetryableStatusCode(status int) bool {

var (
_ http.Flusher = &responseWriterDelegator{}
_ io.ReaderFrom = &responseWriterDelegator{}
_ io.ReaderFrom = &xResponseWriterDelegator{}
)

// responseWriterDelegator represents a wrapper around http.ResponseWriter that provides additional
// functionalities for handling response writing. Depending on the status code and discard settings,
// the heeader + content on write is skipped so that it can be re-used on retry.
// the header + content on write is skipped so that it can be re-used on retry.
type responseWriterDelegator struct {
http.ResponseWriter
headerBuf http.Header
Expand All @@ -101,6 +100,28 @@ type responseWriterDelegator struct {
isRetryableStatusCode func(status int) bool
}

// newResponseWriterDelegator constructor
func newResponseWriterDelegator(writer http.ResponseWriter, isRetryableStatusCode func(status int) bool, discardErrResp bool) xResponseWriter {
d := &responseWriterDelegator{
isRetryableStatusCode: isRetryableStatusCode,
ResponseWriter: writer,
headerBuf: make(http.Header),
discardErrResp: discardErrResp, // abort early on timeout, context cancel
}
if _, ok := writer.(io.ReaderFrom); ok {
return &xResponseWriterDelegator{responseWriterDelegator: d}
}
return d
}

func (r *responseWriterDelegator) discardedResponse() bool {
return r.discardErrResp
}

func (r *responseWriterDelegator) capturedStatusCode() int {
return r.statusCode
}

func (r *responseWriterDelegator) Header() http.Header {
return r.headerBuf
}
Expand Down Expand Up @@ -141,7 +162,22 @@ func (r *responseWriterDelegator) Write(data []byte) (int, error) {
}
}

func (r *responseWriterDelegator) ReadFrom(re io.Reader) (int64, error) {
func (r *responseWriterDelegator) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

// xResponseWriterDelegator provides the same functionalities as responseWriterDelegator but also implements the
// io.ReaderFrom interface.
// The ReadFrom method ensures that the header is set before reading from the reader.
// In case discardErrResp is true and the response status code is retryable, the content is discarded.
// Otherwise, it calls the ReadFrom method of the underlying response writer and returns the result.
type xResponseWriterDelegator struct {
*responseWriterDelegator
}

func (r *xResponseWriterDelegator) ReadFrom(re io.Reader) (int64, error) {
// ensure header is set. default is 200 in Go stdlib
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
Expand All @@ -153,24 +189,33 @@ func (r *responseWriterDelegator) ReadFrom(re io.Reader) (int64, error) {
}
}

func (r *responseWriterDelegator) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

var (
_ io.ReadCloser = &lazyBodyCapturer{}
_ io.WriterTo = &lazyBodyCapturer{}
_ io.WriterTo = &lazyBodyCapturerWriteTo{}
)

// lazyBodyCapturer represents a type that captures the request body lazily.
// It wraps an io.ReadCloser and provides methods for reading, closing,
// writing to an io.Writer, and capturing the body content.
type lazyBodyCapturer struct {
reader io.ReadCloser
capturedBody []byte
buf *bytes.Buffer
allRead bool
}

// newLazyBodyCapturer constructor
func newLazyBodyCapturer(body io.ReadCloser) xBodyCapturer {
c := &lazyBodyCapturer{
reader: body,
buf: bytes.NewBuffer([]byte{}),
}
if _, ok := c.reader.(io.WriterTo); ok {
return &lazyBodyCapturerWriteTo{c}
}
return c
}

func (m *lazyBodyCapturer) Read(p []byte) (int, error) {
if m.allRead {
return m.reader.Read(p)
Expand All @@ -186,23 +231,24 @@ func (m *lazyBodyCapturer) Close() error {
return m.reader.Close()
}

func (m *lazyBodyCapturer) WriteTo(w io.Writer) (int64, error) {
if m.allRead {
return m.reader.(io.WriterTo).WriteTo(w)
}
n, err := m.reader.(io.WriterTo).WriteTo(io.MultiWriter(w, m.buf))
if errors.Is(err, io.EOF) {
m.allRead = true
}
return n, err
}

func (m *lazyBodyCapturer) Capture() {
m.allRead = true
if m.buf != nil {
m.capturedBody = m.buf.Bytes()
m.buf = nil
} else {
m.reader = io.NopCloser(bytes.NewReader(m.capturedBody))
}
m.reader = io.NopCloser(bytes.NewReader(m.capturedBody))
}

type lazyBodyCapturerWriteTo struct {
*lazyBodyCapturer
}

func (m *lazyBodyCapturerWriteTo) WriteTo(w io.Writer) (int64, error) {
if m.allRead {
return m.reader.(io.WriterTo).WriteTo(w)
}
n, err := m.reader.(io.WriterTo).WriteTo(io.MultiWriter(w, m.buf))
m.allRead = true
return n, err
}
Loading

0 comments on commit d0b684a

Please sign in to comment.