Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HTTP X-Label-Selector headers to support Multitenancy #282

Merged
merged 13 commits into from
Oct 24, 2024
4 changes: 2 additions & 2 deletions internal/messenger/messenger.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func NewMessenger(
}

type ModelScaler interface {
ModelExists(ctx context.Context, model string) (bool, error)
ModelFound(ctx context.Context, model string, selectors []string) (bool, error)
ScaleAtLeastOneReplica(ctx context.Context, model string) error
}

Expand Down Expand Up @@ -204,7 +204,7 @@ func (m *Messenger) handleRequest(ctx context.Context, msg *pubsub.Message) {
metrics.InferenceRequestsActive.Add(ctx, 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(ctx, -1, metricAttrs)

modelExists, err := m.modelScaler.ModelExists(ctx, req.model)
modelExists, err := m.modelScaler.ModelFound(ctx, req.model, nil)
if err != nil {
m.sendResponse(req, m.jsonError("error checking if model exists: %v", err), http.StatusInternalServerError)
return
Expand Down
6 changes: 3 additions & 3 deletions internal/modelproxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

type ModelScaler interface {
ModelExists(ctx context.Context, model string) (bool, error)
ModelFound(ctx context.Context, model string, selectors []string) (bool, error)
nstogner marked this conversation as resolved.
Show resolved Hide resolved
ScaleAtLeastOneReplica(ctx context.Context, model string) error
}

Expand Down Expand Up @@ -60,7 +60,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pr := newProxyRequest(r)

// TODO: Only parse model for paths that would have a model.
if err := pr.parseModel(); err != nil {
if err := pr.parse(); err != nil {
pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err)
return
}
Expand All @@ -74,7 +74,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
metrics.InferenceRequestsActive.Add(pr.r.Context(), 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(pr.r.Context(), -1, metricAttrs)

modelExists, err := h.modelScaler.ModelExists(r.Context(), pr.model)
modelExists, err := h.modelScaler.ModelFound(r.Context(), pr.model, pr.selectors)
if err != nil {
pr.sendErrorResponse(w, http.StatusInternalServerError, "unable to resolve model: %v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion internal/modelproxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ type testModelInterface struct {
models map[string]string
}

func (t *testModelInterface) ModelExists(ctx context.Context, model string) (bool, error) {
func (t *testModelInterface) ModelFound(ctx context.Context, model string, selector []string) (bool, error) {
nstogner marked this conversation as resolved.
Show resolved Hide resolved
nstogner marked this conversation as resolved.
Show resolved Hide resolved
_, ok := t.models[model]
return ok, nil
}
Expand Down
9 changes: 5 additions & 4 deletions internal/modelproxy/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type proxyRequest struct {
// in order to determine the model.
body []byte

// metadata:
selectors []string

id string
status int
Expand All @@ -38,14 +38,15 @@ func newProxyRequest(r *http.Request) *proxyRequest {
}

return pr

}

// parseModel attempts to determine the model from the request.
// parse attempts to determine the model from the request.
// It first checks the "X-Model" header, and if that is not set, it
// attempts to unmarshal the request body as JSON and extract the
// .model field.
func (pr *proxyRequest) parseModel() error {
func (pr *proxyRequest) parse() error {
pr.selectors = pr.r.Header.Values("X-Selector")
nstogner marked this conversation as resolved.
Show resolved Hide resolved

// Try to get the model from the header first
if headerModel := pr.r.Header.Get("X-Model"); headerModel != "" {
pr.model = headerModel
Expand Down
21 changes: 19 additions & 2 deletions internal/modelscaler/scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
kubeaiv1 "github.com/substratusai/kubeai/api/v1"
autoscalingv1 "k8s.io/api/autoscaling/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)
Expand All @@ -26,13 +27,29 @@ func NewModelScaler(client client.Client, namespace string) *ModelScaler {
return &ModelScaler{client: client, namespace: namespace, consecutiveScaleDowns: map[string]int{}}
}

func (s *ModelScaler) ModelExists(ctx context.Context, model string) (bool, error) {
if err := s.client.Get(ctx, types.NamespacedName{Name: model, Namespace: s.namespace}, &kubeaiv1.Model{}); err != nil {
func (s *ModelScaler) ModelFound(ctx context.Context, model string, labelSelectors []string) (bool, error) {
m := &kubeaiv1.Model{}
if err := s.client.Get(ctx, types.NamespacedName{Name: model, Namespace: s.namespace}, m); err != nil {
if apierrors.IsNotFound(err) {
return false, nil
}
return false, err
}

lbls := m.GetLabels()
nstogner marked this conversation as resolved.
Show resolved Hide resolved
if lbls == nil {
lbls = map[string]string{}
}
for _, sel := range labelSelectors {
parsedSel, err := labels.Parse(sel)
if err != nil {
return false, fmt.Errorf("parse label selector: %w", err)
}
if !parsedSel.Matches(labels.Set(lbls)) {
return false, nil
}
}

return true, nil
}

Expand Down
18 changes: 16 additions & 2 deletions internal/openaiserver/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

kubeaiv1 "github.com/substratusai/kubeai/api/v1"
"k8s.io/apimachinery/pkg/labels"
"sigs.k8s.io/controller-runtime/pkg/client"
)

Expand All @@ -22,14 +23,27 @@ func (h *Handler) getModels(w http.ResponseWriter, r *http.Request) {
features = []string{kubeaiv1.ModelFeatureTextGeneration}
}

var listOpts []client.ListOption
headerSelectors := r.Header.Values("X-Selector")
for _, sel := range headerSelectors {
parsedSel, err := labels.Parse(sel)
if err != nil {
sendErrorResponse(w, http.StatusBadRequest, "failed to parse label selector: %v", err)
return
}
listOpts = append(listOpts, client.MatchingLabelsSelector{Selector: parsedSel})
}

var k8sModels []kubeaiv1.Model
k8sModelNames := map[string]struct{}{}
for _, feature := range features {
// NOTE(nstogner): Could not find a way to do an OR query with the client,
// NOTE: At time of writing an OR query is not supported with the
// Kubernetes API server
// so we just do multiple queries and merge the results.
labelSelector := client.MatchingLabels{kubeaiv1.ModelFeatureLabelDomain + "/" + feature: "true"}
list := &kubeaiv1.ModelList{}
if err := h.K8sClient.List(r.Context(), list, labelSelector); err != nil {
opts := append([]client.ListOption{labelSelector}, listOpts...)
if err := h.K8sClient.List(r.Context(), list, opts...); err != nil {
sendErrorResponse(w, http.StatusInternalServerError, "failed to list models: %v", err)
return
}
Expand Down
13 changes: 7 additions & 6 deletions test/integration/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,32 +55,33 @@ func TestProxy(t *testing.T) {
// Wait for controller cache to sync.
time.Sleep(3 * time.Second)

// Send request number 1
var wg sync.WaitGroup
sendRequests(t, &wg, m.Name, 1, http.StatusOK, "request 1")

// Send request number 1
sendRequests(t, &wg, m.Name, nil, 1, http.StatusOK, "", "request 1")

requireModelReplicas(t, m, 1, "Replicas should be scaled up to 1 to process messaging request", 5*time.Second)
requireModelPods(t, m, 1, "Pod should be created for the messaging request", 5*time.Second)
markAllModelPodsReady(t, m)
completeRequests(backendComplete, 1)
closeChannels(backendComplete, 1)
require.Equal(t, int32(1), totalBackendRequests.Load(), "ensure the request made its way to the backend")

const autoscaleUpWait = 25 * time.Second
// Ensure the deployment is autoscaled past 1.
// Simulate the backend processing the request.
sendRequests(t, &wg, m.Name, 2, http.StatusOK, "request 2,3")
sendRequests(t, &wg, m.Name, nil, 2, http.StatusOK, "", "request 2,3")
requireModelReplicas(t, m, 2, "Replicas should be scaled up to 2 to process pending messaging request", autoscaleUpWait)
requireModelPods(t, m, 2, "2 Pods should be created for the messaging requests", 5*time.Second)
markAllModelPodsReady(t, m)

// Make sure deployment will not be scaled past max (3).
sendRequests(t, &wg, m.Name, 2, http.StatusOK, "request 4,5")
sendRequests(t, &wg, m.Name, nil, 2, http.StatusOK, "", "request 4,5")
require.Never(t, func() bool {
assert.NoError(t, testK8sClient.Get(testCtx, client.ObjectKeyFromObject(m), m))
return *m.Spec.Replicas > *m.Spec.MaxReplicas
}, autoscaleUpWait, time.Second/10, "Replicas should not be scaled past MaxReplicas")

completeRequests(backendComplete, 4)
closeChannels(backendComplete, 4)
require.Equal(t, int32(5), totalBackendRequests.Load(), "ensure all the requests made their way to the backend")

// Ensure the deployment is autoscaled back down to MinReplicas.
Expand Down
Loading
Loading