From a4b1cf0e98acc77fe1b2366f0bcbda7e33ed5e6f Mon Sep 17 00:00:00 2001 From: Mihaly Lengyel Date: Fri, 25 Oct 2024 22:00:39 +0200 Subject: [PATCH] fix: add client_credentials and implicit flow tokens to oauth_sessions --- .../java/io/supertokens/oauth/OAuthToken.java | 11 ++++--- .../webserver/api/oauth/OAuthAuthAPI.java | 31 ++++++++++++++----- .../webserver/api/oauth/OAuthTokenAPI.java | 11 +++---- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/main/java/io/supertokens/oauth/OAuthToken.java b/src/main/java/io/supertokens/oauth/OAuthToken.java index 0a9ac61df..7f17bba6b 100644 --- a/src/main/java/io/supertokens/oauth/OAuthToken.java +++ b/src/main/java/io/supertokens/oauth/OAuthToken.java @@ -25,10 +25,7 @@ import java.security.KeyException; import java.security.NoSuchAlgorithmException; import java.security.spec.InvalidKeySpecException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; public class OAuthToken { public enum TokenType { @@ -125,6 +122,12 @@ public static String reSignToken(AppIdentifier appIdentifier, Main main, String payload.remove("ext"); payload.remove("initialPayload"); + // We ensure that the gid is there + // If it isn't that means that we are in a client_credentials (M2M) flow + if (!payload.has("gid")) { + payload.addProperty("gid", UUID.randomUUID().toString()); + } + if (payloadUpdate != null) { for (Map.Entry entry : payloadUpdate.entrySet()) { if (!NON_OVERRIDABLE_TOKEN_PROPS.contains(entry.getKey())) { diff --git a/src/main/java/io/supertokens/webserver/api/oauth/OAuthAuthAPI.java b/src/main/java/io/supertokens/webserver/api/oauth/OAuthAuthAPI.java index 50b3a5740..1d9ad5817 100644 --- a/src/main/java/io/supertokens/webserver/api/oauth/OAuthAuthAPI.java +++ b/src/main/java/io/supertokens/webserver/api/oauth/OAuthAuthAPI.java @@ -24,15 +24,16 @@ import io.supertokens.multitenancy.exception.BadPermissionException; import io.supertokens.oauth.HttpRequestForOAuthProvider; import io.supertokens.oauth.OAuth; -import io.supertokens.oauth.OAuthToken; import io.supertokens.pluginInterface.RECIPE_ID; import io.supertokens.pluginInterface.Storage; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.session.SessionInfo; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; import io.supertokens.session.Session; +import io.supertokens.session.jwt.JWT; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.useridmapping.UserIdType; import io.supertokens.webserver.InputParser; @@ -112,14 +113,28 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I for (String part : parts) { if (part.startsWith("access_token=")) { String accessToken = java.net.URLDecoder.decode(part.split("=")[1], "UTF-8"); + JsonObject accessTokenPayload; try { - JsonObject accessTokenPayload = OAuthToken.getPayloadFromJWTToken(appIdentifier, main, accessToken); - if (accessTokenPayload.has("sessionHandle")) { - updateLastActive(appIdentifier, accessTokenPayload.get("sessionHandle").getAsString()); - } - } catch (Exception e) { - // ignore + JWT.JWTInfo jwtInfo = JWT.getPayloadWithoutVerifying(accessToken); + accessTokenPayload = jwtInfo.payload; + } catch (JWT.JWTException e) { + // This should never happen here since we just created/signed the token + throw new ServletException(e); } + + String clientId = accessTokenPayload.get("client_id").getAsString(); + String gid = accessTokenPayload.get("gid").getAsString(); + String jti = accessTokenPayload.get("jti").getAsString(); + + long exp = accessTokenPayload.get("exp").getAsLong(); + + String sessionHandle = null; + if (accessTokenPayload.has("sessionHandle")) { + sessionHandle = accessTokenPayload.get("sessionHandle").getAsString(); + updateLastActive(appIdentifier, sessionHandle); + } + + OAuth.createOrUpdateRefreshTokenMapping(main, appIdentifier, storage, clientId, gid, null, null, sessionHandle, List.of(jti), exp); } } } @@ -142,7 +157,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, finalResponse, resp); } - } catch (IOException | TenantOrAppNotFoundException | BadPermissionException e) { + } catch (IOException | TenantOrAppNotFoundException | BadPermissionException | StorageQueryException e) { throw new ServletException(e); } } diff --git a/src/main/java/io/supertokens/webserver/api/oauth/OAuthTokenAPI.java b/src/main/java/io/supertokens/webserver/api/oauth/OAuthTokenAPI.java index 36dd026ef..c35ccf685 100644 --- a/src/main/java/io/supertokens/webserver/api/oauth/OAuthTokenAPI.java +++ b/src/main/java/io/supertokens/webserver/api/oauth/OAuthTokenAPI.java @@ -191,16 +191,14 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I String gid = null; String jti = null; String sessionHandle = null; + Long exp = null; if(response.jsonResponse.getAsJsonObject().has("access_token")){ try { JsonObject accessTokenPayload = OAuthToken.getPayloadFromJWTToken(appIdentifier, main, response.jsonResponse.getAsJsonObject().get("access_token").getAsString()); - if(accessTokenPayload.has("gid")) { - gid = accessTokenPayload.get("gid").getAsString(); - } - if(accessTokenPayload.has("jti")) { - jti = accessTokenPayload.get("jti").getAsString(); - } + gid = accessTokenPayload.get("gid").getAsString(); + jti = accessTokenPayload.get("jti").getAsString(); + exp = accessTokenPayload.get("exp").getAsLong(); } catch (TryRefreshTokenException e) { //ignore, shouldn't happen } @@ -257,6 +255,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I if (accessTokenPayload.has("sessionHandle")) { updateLastActive(appIdentifier, accessTokenPayload.get("sessionHandle").getAsString()); } + OAuth.createOrUpdateRefreshTokenMapping(main, appIdentifier, storage, clientId, gid, null, null, sessionHandle, List.of(jti), exp); } catch (Exception e) { // ignore }