Skip to content

Commit

Permalink
Better status code capturing
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Jan 17, 2024
1 parent d0b684a commit 5fb2e66
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w)
w = captureStatusRespWriter
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v)
httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.CapturedStatusCode())).Observe(v)
}))
defer timer.ObserveDuration()

Expand Down
29 changes: 25 additions & 4 deletions pkg/proxy/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,32 @@ func MustRegister(r prometheus.Registerer) {
r.MustRegister(httpDuration, totalRetries)
}

// captureStatusResponseWriter is a custom HTTP response writer that captures the status code.
// statusCodeCapturer is an interface that extends the http.ResponseWriter interface and provides a method for reading the status code of an HTTP response.
type statusCodeCapturer interface {
http.ResponseWriter
CapturedStatusCode() int
}

// captureStatusResponseWriter is a custom HTTP response writer that implements statusCodeCapturer
type captureStatusResponseWriter struct {
http.ResponseWriter
statusCode int
wroteHeader bool
}

func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter {
return &captureStatusResponseWriter{ResponseWriter: responseWriter}
func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) statusCodeCapturer {
if o, ok := responseWriter.(statusCodeCapturer); ok { // nothing to do as code is captured already
return o
}
c := &captureStatusResponseWriter{ResponseWriter: responseWriter}
if _, ok := responseWriter.(io.ReaderFrom); ok {
return &captureStatusResponseWriterWithReadFrom{captureStatusResponseWriter: c}
}
return c
}

func (c *captureStatusResponseWriter) CapturedStatusCode() int {
return c.statusCode
}

func (c *captureStatusResponseWriter) WriteHeader(code int) {
Expand All @@ -46,7 +63,11 @@ func (c *captureStatusResponseWriter) Write(b []byte) (int, error) {
return c.ResponseWriter.Write(b)
}

func (c *captureStatusResponseWriter) ReadFrom(re io.Reader) (int64, error) {
type captureStatusResponseWriterWithReadFrom struct {
*captureStatusResponseWriter
}

func (c *captureStatusResponseWriterWithReadFrom) ReadFrom(re io.Reader) (int64, error) {
if !c.wroteHeader {
c.WriteHeader(http.StatusOK)
}
Expand Down
45 changes: 45 additions & 0 deletions pkg/proxy/metrics_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"io"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -83,6 +84,50 @@ func TestMetricsViaLinter(t *testing.T) {
require.Empty(t, problems)
}

func TestCaptureStatusCodeResponseWriters(t *testing.T) {
specs := map[string]struct {
rspWriter http.ResponseWriter
expType any
write func(t *testing.T, r http.ResponseWriter, content string)
}{
"implements statusCodeCapturer": {
rspWriter: &responseWriterDelegator{headerBuf: make(http.Header), ResponseWriter: httptest.NewRecorder()},
expType: &responseWriterDelegator{},
write: func(t *testing.T, r http.ResponseWriter, content string) {
r.WriteHeader(200)
},
},
"implements io.ReaderFrom": {
rspWriter: &testResponseWriter{ResponseRecorder: httptest.NewRecorder()},
expType: &captureStatusResponseWriterWithReadFrom{},
write: func(t *testing.T, r http.ResponseWriter, content string) {
n, err := r.(io.ReaderFrom).ReadFrom(strings.NewReader(content))
require.NoError(t, err)
assert.Equal(t, len(content), int(n))
},
},
"default": {
rspWriter: httptest.NewRecorder(),
expType: &captureStatusResponseWriter{},
write: func(t *testing.T, r http.ResponseWriter, content string) {
n, err := r.Write([]byte(content))
require.NoError(t, err)
assert.Equal(t, len(content), n)
},
},
}

for name, spec := range specs {
t.Run(name, func(t *testing.T) {
instance := newCaptureStatusCodeResponseWriter(spec.rspWriter)
require.IsType(t, spec.expType, instance)
spec.write(t, instance, "foo")
gotCode := instance.CapturedStatusCode()
assert.Equal(t, http.StatusOK, gotCode)
})
}
}

func toMap(s []*io_prometheus_client.LabelPair) map[string]string {
r := make(map[string]string, len(s))
for _, v := range s {
Expand Down
11 changes: 6 additions & 5 deletions pkg/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes
type xResponseWriter interface {
http.ResponseWriter
discardedResponse() bool
capturedStatusCode() int
CapturedStatusCode() int
}
type xBodyCapturer interface {
io.ReadCloser
Expand All @@ -64,7 +64,7 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req
r.nextHandler.ServeHTTP(capturedResp, reqClone)

if !capturedResp.discardedResponse() || // max retries reached or context error
!r.isRetryableStatusCode(capturedResp.capturedStatusCode()) {
!r.isRetryableStatusCode(capturedResp.CapturedStatusCode()) {
break
}
// setup for retry
Expand All @@ -83,8 +83,9 @@ func (r RetryMiddleware) isRetryableStatusCode(status int) bool {
}

var (
_ http.Flusher = &responseWriterDelegator{}
_ io.ReaderFrom = &xResponseWriterDelegator{}
_ http.Flusher = &responseWriterDelegator{}
_ io.ReaderFrom = &xResponseWriterDelegator{}
_ statusCodeCapturer = &responseWriterDelegator{}
)

// responseWriterDelegator represents a wrapper around http.ResponseWriter that provides additional
Expand Down Expand Up @@ -118,7 +119,7 @@ func (r *responseWriterDelegator) discardedResponse() bool {
return r.discardErrResp
}

func (r *responseWriterDelegator) capturedStatusCode() int {
func (r *responseWriterDelegator) CapturedStatusCode() int {
return r.statusCode
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/proxy/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestWriteDelegatorReadFrom(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, len(myTestContent), int(n))
assert.Equal(t, myTestContent, rec.Body.String())
assert.Equal(t, http.StatusOK, d.capturedStatusCode())
assert.Equal(t, http.StatusOK, d.CapturedStatusCode())

// scenario: discard on error enabled
rec = &testResponseWriter{ResponseRecorder: httptest.NewRecorder()}
Expand All @@ -142,7 +142,7 @@ func TestWriteDelegatorReadFrom(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, len(myTestContent), int(n))
assert.Equal(t, "", rec.Body.String())
assert.Equal(t, http.StatusOK, d.capturedStatusCode())
assert.Equal(t, http.StatusOK, d.CapturedStatusCode())

// scenario: not implementing io.ReaderFrom
d = newResponseWriterDelegator(httptest.NewRecorder(), func(int) bool { return true }, false)
Expand Down

0 comments on commit 5fb2e66

Please sign in to comment.