Skip to content

Commit

Permalink
Handle model undeployment (#44)
Browse files Browse the repository at this point in the history
🚧 spike

Simple version to fail the request fast when a model backed was
undeployed while queued.

Not handled is the case when a model was removed from the deployment
annotation
  • Loading branch information
alpe authored Dec 21, 2023
1 parent 00a7e4d commit bc4ba0d
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 11 deletions.
22 changes: 20 additions & 2 deletions pkg/deployments/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (

appsv1 "k8s.io/api/apps/v1"
autoscalingv1 "k8s.io/api/autoscaling/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"

ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
)
Expand Down Expand Up @@ -65,7 +65,11 @@ func (r *Manager) SetDesiredScale(deploymentName string, n int32) {

func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
var d appsv1.Deployment
if err := r.Get(ctx, req.NamespacedName, &d); err != nil {
switch err := r.Get(ctx, req.NamespacedName, &d); {
case apierrors.IsNotFound(err):
r.removeDeployment(req)
return ctrl.Result{}, nil
case err != nil:
return ctrl.Result{}, fmt.Errorf("get: %w", err)
}

Expand Down Expand Up @@ -98,6 +102,20 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result,
return ctrl.Result{}, nil
}

func (r *Manager) removeDeployment(req ctrl.Request) {
r.scalersMtx.Lock()
delete(r.scalers, req.Name)
r.scalersMtx.Unlock()

r.modelToDeploymentMtx.Lock()
for model, deployment := range r.modelToDeployment {
if deployment == req.Name {
delete(r.modelToDeployment, model)
}
}
r.modelToDeploymentMtx.Unlock()
}

func (r *Manager) getScaler(deploymentName string) *scaler {
r.scalersMtx.Lock()
b, ok := r.scalers[deploymentName]
Expand Down
8 changes: 8 additions & 0 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Println("Admitted into queue", id)
defer complete()

// abort when deployment was removed meanwhile
if _, exists := h.Deployments.ResolveDeployment(modelName); !exists {
log.Printf("deployment not active for model removed: %v", err)
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName)))
return
}

log.Println("Waiting for IPs", id)
host := h.Endpoints.GetHost(r.Context(), deploy, "http")
log.Printf("Got host: %v, id: %v\n", host, id)
Expand Down
85 changes: 77 additions & 8 deletions tests/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@ import (
"testing"
"time"

"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
disv1 "k8s.io/api/discovery/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/utils/ptr"
)

func TestIntegration(t *testing.T) {
func TestScaleUpAndDown(t *testing.T) {
const modelName = "test-model-a"
deploy := testDeployment(modelName)

Expand Down Expand Up @@ -57,7 +60,7 @@ func TestIntegration(t *testing.T) {

// Send request number 1
var wg sync.WaitGroup
sendRequests(t, &wg, modelName, 1)
sendRequests(t, &wg, modelName, 1, http.StatusOK)

requireDeploymentReplicas(t, deploy, 1)
require.Equal(t, int32(1), backendRequests.Load(), "ensure the request made its way to the backend")
Expand All @@ -66,11 +69,11 @@ func TestIntegration(t *testing.T) {
// Ensure the deployment scaled scaled past 1.
// 1/2 should be admitted
// 1/2 should remain in queue
sendRequests(t, &wg, modelName, 2)
sendRequests(t, &wg, modelName, 2, http.StatusOK)
requireDeploymentReplicas(t, deploy, 2)

// Make sure deployment will not be scaled past default max (3).
sendRequests(t, &wg, modelName, 2)
sendRequests(t, &wg, modelName, 2, http.StatusOK)
requireDeploymentReplicas(t, deploy, 3)

// Have the mock backend respond to the remaining 4 requests.
Expand All @@ -83,6 +86,71 @@ func TestIntegration(t *testing.T) {
wg.Wait()
}

func TestHandleModelUndeployment(t *testing.T) {
const modelName = "test-model-b"
deploy := testDeployment(modelName)

require.NoError(t, testK8sClient.Create(testCtx, deploy))

backendComplete := make(chan struct{})

backendRequests := &atomic.Int32{}
testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("Serving request from testBackend")
backendRequests.Add(1)
<-backendComplete
w.WriteHeader(200)
}))

// Mock an EndpointSlice.
testBackendURL, err := url.Parse(testBackend.URL)
require.NoError(t, err)
testBackendPort, err := strconv.Atoi(testBackendURL.Port())
require.NoError(t, err)
require.NoError(t, testK8sClient.Create(testCtx,
endpointSlice(
modelName,
testBackendURL.Hostname(),
int32(testBackendPort),
),
))

// Wait for deployment mapping to sync.
time.Sleep(3 * time.Second)

// Send request number 1
var wg sync.WaitGroup
// send single request to scale up and block on the handler to build a queue
sendRequests(t, &wg, modelName, 1, http.StatusOK)

requireDeploymentReplicas(t, deploy, 1)
require.Equal(t, int32(1), backendRequests.Load(), "ensure the request made its way to the backend")
// Add some more requests to the queue but with 404 expected
// because the deployment is deleted before un-queued
sendRequests(t, &wg, modelName, 2, http.StatusNotFound)

require.NoError(t, testK8sClient.Delete(testCtx, deploy))

// Check that the deployment was deleted
err = testK8sClient.Get(testCtx, client.ObjectKey{
Namespace: deploy.Namespace,
Name: deploy.Name,
}, deploy)

// ErrNotFound is desired since we delete the resource earlier
assert.True(t, apierrors.IsNotFound(err))
// release blocked request
completeRequests(backendComplete, 1)

// Wait for deployment mapping to sync.
require.Eventually(t, func() bool {
return queueManager.TotalCounts()[modelName+"-deploy"] == 0
}, 3*time.Second, 100*time.Millisecond)

t.Logf("Waiting for wait group")
wg.Wait()
}

func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) {
require.EventuallyWithT(t, func(t *assert.CollectT) {
err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy)
Expand All @@ -92,13 +160,14 @@ func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32)
}, 3*time.Second, time.Second/2, "waiting for the deployment to be scaled up")
}

func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int) {
func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, expCode int) {
for i := 0; i < n; i++ {
sendRequest(t, wg, modelName)
sendRequest(t, wg, modelName, expCode)
}
}

func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string) {
func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) {
t.Helper()
wg.Add(1)
go func() {
defer wg.Done()
Expand All @@ -109,7 +178,7 @@ func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string) {

res, err := testHTTPClient.Do(req)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
require.Equal(t, expCode, res.StatusCode)
}()
}

Expand Down
3 changes: 2 additions & 1 deletion tests/integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var (
testCancel context.CancelFunc
testServer *httptest.Server
testHTTPClient = &http.Client{Timeout: 10 * time.Second}
queueManager *queue.Manager
)

func init() {
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestMain(m *testing.M) {
requireNoError(err)

const concurrencyPerReplica = 1
queueManager := queue.NewManager(concurrencyPerReplica)
queueManager = queue.NewManager(concurrencyPerReplica)

endpointManager, err := endpoints.NewManager(mgr)
requireNoError(err)
Expand Down

0 comments on commit bc4ba0d

Please sign in to comment.