From e89a0495a656949481555f9589f8fbd05736fb5f Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 10 Jan 2023 18:23:28 -0800 Subject: [PATCH 1/4] Store additional claims in the QueryUserInfoFromAccessToken path Signed-off-by: Haytham Abuelfutuh --- auth/handlers.go | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index d3e451295..d3199d621 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -8,6 +8,9 @@ import ( "strings" "time" + _struct "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/protobuf/encoding/protojson" + "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" @@ -410,16 +413,41 @@ func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Re userInfo, err := authCtx.OidcProvider().UserInfo(ctx, tokenSource) if err != nil { logger.Errorf(ctx, "Error getting user info from IDP %s", err) - return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err) } resp := &service.UserInfoResponse{} err = userInfo.Claims(&resp) if err != nil { logger.Errorf(ctx, "Error getting user info from IDP %s", err) - return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err) + } + + allClaims := make(map[string]any, 10) + err = userInfo.Claims(&allClaims) + if err != nil { + logger.Errorf(ctx, "Error unmarshalling raw claims %s", err) + return &service.UserInfoResponse{}, fmt.Errorf("error unmarshalling raw claims. Error: %w", err) } + alreadyRead := []string{"subject", "name", "preferred_username", "given_name", "family_name", "email", "picture"} + for _, existing := range alreadyRead { + delete(allClaims, existing) + } + + var response _struct.Struct + b, err := json.Marshal(allClaims) + if err != nil { + return &service.UserInfoResponse{}, fmt.Errorf("failed to marshal additional claims to json. Error: %w", err) + } + + err = protojson.Unmarshal(b, &response) + if err != nil { + return nil, fmt.Errorf("failed to unamarshal additional claims to proto.struct. Error: %w", err) + } + + resp.AdditionalClaims = &response + return resp, err } From 1e12d93aa6d065d4595f9ed9ccf914d3f4e89e8a Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 10 Jan 2023 18:55:23 -0800 Subject: [PATCH 2/4] Add logs Signed-off-by: Haytham Abuelfutuh --- auth/handlers.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/auth/handlers.go b/auth/handlers.go index d3199d621..26d52092a 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -321,6 +321,8 @@ func GetHTTPRequestCookieToMetadataHandler(authCtx interfaces.AuthenticationCont logger.Infof(ctx, "Failed to retrieve user info cookie. Ignoring. Error: %v", err) } + logger.Debugf(ctx, "Retrieved [%v] Additional Claims: [%+v]", len(userInfo.AdditionalClaims.AsMap()), userInfo.AdditionalClaims.AsMap()) + raw, err := json.Marshal(userInfo) if err != nil { logger.Infof(ctx, "Failed to marshal user info. Ignoring. Error: %v", err) @@ -376,6 +378,8 @@ func IdentityContextFromRequest(ctx context.Context, req *http.Request, authCtx return nil, fmt.Errorf("unauthenticated request. Error: %w", err) } + logger.Debugf(ctx, "Retrieved2 [%v] Additional Claims: [%+v]", len(userInfo.AdditionalClaims.AsMap()), userInfo.AdditionalClaims.AsMap()) + return IdentityContextFromIDTokenToken(ctx, idToken, authCtx.Options().UserAuth.OpenID.ClientID, authCtx.OidcProvider(), userInfo) } @@ -430,11 +434,15 @@ func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Re return &service.UserInfoResponse{}, fmt.Errorf("error unmarshalling raw claims. Error: %w", err) } + logger.Debugf(ctx, "Unmarshalled a total of [%v] claims: [%+v]", len(allClaims), allClaims) + alreadyRead := []string{"subject", "name", "preferred_username", "given_name", "family_name", "email", "picture"} for _, existing := range alreadyRead { delete(allClaims, existing) } + logger.Debugf(ctx, "Remaining a total of [%v] additional claims: [%+v]", len(allClaims), allClaims) + var response _struct.Struct b, err := json.Marshal(allClaims) if err != nil { From d6e1525e8a7ccf1b1d3b4198e4b28c3a86dc883d Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 10 Jan 2023 19:31:07 -0800 Subject: [PATCH 3/4] More logs Signed-off-by: Haytham Abuelfutuh --- auth/handlers.go | 1 + 1 file changed, 1 insertion(+) diff --git a/auth/handlers.go b/auth/handlers.go index 26d52092a..0214a15ea 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -485,6 +485,7 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler { return func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error { info, ok := m.(*service.UserInfoResponse) if ok { + logger.Debugf(ctx, "GetUserInfoForwardResponseHandler: Additional claims: [%+v]", info.AdditionalClaims) if info.AdditionalClaims != nil { for k, v := range info.AdditionalClaims.GetFields() { jsonBytes, err := v.MarshalJSON() From e1134a49360c356ad9d5652b7a2b4aaa45118d3d Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 10 Jan 2023 19:40:09 -0800 Subject: [PATCH 4/4] More logs Signed-off-by: Haytham Abuelfutuh --- auth/handlers.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auth/handlers.go b/auth/handlers.go index 0214a15ea..7a70a019b 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -485,7 +485,7 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler { return func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error { info, ok := m.(*service.UserInfoResponse) if ok { - logger.Debugf(ctx, "GetUserInfoForwardResponseHandler: Additional claims: [%+v]", info.AdditionalClaims) + logger.Debugf(ctx, "GetUserInfoForwardResponseHandler: Additional claims: [%+v]", info.AdditionalClaims.GetFields()) if info.AdditionalClaims != nil { for k, v := range info.AdditionalClaims.GetFields() { jsonBytes, err := v.MarshalJSON() @@ -494,6 +494,7 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler { continue } header := fmt.Sprintf("X-User-Claim-%s", strings.ReplaceAll(k, "_", "-")) + logger.Debugf(ctx, "Setting header [%v: %v]", header, string(jsonBytes)) w.Header().Set(header, string(jsonBytes)) } }