diff --git a/auth/handlers.go b/auth/handlers.go index d3e451295..7a70a019b 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -8,6 +8,9 @@ import ( "strings" "time" + _struct "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/protobuf/encoding/protojson" + "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" @@ -318,6 +321,8 @@ func GetHTTPRequestCookieToMetadataHandler(authCtx interfaces.AuthenticationCont logger.Infof(ctx, "Failed to retrieve user info cookie. Ignoring. Error: %v", err) } + logger.Debugf(ctx, "Retrieved [%v] Additional Claims: [%+v]", len(userInfo.AdditionalClaims.AsMap()), userInfo.AdditionalClaims.AsMap()) + raw, err := json.Marshal(userInfo) if err != nil { logger.Infof(ctx, "Failed to marshal user info. Ignoring. Error: %v", err) @@ -373,6 +378,8 @@ func IdentityContextFromRequest(ctx context.Context, req *http.Request, authCtx return nil, fmt.Errorf("unauthenticated request. Error: %w", err) } + logger.Debugf(ctx, "Retrieved2 [%v] Additional Claims: [%+v]", len(userInfo.AdditionalClaims.AsMap()), userInfo.AdditionalClaims.AsMap()) + return IdentityContextFromIDTokenToken(ctx, idToken, authCtx.Options().UserAuth.OpenID.ClientID, authCtx.OidcProvider(), userInfo) } @@ -410,16 +417,45 @@ func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Re userInfo, err := authCtx.OidcProvider().UserInfo(ctx, tokenSource) if err != nil { logger.Errorf(ctx, "Error getting user info from IDP %s", err) - return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err) } resp := &service.UserInfoResponse{} err = userInfo.Claims(&resp) if err != nil { logger.Errorf(ctx, "Error getting user info from IDP %s", err) - return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err) + } + + allClaims := make(map[string]any, 10) + err = userInfo.Claims(&allClaims) + if err != nil { + logger.Errorf(ctx, "Error unmarshalling raw claims %s", err) + return &service.UserInfoResponse{}, fmt.Errorf("error unmarshalling raw claims. Error: %w", err) + } + + logger.Debugf(ctx, "Unmarshalled a total of [%v] claims: [%+v]", len(allClaims), allClaims) + + alreadyRead := []string{"subject", "name", "preferred_username", "given_name", "family_name", "email", "picture"} + for _, existing := range alreadyRead { + delete(allClaims, existing) } + logger.Debugf(ctx, "Remaining a total of [%v] additional claims: [%+v]", len(allClaims), allClaims) + + var response _struct.Struct + b, err := json.Marshal(allClaims) + if err != nil { + return &service.UserInfoResponse{}, fmt.Errorf("failed to marshal additional claims to json. Error: %w", err) + } + + err = protojson.Unmarshal(b, &response) + if err != nil { + return nil, fmt.Errorf("failed to unamarshal additional claims to proto.struct. Error: %w", err) + } + + resp.AdditionalClaims = &response + return resp, err } @@ -449,6 +485,7 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler { return func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error { info, ok := m.(*service.UserInfoResponse) if ok { + logger.Debugf(ctx, "GetUserInfoForwardResponseHandler: Additional claims: [%+v]", info.AdditionalClaims.GetFields()) if info.AdditionalClaims != nil { for k, v := range info.AdditionalClaims.GetFields() { jsonBytes, err := v.MarshalJSON() @@ -457,6 +494,7 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler { continue } header := fmt.Sprintf("X-User-Claim-%s", strings.ReplaceAll(k, "_", "-")) + logger.Debugf(ctx, "Setting header [%v: %v]", header, string(jsonBytes)) w.Header().Set(header, string(jsonBytes)) } }