Skip to content

Commit

Permalink
Lazy buffer request body for retry + reuse; refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Jan 16, 2024
1 parent e71df9f commit b74a395
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 34 deletions.
14 changes: 10 additions & 4 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,16 @@ func parseModel(r *http.Request) (string, *http.Request, error) {
if model := r.Header.Get("X-Model"); model != "" {
return model, r, nil
}
// parse request body for model name, ignore errors
body, err := io.ReadAll(r.Body)
if err != nil {
return "", r, nil
var body []byte
if mb, ok := r.Body.(*lazyBodyCapturer); ok && mb.capturedBody != nil {
body = mb.capturedBody
} else {
// parse request body for model name, ignore errors
var err error
body, err = io.ReadAll(r.Body)
if err != nil {
return "", r, nil
}
}

var payload struct {
Expand Down
25 changes: 21 additions & 4 deletions pkg/proxy/metrics.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"io"
"net/http"

"github.com/prometheus/client_golang/prometheus"
Expand All @@ -24,14 +25,30 @@ func MustRegister(r prometheus.Registerer) {
// captureStatusResponseWriter is a custom HTTP response writer that captures the status code.
type captureStatusResponseWriter struct {
http.ResponseWriter
statusCode int
statusCode int
wroteHeader bool
}

func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter {
return &captureStatusResponseWriter{ResponseWriter: responseWriter}
}

func (srw *captureStatusResponseWriter) WriteHeader(code int) {
srw.statusCode = code
srw.ResponseWriter.WriteHeader(code)
func (c *captureStatusResponseWriter) WriteHeader(code int) {
c.wroteHeader = true
c.statusCode = code
c.ResponseWriter.WriteHeader(code)
}

func (c *captureStatusResponseWriter) Write(b []byte) (int, error) {
if !c.wroteHeader {
c.WriteHeader(http.StatusOK)
}
return c.ResponseWriter.Write(b)
}

func (c *captureStatusResponseWriter) ReadFrom(re io.Reader) (int64, error) {
if !c.wroteHeader {
c.WriteHeader(http.StatusOK)
}
return c.ResponseWriter.(io.ReaderFrom).ReadFrom(re)
}
63 changes: 61 additions & 2 deletions pkg/proxy/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package proxy

import (
"bytes"
"errors"
"io"
"math/rand"
"net/http"
Expand All @@ -27,6 +29,11 @@ func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware {
}

func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
lazyBody := &lazyBodyCapturer{
reader: request.Body,
buf: bytes.NewBuffer([]byte{}),
}
request.Body = lazyBody
var capturedResp *responseWriterDelegator
for i := 0; ; i++ {
capturedResp = &responseWriterDelegator{
Expand All @@ -36,8 +43,12 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req
request.Context().Err() == nil, // abort early on timeout, context cancel
}
// call next handler in chain
r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context()))

req, err := http.NewRequestWithContext(request.Context(), request.Method, request.URL.String(), lazyBody)
if err != nil {
panic(err)
}
r.nextHandler.ServeHTTP(capturedResp, req)
lazyBody.Capture()
if !capturedResp.discardErrResp || // max retries reached
!isRetryableStatusCode(capturedResp.statusCode) {
break
Expand Down Expand Up @@ -121,3 +132,51 @@ func (r *responseWriterDelegator) Flush() {
f.Flush()
}
}

var (
_ io.ReadCloser = &lazyBodyCapturer{}
_ io.WriterTo = &lazyBodyCapturer{}
)

type lazyBodyCapturer struct {
reader io.ReadCloser
capturedBody []byte
buf *bytes.Buffer
allRead bool
}

func (m *lazyBodyCapturer) Read(p []byte) (int, error) {
if m.allRead {
return m.reader.Read(p)
}
n, err := io.TeeReader(m.reader, m.buf).Read(p)
if errors.Is(err, io.EOF) {
m.allRead = true
}
return n, err
}

func (m *lazyBodyCapturer) Close() error {
return m.reader.Close()
}

func (m *lazyBodyCapturer) WriteTo(w io.Writer) (int64, error) {
if m.allRead {
return m.reader.(io.WriterTo).WriteTo(w)
}
n, err := m.reader.(io.WriterTo).WriteTo(io.MultiWriter(w, m.buf))
if errors.Is(err, io.EOF) {
m.allRead = true
}
return n, err
}

func (m *lazyBodyCapturer) Capture() {
m.allRead = true
if m.buf != nil {
m.capturedBody = m.buf.Bytes()
m.buf = nil
} else {
m.reader = io.NopCloser(bytes.NewReader(m.capturedBody))
}
}
102 changes: 79 additions & 23 deletions tests/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,7 @@ func TestScaleUpAndDown(t *testing.T) {
}))

// Mock an EndpointSlice.
testBackendURL, err := url.Parse(testBackend.URL)
require.NoError(t, err)
testBackendPort, err := strconv.Atoi(testBackendURL.Port())
require.NoError(t, err)
require.NoError(t, testK8sClient.Create(testCtx,
endpointSlice(
modelName,
testBackendURL.Hostname(),
int32(testBackendPort),
),
))
withMockEndpointSlice(t, testBackend, modelName)

// Wait for deployment mapping to sync.
time.Sleep(3 * time.Second)
Expand Down Expand Up @@ -103,17 +93,7 @@ func TestHandleModelUndeployment(t *testing.T) {
}))

// Mock an EndpointSlice.
testBackendURL, err := url.Parse(testBackend.URL)
require.NoError(t, err)
testBackendPort, err := strconv.Atoi(testBackendURL.Port())
require.NoError(t, err)
require.NoError(t, testK8sClient.Create(testCtx,
endpointSlice(
modelName,
testBackendURL.Hostname(),
int32(testBackendPort),
),
))
withMockEndpointSlice(t, testBackend, modelName)

// Wait for deployment mapping to sync.
time.Sleep(3 * time.Second)
Expand All @@ -132,7 +112,7 @@ func TestHandleModelUndeployment(t *testing.T) {
require.NoError(t, testK8sClient.Delete(testCtx, deploy))

// Check that the deployment was deleted
err = testK8sClient.Get(testCtx, client.ObjectKey{
err := testK8sClient.Get(testCtx, client.ObjectKey{
Namespace: deploy.Namespace,
Name: deploy.Name,
}, deploy)
Expand All @@ -151,6 +131,82 @@ func TestHandleModelUndeployment(t *testing.T) {
wg.Wait()
}

func TestRetryMiddleware(t *testing.T) {
const modelName = "test-model-c"
deploy := testDeployment(modelName)
require.NoError(t, testK8sClient.Create(testCtx, deploy))

// Wait for deployment mapping to sync.
time.Sleep(3 * time.Second)
backendRequests := &atomic.Int32{}
var serverCodes []int
testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
i := backendRequests.Add(1)
code := serverCodes[i-1]
t.Logf("Serving request from testBackend: %d; code: %d\n", i, code)
w.WriteHeader(code)
}))

// Mock an EndpointSlice.
withMockEndpointSlice(t, testBackend, modelName)

specs := map[string]struct {
serverCodes []int
expResultCode int
expBackendHits int32
}{
"max retries - succeeds": {
serverCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusOK},
expResultCode: http.StatusOK,
expBackendHits: 4,
},
"max retries - fails": {
serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway},
expResultCode: http.StatusBadGateway,
expBackendHits: 4,
},
"non retryable error code": {
serverCodes: []int{http.StatusNotImplemented},
expResultCode: http.StatusNotImplemented,
expBackendHits: 1,
},
"200 status code": {
serverCodes: []int{http.StatusOK},
expResultCode: http.StatusOK,
expBackendHits: 1,
},
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
// setup
serverCodes = spec.serverCodes
backendRequests.Store(0)

// when single request sent
var wg sync.WaitGroup
sendRequest(t, &wg, modelName, spec.expResultCode)
wg.Wait()

// then
require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit with retries")
})
}
}

func withMockEndpointSlice(t *testing.T, testBackend *httptest.Server, modelName string) {
testBackendURL, err := url.Parse(testBackend.URL)
require.NoError(t, err)
testBackendPort, err := strconv.Atoi(testBackendURL.Port())
require.NoError(t, err)
require.NoError(t, testK8sClient.Create(testCtx,
endpointSlice(
modelName,
testBackendURL.Hostname(),
int32(testBackendPort),
),
))
}

func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) {
require.EventuallyWithT(t, func(t *assert.CollectT) {
err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestMain(m *testing.M) {
Endpoints: endpointManager,
Queues: queueManager,
}
testServer = httptest.NewServer(handler)
testServer = httptest.NewServer(proxy.NewRetryMiddleware(3, handler))
defer testServer.Close()

go func() {
Expand Down

0 comments on commit b74a395

Please sign in to comment.