Skip to content

Commit

Permalink
Refactor request parsing into apiutils
Browse files Browse the repository at this point in the history
  • Loading branch information
nstogner committed Dec 7, 2024
1 parent 219bd17 commit 7a8e3eb
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 204 deletions.
153 changes: 152 additions & 1 deletion internal/apiutils/requests.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
package apiutils

import "strings"
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"strings"

"github.com/google/uuid"
)

const (
// adapterSeparator is the separator used to split model and adapter names
Expand Down Expand Up @@ -35,3 +46,143 @@ func MergeModelAdapter(model, adapter string) string {
}
return model + adapterSeparator + adapter
}

type Request struct {
Body []byte

Selectors []string

ID string

// RequestedModel is the model name requested by the client.
// This might contain the adapter name as well.
RequestedModel string

Model string
Adapter string

ContentLength int64
}

func ParseRequest(body io.Reader, headers http.Header) (*Request, error) {
pr := &Request{
ID: uuid.New().String(),
}

pr.Selectors = headers.Values("X-Label-Selector")

// Parse media type (with params - which are used for multipart form data)
var (
contentType = headers.Get("Content-Type")
mediaType string
mediaParams map[string]string
)
if contentType == "" {
mediaType = "application/json"
mediaParams = map[string]string{}
} else {
var err error
mediaType, mediaParams, err = mime.ParseMediaType(contentType)
if err != nil {
return nil, fmt.Errorf("parse media type: %w", err)
}
}

switch mediaType {
// Multipart form data is used for endpoints that accept file uploads:
case "multipart/form-data":
boundary := mediaParams["boundary"]
if boundary == "" {
return nil, fmt.Errorf("no boundary specified in multipart form data")
}

var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
// Keep the same boundary as the initial request (probably not necessary)
mw.SetBoundary(boundary)

// Iterate over the parts of the multipart form data:
// - If the part is named "model", save the value to the proxy request.
// - Otherwise, just copy the part to the new multipart writer.
mr := multipart.NewReader(body, boundary)
for {
p, err := mr.NextPart()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("interating over multipart form: %w", err)
}

if p.FormName() == "model" {
value, err := io.ReadAll(p)
if err != nil {
return nil, fmt.Errorf("reading multipart form value: %w", err)
}
pr.Model, pr.Adapter = SplitModelAdapter(string(value))
pr.RequestedModel = string(value)
// WORKAROUND ALERT:
// Omit the "model" field from the proxy request to avoid FasterWhisper validation issues:
// See https://github.com/fedirz/faster-whisper-server/issues/71
continue
}

// Copy the part to the new multipart writer.
pp, err := mw.CreatePart(p.Header)
if err != nil {
return nil, fmt.Errorf("creating part: %w", err)
}
if _, err := io.Copy(pp, p); err != nil {
return nil, fmt.Errorf("copying part: %w", err)
}
}

// Fully write to buffer.
if err := mw.Close(); err != nil {
return nil, fmt.Errorf("closing multipart writer: %w", err)
}
pr.Body = buf.Bytes()
// Set a new content length based on the new body - which had the "model" field removed.
pr.ContentLength = int64(len(pr.Body))

// Assume "application/json":
default:
if err := pr.readModelFromBody(body); err != nil {
return nil, fmt.Errorf("reading model from body: %w", err)
}
}

return pr, nil
}

func (pr *Request) readModelFromBody(r io.Reader) error {
var payload map[string]interface{}
if err := json.NewDecoder(r).Decode(&payload); err != nil {
return fmt.Errorf("decoding: %w", err)
}
modelInf, ok := payload["model"]
if !ok {
return fmt.Errorf("missing 'model' field")
}
modelStr, ok := modelInf.(string)
if !ok {
return fmt.Errorf("field 'model' should be a string")
}

pr.RequestedModel = modelStr
pr.Model, pr.Adapter = SplitModelAdapter(modelStr)

if pr.Adapter != "" {
// vLLM expects the adapter to be in the model field.
payload["model"] = pr.Adapter
}

body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("remarshalling: %w", err)
}
pr.Body = body
pr.ContentLength = int64(len(pr.Body))

return nil
}
57 changes: 16 additions & 41 deletions internal/messenger/messenger.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,32 +200,32 @@ func (m *Messenger) handleRequest(ctx context.Context, msg *pubsub.Message) {
}

metricAttrs := metric.WithAttributeSet(attribute.NewSet(
metrics.AttrRequestModel.String(req.model),
metrics.AttrRequestModel.String(req.Model),
metrics.AttrRequestType.String(metrics.AttrRequestTypeMessage),
))
metrics.InferenceRequestsActive.Add(ctx, 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(ctx, -1, metricAttrs)

modelExists, err := m.modelScaler.LookupModel(ctx, req.model, req.adapter, nil)
modelExists, err := m.modelScaler.LookupModel(ctx, req.Model, req.Adapter, nil)
if err != nil {
m.sendResponse(req, m.jsonError("error checking if model exists: %v", err), http.StatusInternalServerError)
return
}
if !modelExists {
// Send a 400 response to the client, however it is possible the backend
// will be deployed soon or another subscriber will handle it.
m.sendResponse(req, m.jsonError("model not found: %s", req.model), http.StatusNotFound)
m.sendResponse(req, m.jsonError("model not found: %s", req.RequestedModel), http.StatusNotFound)
return
}

// Ensure the backend is scaled to at least one Pod.
m.modelScaler.ScaleAtLeastOneReplica(ctx, req.model)
m.modelScaler.ScaleAtLeastOneReplica(ctx, req.Model)

log.Printf("Awaiting host for message %s", msg.LoggableID)

host, completeFunc, err := m.loadBalancer.AwaitBestAddress(ctx, loadbalancer.AddressRequest{
Model: req.model,
Adapter: req.adapter,
Model: req.Model,
Adapter: req.Adapter,
// TODO: Prefix
})
if err != nil {
Expand All @@ -236,7 +236,7 @@ func (m *Messenger) handleRequest(ctx context.Context, msg *pubsub.Message) {

url := fmt.Sprintf("http://%s%s", host, req.path)
log.Printf("Sending request to backend for message %s: %s", msg.LoggableID, url)
respPayload, respCode, err := m.sendBackendRequest(ctx, url, req.body)
respPayload, respCode, err := m.sendBackendRequest(ctx, url, req.Body)
if err != nil {
m.sendResponse(req, m.jsonError("error sending request to backend: %v", err), http.StatusBadGateway)
return
Expand All @@ -250,14 +250,11 @@ func (m *Messenger) Stop(ctx context.Context) error {
}

type request struct {
ctx context.Context
msg *pubsub.Message
metadata map[string]interface{}
path string
body json.RawMessage
requestedModel string
model string
adapter string
ctx context.Context
*apiutils.Request
msg *pubsub.Message
metadata map[string]interface{}
path string
}

func parseRequest(ctx context.Context, msg *pubsub.Message) (*request, error) {
Expand Down Expand Up @@ -285,34 +282,12 @@ func parseRequest(ctx context.Context, msg *pubsub.Message) (*request, error) {

req.metadata = payload.Metadata
req.path = path
req.body = payload.Body

var payloadBody map[string]interface{}
if err := json.Unmarshal(payload.Body, &payloadBody); err != nil {
return nil, fmt.Errorf("decoding: %w", err)
}
modelInf, ok := payloadBody["model"]
if !ok {
return nil, fmt.Errorf("missing '.body.model' field")
}
modelStr, ok := modelInf.(string)
if !ok {
return nil, fmt.Errorf("field '.body.model' should be a string")
}

req.requestedModel = modelStr
req.model, req.adapter = apiutils.SplitModelAdapter(modelStr)

// Assuming this is a vLLM request.
// vLLM expects the adapter to be in the model field.
if req.adapter != "" {
payloadBody["model"] = req.adapter
rewrittenBody, err := json.Marshal(payloadBody)
if err != nil {
return nil, fmt.Errorf("remarshalling: %w", err)
}
req.body = rewrittenBody
apiR, err := apiutils.ParseRequest(bytes.NewReader(payload.Body), http.Header{})
if err != nil {
return nil, fmt.Errorf("parsing request: %w", err)
}
req.Request = apiR

return req, nil
}
Expand Down
31 changes: 15 additions & 16 deletions internal/modelproxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,34 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

w.Header().Set("X-Proxy", "lingo")

pr := newProxyRequest(r)

// TODO: Only parse model for paths that would have a model.
if err := pr.parse(); err != nil {
pr, err := newProxyRequest(r)
if err != nil {
pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err)
return
}

log.Println("model:", pr.model, "adapter:", pr.adapter)
log.Println("model:", pr.Model, "adapter:", pr.Adapter)

metricAttrs := metric.WithAttributeSet(attribute.NewSet(
metrics.AttrRequestModel.String(pr.requestedModel),
metrics.AttrRequestModel.String(pr.RequestedModel),
metrics.AttrRequestType.String(metrics.AttrRequestTypeHTTP),
))
metrics.InferenceRequestsActive.Add(pr.r.Context(), 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(pr.r.Context(), -1, metricAttrs)
metrics.InferenceRequestsActive.Add(pr.http.Context(), 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(pr.http.Context(), -1, metricAttrs)

modelExists, err := h.modelScaler.LookupModel(r.Context(), pr.model, pr.adapter, pr.selectors)
modelExists, err := h.modelScaler.LookupModel(r.Context(), pr.Model, pr.Adapter, pr.Selectors)
if err != nil {
pr.sendErrorResponse(w, http.StatusInternalServerError, "unable to resolve model: %v", err)
return
}
if !modelExists {
pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.requestedModel)
pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.RequestedModel)
return
}

// Ensure the backend is scaled to at least one Pod.
if err := h.modelScaler.ScaleAtLeastOneReplica(r.Context(), pr.model); err != nil {
if err := h.modelScaler.ScaleAtLeastOneReplica(r.Context(), pr.Model); err != nil {
pr.sendErrorResponse(w, http.StatusInternalServerError, "unable to scale model: %v", err)
return
}
Expand All @@ -99,11 +98,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {}

func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
log.Printf("Waiting for host: %v", pr.id)
log.Printf("Waiting for host: %v", pr.ID)

addr, decrementInflight, err := h.loadBalancer.AwaitBestAddress(pr.r.Context(), loadbalancer.AddressRequest{
Model: pr.model,
Adapter: pr.adapter,
addr, decrementInflight, err := h.loadBalancer.AwaitBestAddress(pr.http.Context(), loadbalancer.AddressRequest{
Model: pr.Model,
Adapter: pr.Adapter,
// TODO: Prefix
})
if err != nil {
Expand Down Expand Up @@ -153,7 +152,7 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
if err != nil && r.Context().Err() == nil && pr.attempt < h.maxRetries {
pr.attempt++

log.Printf("Retrying request (%v/%v): %v: %v", pr.attempt, h.maxRetries, pr.id, err)
log.Printf("Retrying request (%v/%v): %v: %v", pr.attempt, h.maxRetries, pr.ID, err)
h.proxyHTTP(w, pr)
return
}
Expand All @@ -163,7 +162,7 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
}
}

log.Printf("Proxying request to ip %v: %v\n", addr, pr.id)
log.Printf("Proxying request to ip %v: %v\n", addr, pr.ID)
proxy.ServeHTTP(w, pr.httpRequest())
}

Expand Down
Loading

0 comments on commit 7a8e3eb

Please sign in to comment.