diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index c2419e60..e224f4a9 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -42,7 +42,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w) w = captureStatusRespWriter timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v) + httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.CapturedStatusCode())).Observe(v) })) defer timer.ObserveDuration() diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go index 48d204c8..a9ef3b32 100644 --- a/pkg/proxy/metrics.go +++ b/pkg/proxy/metrics.go @@ -22,15 +22,32 @@ func MustRegister(r prometheus.Registerer) { r.MustRegister(httpDuration, totalRetries) } -// captureStatusResponseWriter is a custom HTTP response writer that captures the status code. +// statusCodeCapturer is an interface that extends the http.ResponseWriter interface and provides a method for reading the status code of an HTTP response. +type statusCodeCapturer interface { + http.ResponseWriter + CapturedStatusCode() int +} + +// captureStatusResponseWriter is a custom HTTP response writer that implements statusCodeCapturer type captureStatusResponseWriter struct { http.ResponseWriter statusCode int wroteHeader bool } -func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter { - return &captureStatusResponseWriter{ResponseWriter: responseWriter} +func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) statusCodeCapturer { + if o, ok := responseWriter.(statusCodeCapturer); ok { // nothing to do as code is captured already + return o + } + c := &captureStatusResponseWriter{ResponseWriter: responseWriter} + if _, ok := responseWriter.(io.ReaderFrom); ok { + return &captureStatusResponseWriterWithReadFrom{captureStatusResponseWriter: c} + } + return c +} + +func (c *captureStatusResponseWriter) CapturedStatusCode() int { + return c.statusCode } func (c *captureStatusResponseWriter) WriteHeader(code int) { @@ -46,7 +63,11 @@ func (c *captureStatusResponseWriter) Write(b []byte) (int, error) { return c.ResponseWriter.Write(b) } -func (c *captureStatusResponseWriter) ReadFrom(re io.Reader) (int64, error) { +type captureStatusResponseWriterWithReadFrom struct { + *captureStatusResponseWriter +} + +func (c *captureStatusResponseWriterWithReadFrom) ReadFrom(re io.Reader) (int64, error) { if !c.wroteHeader { c.WriteHeader(http.StatusOK) } diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/metrics_test.go index cf0d5c3e..b43b1606 100644 --- a/pkg/proxy/metrics_test.go +++ b/pkg/proxy/metrics_test.go @@ -1,6 +1,7 @@ package proxy import ( + "io" "net/http" "net/http/httptest" "strings" @@ -83,6 +84,50 @@ func TestMetricsViaLinter(t *testing.T) { require.Empty(t, problems) } +func TestCaptureStatusCodeResponseWriters(t *testing.T) { + specs := map[string]struct { + rspWriter http.ResponseWriter + expType any + write func(t *testing.T, r http.ResponseWriter, content string) + }{ + "implements statusCodeCapturer": { + rspWriter: &responseWriterDelegator{headerBuf: make(http.Header), ResponseWriter: httptest.NewRecorder()}, + expType: &responseWriterDelegator{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + r.WriteHeader(200) + }, + }, + "implements io.ReaderFrom": { + rspWriter: &testResponseWriter{ResponseRecorder: httptest.NewRecorder()}, + expType: &captureStatusResponseWriterWithReadFrom{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + n, err := r.(io.ReaderFrom).ReadFrom(strings.NewReader(content)) + require.NoError(t, err) + assert.Equal(t, len(content), int(n)) + }, + }, + "default": { + rspWriter: httptest.NewRecorder(), + expType: &captureStatusResponseWriter{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + n, err := r.Write([]byte(content)) + require.NoError(t, err) + assert.Equal(t, len(content), n) + }, + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + instance := newCaptureStatusCodeResponseWriter(spec.rspWriter) + require.IsType(t, spec.expType, instance) + spec.write(t, instance, "foo") + gotCode := instance.CapturedStatusCode() + assert.Equal(t, http.StatusOK, gotCode) + }) + } +} + func toMap(s []*io_prometheus_client.LabelPair) map[string]string { r := make(map[string]string, len(s)) for _, v := range s { diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index 6964a84a..2fb1ae68 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -46,7 +46,7 @@ func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes type xResponseWriter interface { http.ResponseWriter discardedResponse() bool - capturedStatusCode() int + CapturedStatusCode() int } type xBodyCapturer interface { io.ReadCloser @@ -64,7 +64,7 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req r.nextHandler.ServeHTTP(capturedResp, reqClone) if !capturedResp.discardedResponse() || // max retries reached or context error - !r.isRetryableStatusCode(capturedResp.capturedStatusCode()) { + !r.isRetryableStatusCode(capturedResp.CapturedStatusCode()) { break } // setup for retry @@ -83,8 +83,9 @@ func (r RetryMiddleware) isRetryableStatusCode(status int) bool { } var ( - _ http.Flusher = &responseWriterDelegator{} - _ io.ReaderFrom = &xResponseWriterDelegator{} + _ http.Flusher = &responseWriterDelegator{} + _ io.ReaderFrom = &xResponseWriterDelegator{} + _ statusCodeCapturer = &responseWriterDelegator{} ) // responseWriterDelegator represents a wrapper around http.ResponseWriter that provides additional @@ -118,7 +119,7 @@ func (r *responseWriterDelegator) discardedResponse() bool { return r.discardErrResp } -func (r *responseWriterDelegator) capturedStatusCode() int { +func (r *responseWriterDelegator) CapturedStatusCode() int { return r.statusCode } diff --git a/pkg/proxy/middleware_test.go b/pkg/proxy/middleware_test.go index 452fee12..18bc88d7 100644 --- a/pkg/proxy/middleware_test.go +++ b/pkg/proxy/middleware_test.go @@ -130,7 +130,7 @@ func TestWriteDelegatorReadFrom(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(myTestContent), int(n)) assert.Equal(t, myTestContent, rec.Body.String()) - assert.Equal(t, http.StatusOK, d.capturedStatusCode()) + assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) // scenario: discard on error enabled rec = &testResponseWriter{ResponseRecorder: httptest.NewRecorder()} @@ -142,7 +142,7 @@ func TestWriteDelegatorReadFrom(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(myTestContent), int(n)) assert.Equal(t, "", rec.Body.String()) - assert.Equal(t, http.StatusOK, d.capturedStatusCode()) + assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) // scenario: not implementing io.ReaderFrom d = newResponseWriterDelegator(httptest.NewRecorder(), func(int) bool { return true }, false)