Skip to content

Commit

Permalink
frontend: Remove SubscriptionFromContext
Browse files Browse the repository at this point in the history
No longer used apart from middleware testing.  Refactored tests
to extract subscription state from the in-memory cache.
  • Loading branch information
Matthew Barnes committed Sep 18, 2024
1 parent c66bb65 commit fa463f0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 33 deletions.
15 changes: 0 additions & 15 deletions frontend/pkg/frontend/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ const (
contextKeyResourceID
contextKeyCorrelationData
contextKeySystemData
contextKeySubscription
)

func ContextWithOriginalPath(ctx context.Context, originalPath string) context.Context {
Expand Down Expand Up @@ -138,17 +137,3 @@ func SystemDataFromContext(ctx context.Context) (*arm.SystemData, error) {
}
return systemData, nil
}

func ContextWithSubscription(ctx context.Context, subscription arm.Subscription) context.Context {
return context.WithValue(ctx, contextKeySubscription, subscription)
}

func SubscriptionFromContext(ctx context.Context) (arm.Subscription, error) {
sub, ok := ctx.Value(contextKeySubscription).(arm.Subscription)
if !ok {
return arm.Subscription{}, &ContextError{
got: sub,
}
}
return sub, nil
}
2 changes: 0 additions & 2 deletions frontend/pkg/frontend/middleware_validatesubscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ func (s *SubscriptionStateMuxValidator) MiddlewareValidateSubscriptionState(w ht
}
}

ctx := ContextWithSubscription(r.Context(), *sub.Subscription)
r = r.WithContext(ctx)
switch sub.Subscription.State {
case arm.SubscriptionStateRegistered:
next(w, r)
Expand Down
31 changes: 15 additions & 16 deletions frontend/pkg/frontend/middleware_validatesubscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -190,22 +189,22 @@ func TestMiddlewareValidateSubscription(t *testing.T) {
}

middleware.MiddlewareValidateSubscriptionState(writer, request, next)
sub, err := SubscriptionFromContext(request.Context())
if err != nil {
if tt.expectedError != nil {
var actualError *arm.CloudError
body, _ := io.ReadAll(http.MaxBytesReader(writer, writer.Result().Body, 4*megabyte))
_ = json.Unmarshal(body, &actualError)
if (writer.Result().StatusCode != tt.expectedError.StatusCode) || actualError.Code != tt.expectedError.Code || actualError.Message != tt.expectedError.Message {
t.Errorf("unexpected CloudError, wanted %v, got %v", tt.expectedError, actualError)
}
} else {
t.Errorf("expected CloudError, wanted %v, got %v", tt.expectedError, err)
}
}

if !reflect.DeepEqual(sub.State, tt.expectedState) {
t.Error(cmp.Diff(sub.State, tt.expectedState))
if tt.expectedError != nil {
var actualError *arm.CloudError
body, _ := io.ReadAll(http.MaxBytesReader(writer, writer.Result().Body, 4*megabyte))
_ = json.Unmarshal(body, &actualError)
if (writer.Result().StatusCode != tt.expectedError.StatusCode) || actualError.Code != tt.expectedError.Code || actualError.Message != tt.expectedError.Message {
t.Errorf("unexpected CloudError, wanted %v, got %v", tt.expectedError, actualError)
}
} else {
doc, err := dbClient.GetSubscriptionDoc(request.Context(), subscriptionId)
if err != nil {
t.Fatal(err.Error())
}
if doc.Subscription.State != tt.expectedState {
t.Error(cmp.Diff(doc.Subscription.State, tt.expectedState))
}
}
})
}
Expand Down

0 comments on commit fa463f0

Please sign in to comment.