Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spike: Integration test fix #69

Closed
wants to merge 13 commits into from
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ENVTEST_K8S_VERSION = 1.27.1

.PHONY: test
test: test-unit test-race test-integration test-e2e
test: test-unit test-integration test-e2e

.PHONY: test-unit
test-unit:
Expand Down
2 changes: 2 additions & 0 deletions cmd/lingo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func run() error {

concurrency := getEnvInt("CONCURRENCY", 100)
scaleDownDelay := getEnvInt("SCALE_DOWN_DELAY", 30)
backendRetries := getEnvInt("BACKEND_RETRIES", 1)

var metricsAddr string
var probeAddr string
Expand Down Expand Up @@ -154,6 +155,7 @@ func run() error {

proxy.MustRegister(metricsRegistry)
proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager)
proxyHandler.MaxRetries = backendRetries
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler}

statsHandler := &stats.Handler{
Expand Down
240 changes: 137 additions & 103 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
@@ -1,153 +1,140 @@
package proxy

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strconv"

"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"

"github.com/substratusai/lingo/pkg/deployments"
"github.com/substratusai/lingo/pkg/endpoints"
"github.com/substratusai/lingo/pkg/queue"
)

var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "http_response_time_seconds",
Help: "Duration of HTTP requests.",
Buckets: prometheus.DefBuckets,
}, []string{"model", "status_code"})

func MustRegister(r prometheus.Registerer) {
r.MustRegister(httpDuration)
}

type DeploymentManager interface {
ResolveDeployment(model string) (string, bool)
AtLeastOne(model string)
}

type EndpointManager interface {
AwaitHostAddress(ctx context.Context, service, portName string) (string, error)
}

type QueueManager interface {
EnqueueAndWait(ctx context.Context, deploymentName, id string) func()
}

// Handler serves http requests for end-clients.
// It is also responsible for triggering scale-from-zero.
type Handler struct {
Deployments *deployments.Manager
Endpoints *endpoints.Manager
Queues *queue.Manager
Deployments DeploymentManager
Endpoints EndpointManager
Queues QueueManager

MaxRetries int
RetryCodes map[int]struct{}
}

func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler {
return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues}
func NewHandler(
deployments DeploymentManager,
endpoints EndpointManager,
queues QueueManager,
) *Handler {
return &Handler{
Deployments: deployments,
Endpoints: endpoints,
Queues: queues,
}
}

var defaultRetryCodes = map[int]struct{}{
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var modelName string
captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w)
w = captureStatusRespWriter
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v)
}))
defer timer.ObserveDuration()

id := uuid.New().String()
log.Printf("request: %v", r.URL)
log.Printf("url: %v", r.URL)

w.Header().Set("X-Proxy", "lingo")

var (
proxyRequest *http.Request
err error
)
pr := newProxyRequest(r)
defer pr.done()

// TODO: Only parse model for paths that would have a model.
modelName, proxyRequest, err = parseModel(r)
if err != nil || modelName == "" {
modelName = "unknown"
log.Printf("error reading model from request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Bad request: unable to parse .model from JSON payload"))
if err := pr.parseModel(); err != nil {
pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err)
return
}
log.Println("model:", modelName)

deploy, found := h.Deployments.ResolveDeployment(modelName)
if !found {
log.Printf("deployment not found for model: %v", err)
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName)))
log.Println("model:", pr.model)

var backendExists bool
pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model)
if !backendExists {
pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.model)
return
}

h.Deployments.AtLeastOne(deploy)
// Ensure the backend is scaled to at least one Pod.
h.Deployments.AtLeastOne(pr.backendDeployment)

log.Println("Entering queue", id)
complete := h.Queues.EnqueueAndWait(r.Context(), deploy, id)
log.Println("Admitted into queue", id)
log.Printf("Entering queue: %v", pr.id)

// Wait to until the request is admitted into the queue before proceeding with
// serving the request.
complete := h.Queues.EnqueueAndWait(r.Context(), pr.backendDeployment, pr.id)
defer complete()

// abort when deployment was removed meanwhile
if _, exists := h.Deployments.ResolveDeployment(modelName); !exists {
log.Printf("deployment not active for model removed: %v", err)
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName)))
log.Printf("Admitted into queue: %v", pr.id)

// After waiting for the request to be admitted, double check that the model
// still exists. It's possible that the model was deleted while waiting.
// This would lead to a long subequent wait with the host lookup.
pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model)
if !backendExists {
pr.sendErrorResponse(w, http.StatusNotFound, "model not found after being dequeued: %v", pr.model)
return
}

log.Println("Waiting for IPs", id)
host, err := h.Endpoints.AwaitHostAddress(r.Context(), deploy, "http")
h.proxyHTTP(w, pr)
}

// AdditionalProxyRewrite is an injection point for modifying proxy requests.
// Used in tests.
var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {}

func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
log.Printf("Waiting for host: %v", pr.id)

host, err := h.Endpoints.AwaitHostAddress(pr.r.Context(), pr.backendDeployment, "http")
if err != nil {
log.Printf("error while finding the host address %v", err)
switch {
case errors.Is(err, context.Canceled):
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("Request cancelled"))
pr.sendErrorResponse(w, http.StatusInternalServerError, "request cancelled while finding host: %v", err)
return
case errors.Is(err, context.DeadlineExceeded):
w.WriteHeader(http.StatusGatewayTimeout)
_, _ = w.Write([]byte(fmt.Sprintf("Request timed out for model: %v", modelName)))
pr.sendErrorResponse(w, http.StatusGatewayTimeout, "request timeout while finding host: %v", err)
return
default:
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("Internal server error"))
pr.sendErrorResponse(w, http.StatusGatewayTimeout, "unable to find host: %v", err)
return
}
}
log.Printf("Got host: %v, id: %v\n", host, id)

// TODO: Avoid creating new reverse proxies for each request.
// TODO: Consider implementing a round robin scheme.
log.Printf("Proxying request to host %v: %v\n", host, id)
newReverseProxy(host).ServeHTTP(w, proxyRequest)
}
log.Printf("Got host: %v, id: %v\n", host, pr.id)

// parseModel parses the model name from the request
// returns empty string when none found or an error for failures on the proxy request object
func parseModel(r *http.Request) (string, *http.Request, error) {
if model := r.Header.Get("X-Model"); model != "" {
return model, r, nil
}
// parse request body for model name, ignore errors
body, err := io.ReadAll(r.Body)
if err != nil {
return "", r, nil
}

var payload struct {
Model string `json:"model"`
}
var model string
if err := json.Unmarshal(body, &payload); err == nil {
model = payload.Model
}

// create new request object
proxyReq, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), bytes.NewReader(body))
if err != nil {
return "", nil, fmt.Errorf("create proxy request: %w", err)
}
proxyReq.Header = r.Header
if err := proxyReq.ParseForm(); err != nil {
return "", nil, fmt.Errorf("parse proxy form: %w", err)
}
return model, proxyReq, nil
}

// AdditionalProxyRewrite is an injection point for modifying proxy requests.
// Used in tests.
var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {}

func newReverseProxy(host string) *httputil.ReverseProxy {
proxy := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(&url.URL{
Expand All @@ -158,5 +145,52 @@ func newReverseProxy(host string) *httputil.ReverseProxy {
AdditionalProxyRewrite(r)
},
}
return proxy

proxy.ModifyResponse = func(r *http.Response) error {
// Record the response for metrics.
pr.status = r.StatusCode

// This point is reached if a response code is received.
if h.isRetryCode(r.StatusCode) && pr.attempt < h.MaxRetries {
// Returning an error will trigger the ErrorHandler.
return ErrRetry
}

return nil
}

proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
// This point could be reached if a bad response code was sent by the backend
// or
// if there was an issue with the connection and no response was ever received.
if err != nil &&
r.Context().Err() == nil &&
pr.attempt < h.MaxRetries {
pr.attempt++

log.Printf("Retrying request (%v/%v): %v", pr.attempt, h.MaxRetries, pr.id)
h.proxyHTTP(w, pr)
return
}

if !errors.Is(err, ErrRetry) {
pr.sendErrorResponse(w, http.StatusBadGateway, "proxy: exceeded retries: %v/%v", pr.attempt, h.MaxRetries)
}
}

log.Printf("Proxying request to host %v: %v\n", host, pr.id)
proxy.ServeHTTP(w, pr.httpRequest())
}

var ErrRetry = errors.New("retry")

func (h *Handler) isRetryCode(status int) bool {
var retry bool
// TODO: avoid the nil check here and set a default map in the constructor.
if h.RetryCodes != nil {
_, retry = h.RetryCodes[status]
} else {
_, retry = defaultRetryCodes[status]
}
return retry
}
Loading
Loading