Skip to content

Commit

Permalink
Merge branch 'main' into feat/add-opentelemetry
Browse files Browse the repository at this point in the history
  • Loading branch information
SamMHD committed Feb 24, 2024
2 parents b21450d + ca84c80 commit e9ae4d2
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 34 deletions.
41 changes: 20 additions & 21 deletions pkg/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ func (a *Authenticator) TestAccess(request *Request, wsvc WebservicesCacheEntry)
return
}
}
reason = CerberusReasonOK
return
}

Expand All @@ -116,17 +115,17 @@ func (a *Authenticator) readToken(request *Request, wsvc WebservicesCacheEntry)

// readService reads requested webservice from cache and
// will return error if the object would not be found in cache
func (a *Authenticator) readService(wsvc string) (bool, CerberusReason, WebservicesCacheEntry) {
func (a *Authenticator) readService(wsvc string) (CerberusReason, WebservicesCacheEntry) {
a.cacheLock.RLock()
cacheReaders.Inc()
defer a.cacheLock.RUnlock()
defer cacheReaders.Dec()

res, ok := a.webservicesCache.ReadWebservice(wsvc)
if !ok {
return false, CerberusReasonWebserviceNotFound, WebservicesCacheEntry{}
return CerberusReasonWebserviceNotFound, WebservicesCacheEntry{}
}
return true, "", res
return "", res
}

func toExtraHeaders(headers CerberusExtraHeaders) ExtraHeaders {
Expand Down Expand Up @@ -155,7 +154,7 @@ func (a *Authenticator) Check(ctx context.Context, request *Request) (*Response,
)

if reason != "" {
return generateResponse(false, reason, nil), nil
return generateResponse(reason, nil), nil
}
wsvc = v1alpha1.WebserviceReference{
Name: wsvc,
Expand All @@ -165,8 +164,8 @@ func (a *Authenticator) Check(ctx context.Context, request *Request) (*Response,
request.Context[HasUpstreamAuth] = "false"
var extraHeaders ExtraHeaders

ok, reason, wsvcCacheEntry := a.readService(wsvc)
if ok {
reason, wsvcCacheEntry := a.readService(wsvc)
if reason == "" {
var cerberusExtraHeaders CerberusExtraHeaders

// perform TestAccess
Expand All @@ -175,7 +174,7 @@ func (a *Authenticator) Check(ctx context.Context, request *Request) (*Response,
extraHeaders = toExtraHeaders(cerberusExtraHeaders)
if reason == CerberusReasonOK && hasUpstreamAuth(wsvcCacheEntry) {
request.Context[HasUpstreamAuth] = "true"
ok, reason = a.checkServiceUpstreamAuth(wsvcCacheEntry, request, &extraHeaders, ctx)
reason = a.checkServiceUpstreamAuth(wsvcCacheEntry, request, &extraHeaders, ctx)
}
}

Expand All @@ -184,7 +183,7 @@ func (a *Authenticator) Check(ctx context.Context, request *Request) (*Response,
err = status.Error(codes.DeadlineExceeded, "Timeout exceeded")
}

return generateResponse(ok, reason, extraHeaders), err
return generateResponse(reason, extraHeaders), err
}

func readRequestContext(request *Request) (wsvc string, ns string, reason CerberusReason) {
Expand Down Expand Up @@ -290,40 +289,37 @@ func processResponseError(err error) CerberusReason {

// checkServiceUpstreamAuth function is designed to validate the request through
// the upstream authentication for a given webservice
func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry, request *Request, extraHeaders *ExtraHeaders, ctx context.Context) (ok bool, reason CerberusReason) {
func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry, request *Request, extraHeaders *ExtraHeaders, ctx context.Context) (reason CerberusReason) {
downstreamDeadline, hasDownstreamDeadline := ctx.Deadline()
serviceUpstreamAuthCalls.With(AddWithDownstreamDeadline(nil, hasDownstreamDeadline)).Inc()

_, span := tracing.StartSpan(ctx, "upstream-auth")
defer func() {
span.SetAttributes(
attribute.String("upstream-auth-cerberus-reason", string(reason)),
attribute.Bool("upstream-auth-final-is-ok", ok),
)
span.End()
}()
span.SetAttributes(
attribute.String("upstream-auth-address", service.Spec.UpstreamHttpAuth.Address),
)

if reason = validateUpstreamAuthRequest(service); reason != "" {
ok = false
return
if reason := validateUpstreamAuthRequest(service); reason != "" {
return reason
}
upstreamAuth := service.Spec.UpstreamHttpAuth
req, err := setupUpstreamAuthRequest(&upstreamAuth, request)
if err != nil {
return false, CerberusReasonUpstreamAuthNoReq
return CerberusReasonUpstreamAuthNoReq
}
a.adjustTimeout(upstreamAuth.Timeout, downstreamDeadline, hasDownstreamDeadline)

reqStart := time.Now()
resp, err := a.httpClient.Do(req)
reqDuration := time.Since(reqStart)

if reason = processResponseError(err); reason != "" {
ok = false
return
if reason := processResponseError(err); reason != "" {
return reason
}

labels := AddWithDownstreamDeadline(AddStatusLabel(nil, resp.StatusCode), hasDownstreamDeadline)
Expand All @@ -335,11 +331,11 @@ func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry,
)

if resp.StatusCode != http.StatusOK {
return false, CerberusReasonUnauthorized
return CerberusReasonUnauthorized
}
// add requested careHeaders to extraHeaders for response
copyUpstreamHeaders(resp, extraHeaders, service.Spec.UpstreamHttpAuth.CareHeaders)
return true, CerberusReasonOK
return ""
}

// hasUpstreamAuth evaluates whether the provided webservice
Expand All @@ -351,10 +347,13 @@ func hasUpstreamAuth(service WebservicesCacheEntry) bool {
// generateResponse initializes defaults for cerberus http result and creates a
// valid response from cerberus reasons and computed headers to inform the client
// that it has the access or not.
func generateResponse(ok bool, reason CerberusReason, extraHeaders ExtraHeaders) *Response {
func generateResponse(reason CerberusReason, extraHeaders ExtraHeaders) *Response {
ok := (reason == "")

var httpStatusCode int
if ok {
httpStatusCode = http.StatusOK
reason = CerberusReasonOK
} else {
httpStatusCode = http.StatusUnauthorized
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/auth/authenticator_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ func (a *Authenticator) buildNewWebservicesCache(
}

a.logger.Info("webservice access cache built successfully", "len", len(newWebservicesCache))

for _, entry := range newWebservicesCache {
a.logger.Info("webservice stored", "name", entry.Name, "allowedNamespaces", entry.allowedNamespacesCache)
}

return &newWebservicesCache
}

Expand Down Expand Up @@ -199,7 +204,12 @@ func (a *Authenticator) buildNewAccessTokensCache(
}
}

a.logger.Info("webservice access cache built successfully", "len", len(newAccessTokensCache))
a.logger.Info("access token cache built successfully", "len", len(newAccessTokensCache))

for key, entry := range newAccessTokensCache {
a.logger.Info("webservice stored", "name", key, "allowedWebservices", entry.allowedWebservicesCache)
}

return &newAccessTokensCache
}

Expand Down
15 changes: 13 additions & 2 deletions pkg/auth/authenticator_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ func TestBuildNewWebservicesCache(t *testing.T) {
getBindingfromLogs := func(logs testutils.Logs) []string {
bindings := make([]string, 0)
for _, v := range logs {
bindings = append(bindings, v.KeyValues["binding"].(string))
fmt.Println(v)
if v.Message != "webservice stored" && v.Message != "webservice access cache built successfully" {
bindings = append(bindings, v.KeyValues["binding"].(string))
}
}
return bindings
}
Expand All @@ -136,7 +139,15 @@ func TestBuildNewWebservicesCache(t *testing.T) {
assert.ElementsMatch(t, bindingsNamesFromFixtures, bindingsNamesFromLog)
for _, log := range bindingLogs {
assert.Equal(t, "info", log.Type)
assert.Equal(t, "ignored some webservices over binding", log.Message)
assert.Contains(
t,
[]string{
"ignored some webservices over binding",
"webservice stored",
"webservice access cache built successfully",
},
log.Message,
)
}

assert.Len(t, *newWebservicesCache, 2)
Expand Down
2 changes: 1 addition & 1 deletion pkg/auth/authenticator_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ var _ AuthenticationValidation = (*AuthenticationTokenAccessValidation)(nil)
// Validate checks token and webservice access
func (adv *AuthenticationTokenAccessValidation) Validate(ac *AccessTokensCacheEntry,
wsvc *WebservicesCacheEntry, request *Request) (CerberusReason, CerberusExtraHeaders) {
if !ac.TestAccess(wsvc.Name) {
if !ac.TestAccess(wsvc.LocalName()) {
return CerberusReasonWebserviceNotAllowed, CerberusExtraHeaders{}
}
return CerberusReasonNotSet, CerberusExtraHeaders{}
Expand Down
2 changes: 1 addition & 1 deletion pkg/auth/authenticator_filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestAuthenticationTokenAccessValidation_Validate(t *testing.T) {
},
}
ac.allowedWebservicesCache = make(AllowedWebservicesCache)
ac.allowedWebservicesCache["test-ws"] = struct{}{}
ac.allowedWebservicesCache[wsvc.LocalName()] = struct{}{}

atcv := AuthenticationTokenAccessValidation{}

Expand Down
13 changes: 5 additions & 8 deletions pkg/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,7 @@ func TestReadService(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.wsvc, func(t *testing.T) {
ok, reason, _ := authenticator.readService(tc.wsvc)
if ok != tc.expectedOk {
t.Errorf("Expected success: %v, Got: %v", tc.expectedOk, ok)
}
reason, _ := authenticator.readService(tc.wsvc)
if reason != tc.expectedReason {
t.Errorf("Expected reason: %v, Got: %v", tc.expectedReason, reason)
}
Expand Down Expand Up @@ -241,7 +238,7 @@ func TestTestAccessValidToken(t *testing.T) {

reason, extraHeaders := authenticator.TestAccess(request, webservice)

assert.Equal(t, CerberusReasonOK, reason, "Expected reason to be OK")
assert.Equal(t, CerberusReasonNotSet, reason, "Expected reason to be OK")
assert.Equal(t, "valid-token", extraHeaders[CerberusHeaderAccessToken], "Expected token in extraHeaders")
}

Expand Down Expand Up @@ -729,11 +726,11 @@ func Test_generateResponse(t *testing.T) {
StatusCode: http.StatusOK,
Header: http.Header{
ExternalAuthHandlerHeader: {"cerberus"},
CerberusHeaderReasonHeader: {"reason"},
CerberusHeaderReasonHeader: {string(CerberusReasonOK)},
},
},
}
actualResponse := generateResponse(true, "reason", nil)
actualResponse := generateResponse("", nil)
assert.Equal(t, expectedResponse.Allow, actualResponse.Allow, "Response should be allowed")
assert.Equal(t, expectedResponse.Response.StatusCode, actualResponse.Response.StatusCode, "HTTP status code should match")
assert.Equal(t, expectedResponse.Response.Header, actualResponse.Response.Header, "Response headers should match")
Expand All @@ -751,7 +748,7 @@ func Test_generateResponse(t *testing.T) {
},
},
}
actualResponse = generateResponse(false, "reason", extraHeaders)
actualResponse = generateResponse("reason", extraHeaders)
assert.Equal(t, expectedResponse.Allow, actualResponse.Allow, "Response should not be allowed")
assert.Equal(t, expectedResponse.Response.StatusCode, actualResponse.Response.StatusCode, "HTTP status code should match")
assert.Equal(t, expectedResponse.Response.Header, actualResponse.Response.Header, "Response headers should match")
Expand Down

0 comments on commit e9ae4d2

Please sign in to comment.