From b74a3950fce7f729183e20f3453cb815ace3f756 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Tue, 16 Jan 2024 11:28:58 +0100 Subject: [PATCH] Lazy buffer request body for retry + reuse; refactorings --- pkg/proxy/handler.go | 14 +++- pkg/proxy/metrics.go | 25 ++++++- pkg/proxy/middleware.go | 63 +++++++++++++++- tests/integration/integration_test.go | 102 ++++++++++++++++++++------ tests/integration/main_test.go | 2 +- 5 files changed, 172 insertions(+), 34 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index c7b18c6e..c2419e60 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -129,10 +129,16 @@ func parseModel(r *http.Request) (string, *http.Request, error) { if model := r.Header.Get("X-Model"); model != "" { return model, r, nil } - // parse request body for model name, ignore errors - body, err := io.ReadAll(r.Body) - if err != nil { - return "", r, nil + var body []byte + if mb, ok := r.Body.(*lazyBodyCapturer); ok && mb.capturedBody != nil { + body = mb.capturedBody + } else { + // parse request body for model name, ignore errors + var err error + body, err = io.ReadAll(r.Body) + if err != nil { + return "", r, nil + } } var payload struct { diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go index 45d20e8e..48d204c8 100644 --- a/pkg/proxy/metrics.go +++ b/pkg/proxy/metrics.go @@ -1,6 +1,7 @@ package proxy import ( + "io" "net/http" "github.com/prometheus/client_golang/prometheus" @@ -24,14 +25,30 @@ func MustRegister(r prometheus.Registerer) { // captureStatusResponseWriter is a custom HTTP response writer that captures the status code. type captureStatusResponseWriter struct { http.ResponseWriter - statusCode int + statusCode int + wroteHeader bool } func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter { return &captureStatusResponseWriter{ResponseWriter: responseWriter} } -func (srw *captureStatusResponseWriter) WriteHeader(code int) { - srw.statusCode = code - srw.ResponseWriter.WriteHeader(code) +func (c *captureStatusResponseWriter) WriteHeader(code int) { + c.wroteHeader = true + c.statusCode = code + c.ResponseWriter.WriteHeader(code) +} + +func (c *captureStatusResponseWriter) Write(b []byte) (int, error) { + if !c.wroteHeader { + c.WriteHeader(http.StatusOK) + } + return c.ResponseWriter.Write(b) +} + +func (c *captureStatusResponseWriter) ReadFrom(re io.Reader) (int64, error) { + if !c.wroteHeader { + c.WriteHeader(http.StatusOK) + } + return c.ResponseWriter.(io.ReaderFrom).ReadFrom(re) } diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index 2c614fec..e0b420e3 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -1,6 +1,8 @@ package proxy import ( + "bytes" + "errors" "io" "math/rand" "net/http" @@ -27,6 +29,11 @@ func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware { } func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + lazyBody := &lazyBodyCapturer{ + reader: request.Body, + buf: bytes.NewBuffer([]byte{}), + } + request.Body = lazyBody var capturedResp *responseWriterDelegator for i := 0; ; i++ { capturedResp = &responseWriterDelegator{ @@ -36,8 +43,12 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req request.Context().Err() == nil, // abort early on timeout, context cancel } // call next handler in chain - r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context())) - + req, err := http.NewRequestWithContext(request.Context(), request.Method, request.URL.String(), lazyBody) + if err != nil { + panic(err) + } + r.nextHandler.ServeHTTP(capturedResp, req) + lazyBody.Capture() if !capturedResp.discardErrResp || // max retries reached !isRetryableStatusCode(capturedResp.statusCode) { break @@ -121,3 +132,51 @@ func (r *responseWriterDelegator) Flush() { f.Flush() } } + +var ( + _ io.ReadCloser = &lazyBodyCapturer{} + _ io.WriterTo = &lazyBodyCapturer{} +) + +type lazyBodyCapturer struct { + reader io.ReadCloser + capturedBody []byte + buf *bytes.Buffer + allRead bool +} + +func (m *lazyBodyCapturer) Read(p []byte) (int, error) { + if m.allRead { + return m.reader.Read(p) + } + n, err := io.TeeReader(m.reader, m.buf).Read(p) + if errors.Is(err, io.EOF) { + m.allRead = true + } + return n, err +} + +func (m *lazyBodyCapturer) Close() error { + return m.reader.Close() +} + +func (m *lazyBodyCapturer) WriteTo(w io.Writer) (int64, error) { + if m.allRead { + return m.reader.(io.WriterTo).WriteTo(w) + } + n, err := m.reader.(io.WriterTo).WriteTo(io.MultiWriter(w, m.buf)) + if errors.Is(err, io.EOF) { + m.allRead = true + } + return n, err +} + +func (m *lazyBodyCapturer) Capture() { + m.allRead = true + if m.buf != nil { + m.capturedBody = m.buf.Bytes() + m.buf = nil + } else { + m.reader = io.NopCloser(bytes.NewReader(m.capturedBody)) + } +} diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index 10fb5e45..bd4f25e7 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -43,17 +43,7 @@ func TestScaleUpAndDown(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -103,17 +93,7 @@ func TestHandleModelUndeployment(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -132,7 +112,7 @@ func TestHandleModelUndeployment(t *testing.T) { require.NoError(t, testK8sClient.Delete(testCtx, deploy)) // Check that the deployment was deleted - err = testK8sClient.Get(testCtx, client.ObjectKey{ + err := testK8sClient.Get(testCtx, client.ObjectKey{ Namespace: deploy.Namespace, Name: deploy.Name, }, deploy) @@ -151,6 +131,82 @@ func TestHandleModelUndeployment(t *testing.T) { wg.Wait() } +func TestRetryMiddleware(t *testing.T) { + const modelName = "test-model-c" + deploy := testDeployment(modelName) + require.NoError(t, testK8sClient.Create(testCtx, deploy)) + + // Wait for deployment mapping to sync. + time.Sleep(3 * time.Second) + backendRequests := &atomic.Int32{} + var serverCodes []int + testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + i := backendRequests.Add(1) + code := serverCodes[i-1] + t.Logf("Serving request from testBackend: %d; code: %d\n", i, code) + w.WriteHeader(code) + })) + + // Mock an EndpointSlice. + withMockEndpointSlice(t, testBackend, modelName) + + specs := map[string]struct { + serverCodes []int + expResultCode int + expBackendHits int32 + }{ + "max retries - succeeds": { + serverCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusOK}, + expResultCode: http.StatusOK, + expBackendHits: 4, + }, + "max retries - fails": { + serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway}, + expResultCode: http.StatusBadGateway, + expBackendHits: 4, + }, + "non retryable error code": { + serverCodes: []int{http.StatusNotImplemented}, + expResultCode: http.StatusNotImplemented, + expBackendHits: 1, + }, + "200 status code": { + serverCodes: []int{http.StatusOK}, + expResultCode: http.StatusOK, + expBackendHits: 1, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + // setup + serverCodes = spec.serverCodes + backendRequests.Store(0) + + // when single request sent + var wg sync.WaitGroup + sendRequest(t, &wg, modelName, spec.expResultCode) + wg.Wait() + + // then + require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit with retries") + }) + } +} + +func withMockEndpointSlice(t *testing.T, testBackend *httptest.Server, modelName string) { + testBackendURL, err := url.Parse(testBackend.URL) + require.NoError(t, err) + testBackendPort, err := strconv.Atoi(testBackendURL.Port()) + require.NoError(t, err) + require.NoError(t, testK8sClient.Create(testCtx, + endpointSlice( + modelName, + testBackendURL.Hostname(), + int32(testBackendPort), + ), + )) +} + func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) { require.EventuallyWithT(t, func(t *assert.CollectT) { err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy) diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 7ec22fb7..36d8987c 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -109,7 +109,7 @@ func TestMain(m *testing.M) { Endpoints: endpointManager, Queues: queueManager, } - testServer = httptest.NewServer(handler) + testServer = httptest.NewServer(proxy.NewRetryMiddleware(3, handler)) defer testServer.Close() go func() {