diff --git a/src/main/java/io/supertokens/storage/postgresql/Start.java b/src/main/java/io/supertokens/storage/postgresql/Start.java index 57da9eaf..bc4915a1 100644 --- a/src/main/java/io/supertokens/storage/postgresql/Start.java +++ b/src/main/java/io/supertokens/storage/postgresql/Start.java @@ -55,7 +55,9 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.multitenancy.sqlStorage.MultitenancySQLStorage; import io.supertokens.pluginInterface.oauth.OAuthLogoutChallenge; -import io.supertokens.pluginInterface.oauth.sqlStorage.OAuthSQLStorage; +import io.supertokens.pluginInterface.oauth.OAuthRevokeTargetType; +import io.supertokens.pluginInterface.oauth.OAuthStorage; +import io.supertokens.pluginInterface.oauth.exception.DuplicateOAuthLogoutChallengeException; import io.supertokens.pluginInterface.passwordless.PasswordlessCode; import io.supertokens.pluginInterface.passwordless.PasswordlessDevice; import io.supertokens.pluginInterface.passwordless.exception.*; @@ -108,7 +110,7 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, UserIdMappingSQLStorage, MultitenancyStorage, MultitenancySQLStorage, DashboardSQLStorage, TOTPSQLStorage, - ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage, OAuthSQLStorage { + ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage, OAuthStorage { // these configs are protected from being modified / viewed by the dev using the SuperTokens // SaaS. If the core is not running in SuperTokens SaaS, this array has no effect. @@ -865,6 +867,8 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi } } else if (className.equals(JWTRecipeStorage.class.getName())) { /* Since JWT recipe tables do not store userId we do not add any data to them */ + } else if (className.equals(OAuthStorage.class.getName())) { + /* Since OAuth recipe tables do not store userId we do not add any data to them */ } else if (className.equals(ActiveUsersStorage.class.getName())) { try { ActiveUsersQueries.updateUserLastActive(this, tenantIdentifier.toAppIdentifier(), userId); @@ -3082,154 +3086,162 @@ public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(A } @Override - public boolean doesClientIdExistForApp(AppIdentifier appIdentifier, String clientId) + public boolean doesOAuthClientIdExist(AppIdentifier appIdentifier, String clientId) throws StorageQueryException { try { - return OAuthQueries.isClientIdForAppId(this, clientId, appIdentifier); + return OAuthQueries.doesOAuthClientIdExist(this, clientId, appIdentifier); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void addOrUpdateClientForApp(AppIdentifier appIdentifier, String clientId, boolean isClientCredentialsOnly) + public void addOrUpdateOauthClient(AppIdentifier appIdentifier, String clientId, boolean isClientCredentialsOnly) throws StorageQueryException { try { - OAuthQueries.insertClientIdForAppId(this, appIdentifier, clientId, isClientCredentialsOnly); + OAuthQueries.addOrUpdateOauthClient(this, appIdentifier, clientId, isClientCredentialsOnly); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public boolean removeAppClientAssociation(AppIdentifier appIdentifier, String clientId) - throws StorageQueryException { + public boolean deleteOAuthClient(AppIdentifier appIdentifier, String clientId) throws StorageQueryException { try { - return OAuthQueries.deleteClientIdForAppId(this, clientId, appIdentifier); + return OAuthQueries.deleteOAuthClient(this, clientId, appIdentifier); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public List listClientsForApp(AppIdentifier appIdentifier) throws StorageQueryException { + public List listOAuthClients(AppIdentifier appIdentifier) throws StorageQueryException { try { - return OAuthQueries.listClientsForApp(this, appIdentifier); + return OAuthQueries.listOAuthClients(this, appIdentifier); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void revoke(AppIdentifier appIdentifier, String targetType, String targetValue, long exp) + public void revokeOAuthTokensBasedOnTargetFields(AppIdentifier appIdentifier, OAuthRevokeTargetType targetType, String targetValue, long exp) throws StorageQueryException { try { - OAuthQueries.revoke(this, appIdentifier, targetType, targetValue, exp); + OAuthQueries.revokeOAuthTokensBasedOnTargetFields(this, appIdentifier, targetType, targetValue, exp); } catch (SQLException e) { throw new StorageQueryException(e); } + } @Override - public boolean isRevoked(AppIdentifier appIdentifier, String[] targetTypes, String[] targetValues, long issuedAt) + public boolean isOAuthTokenRevokedBasedOnTargetFields(AppIdentifier appIdentifier, OAuthRevokeTargetType[] targetTypes, String[] targetValues, long issuedAt) throws StorageQueryException { try { - return OAuthQueries.isRevoked(this, appIdentifier, targetTypes, targetValues, issuedAt); + return OAuthQueries.isOAuthTokenRevokedBasedOnTargetFields(this, appIdentifier, targetTypes, targetValues, issuedAt); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void addM2MToken(AppIdentifier appIdentifier, String clientId, long iat, long exp) + public void addOAuthM2MTokenForStats(AppIdentifier appIdentifier, String clientId, long iat, long exp) throws StorageQueryException { try { - OAuthQueries.addM2MToken(this, appIdentifier, clientId, iat, exp); + OAuthQueries.addOAuthM2MTokenForStats(this, appIdentifier, clientId, iat, exp); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void addLogoutChallenge(AppIdentifier appIdentifier, String challenge, String clientId, - String postLogoutRedirectionUri, String sessionHandle, String state, long timeCreated) throws StorageQueryException { + public void cleanUpExpiredAndRevokedOAuthTokensList() throws StorageQueryException { try { - OAuthQueries.addLogoutChallenge(this, appIdentifier, challenge, clientId, postLogoutRedirectionUri, sessionHandle, state, timeCreated); + OAuthQueries.cleanUpExpiredAndRevokedOAuthTokensList(this); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public OAuthLogoutChallenge getLogoutChallenge(AppIdentifier appIdentifier, String challenge) - throws StorageQueryException { + public void addOAuthLogoutChallenge(AppIdentifier appIdentifier, String challenge, String clientId, + String postLogoutRedirectionUri, String sessionHandle, String state, long timeCreated) + throws StorageQueryException, DuplicateOAuthLogoutChallengeException { try { - return OAuthQueries.getLogoutChallenge(this, appIdentifier, challenge); + OAuthQueries.addOAuthLogoutChallenge(this, appIdentifier, challenge, clientId, postLogoutRedirectionUri, sessionHandle, state, timeCreated); } catch (SQLException e) { + PostgreSQLConfig config = Config.getConfig(this); + if (e instanceof PSQLException) { + ServerErrorMessage serverMessage = ((PSQLException) e).getServerErrorMessage(); + + if (isPrimaryKeyError(serverMessage, config.getOAuthLogoutChallengesTable())) { + throw new DuplicateOAuthLogoutChallengeException(); + } + } throw new StorageQueryException(e); } } @Override - public void deleteLogoutChallenge(AppIdentifier appIdentifier, String challenge) throws StorageQueryException { + public OAuthLogoutChallenge getOAuthLogoutChallenge(AppIdentifier appIdentifier, String challenge) throws StorageQueryException { try { - OAuthQueries.deleteLogoutChallenge(this, appIdentifier, challenge); + return OAuthQueries.getOAuthLogoutChallenge(this, appIdentifier, challenge); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void deleteLogoutChallengesBefore(AppIdentifier appIdentifier, long time) throws StorageQueryException { + public void deleteOAuthLogoutChallenge(AppIdentifier appIdentifier, String challenge) throws StorageQueryException { try { - OAuthQueries.deleteLogoutChallengesBefore(this, appIdentifier, time); + OAuthQueries.deleteOAuthLogoutChallenge(this, appIdentifier, challenge); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public void cleanUpExpiredAndRevokedTokens(AppIdentifier appIdentifier) throws StorageQueryException { + public void deleteOAuthLogoutChallengesBefore(long time) throws StorageQueryException { try { - OAuthQueries.cleanUpExpiredAndRevokedTokens(this, appIdentifier); + OAuthQueries.deleteOAuthLogoutChallengesBefore(this, time); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public int countTotalNumberOfM2MTokensAlive(AppIdentifier appIdentifier) throws StorageQueryException { + public int countTotalNumberOfOAuthClients(AppIdentifier appIdentifier) throws StorageQueryException { try { - return OAuthQueries.countTotalNumberOfM2MTokensAlive(this, appIdentifier); + return OAuthQueries.countTotalNumberOfClients(this, appIdentifier, false); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public int countTotalNumberOfM2MTokensCreatedSince(AppIdentifier appIdentifier, long since) + public int countTotalNumberOfClientCredentialsOnlyOAuthClients(AppIdentifier appIdentifier) throws StorageQueryException { try { - return OAuthQueries.countTotalNumberOfM2MTokensCreatedSince(this, appIdentifier, since); + return OAuthQueries.countTotalNumberOfClients(this, appIdentifier, true); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public int countTotalNumberOfClientCredentialsOnlyClientsForApp(AppIdentifier appIdentifier) + public int countTotalNumberOfOAuthM2MTokensCreatedSince(AppIdentifier appIdentifier, long since) throws StorageQueryException { try { - return OAuthQueries.countTotalNumberOfClientsForApp(this, appIdentifier, true); + return OAuthQueries.countTotalNumberOfOAuthM2MTokensCreatedSince(this, appIdentifier, since); } catch (SQLException e) { throw new StorageQueryException(e); } } @Override - public int countTotalNumberOfClientsForApp(AppIdentifier appIdentifier) throws StorageQueryException { + public int countTotalNumberOfOAuthM2MTokensAlive(AppIdentifier appIdentifier) throws StorageQueryException { try { - return OAuthQueries.countTotalNumberOfClientsForApp(this, appIdentifier, false); + return OAuthQueries.countTotalNumberOfOAuthM2MTokensAlive(this, appIdentifier); } catch (SQLException e) { throw new StorageQueryException(e); } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java index df7af769..217c5dea 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java @@ -1,4 +1,20 @@ -package io.supertokens.storage.postgresql.queries; +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + + package io.supertokens.storage.postgresql.queries; import java.sql.ResultSet; import java.sql.SQLException; @@ -8,6 +24,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.oauth.OAuthLogoutChallenge; +import io.supertokens.pluginInterface.oauth.OAuthRevokeTargetType; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; @@ -61,7 +78,7 @@ public static String getQueryToCreateOAuthRevokeTimestampIndex(Start start) { public static String getQueryToCreateOAuthRevokeExpIndex(Start start) { String oAuthRevokeTable = Config.getConfig(start).getOAuthRevokeTable(); return "CREATE INDEX IF NOT EXISTS oauth_revoke_exp_index ON " - + oAuthRevokeTable + "(exp DESC, app_id DESC);"; + + oAuthRevokeTable + "(exp DESC);"; } public static String getQueryToCreateOAuthM2MTokensTable(Start start) { @@ -91,7 +108,7 @@ public static String getQueryToCreateOAuthM2MTokenIatIndex(Start start) { public static String getQueryToCreateOAuthM2MTokenExpIndex(Start start) { String oAuthM2MTokensTable = Config.getConfig(start).getOAuthM2MTokensTable(); return "CREATE INDEX IF NOT EXISTS oauth_m2m_token_exp_index ON " - + oAuthM2MTokensTable + "(exp DESC, app_id DESC);"; + + oAuthM2MTokensTable + "(exp DESC);"; } public static String getQueryToCreateOAuthLogoutChallengesTable(Start start) { @@ -124,7 +141,7 @@ public static String getQueryToCreateOAuthLogoutChallengesTimeCreatedIndex(Start + oAuth2LogoutChallengesTable + "(time_created ASC, app_id ASC);"; } - public static boolean isClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) + public static boolean doesOAuthClientIdExist(Start start, String clientId, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE client_id = ? AND app_id = ?"; @@ -135,7 +152,7 @@ public static boolean isClientIdForAppId(Start start, String clientId, AppIdenti }, ResultSet::next); } - public static List listClientsForApp(Start start, AppIdentifier appIdentifier) + public static List listOAuthClients(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { String QUERY = "SELECT client_id FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ?"; @@ -150,7 +167,7 @@ public static List listClientsForApp(Start start, AppIdentifier appIdent }); } - public static void insertClientIdForAppId(Start start, AppIdentifier appIdentifier, String clientId, + public static void addOrUpdateOauthClient(Start start, AppIdentifier appIdentifier, String clientId, boolean isClientCredentialsOnly) throws SQLException, StorageQueryException { String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthClientsTable() @@ -164,7 +181,7 @@ public static void insertClientIdForAppId(Start start, AppIdentifier appIdentifi }); } - public static boolean deleteClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) + public static boolean deleteOAuthClient(Start start, String clientId, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { String DELETE = "DELETE FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ? AND client_id = ?"; @@ -175,7 +192,7 @@ public static boolean deleteClientIdForAppId(Start start, String clientId, AppId return numberOfRow > 0; } - public static void revoke(Start start, AppIdentifier appIdentifier, String targetType, String targetValue, long exp) + public static void revokeOAuthTokensBasedOnTargetFields(Start start, AppIdentifier appIdentifier, OAuthRevokeTargetType targetType, String targetValue, long exp) throws SQLException, StorageQueryException { String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthRevokeTable() + "(app_id, target_type, target_value, timestamp, exp) VALUES (?, ?, ?, ?, ?) " @@ -184,7 +201,7 @@ public static void revoke(Start start, AppIdentifier appIdentifier, String targe long currentTime = System.currentTimeMillis() / 1000; update(start, INSERT, pst -> { pst.setString(1, appIdentifier.getAppId()); - pst.setString(2, targetType); + pst.setString(2, targetType.getValue()); pst.setString(3, targetValue); pst.setLong(4, currentTime); pst.setLong(5, exp); @@ -193,11 +210,11 @@ public static void revoke(Start start, AppIdentifier appIdentifier, String targe }); } - public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String[] targetTypes, + public static boolean isOAuthTokenRevokedBasedOnTargetFields(Start start, AppIdentifier appIdentifier, OAuthRevokeTargetType[] targetTypes, String[] targetValues, long issuedAt) throws SQLException, StorageQueryException { String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthRevokeTable() + - " WHERE app_id = ? AND timestamp > ? AND ("; + " WHERE app_id = ? AND timestamp >= ? AND ("; for (int i = 0; i < targetTypes.length; i++) { QUERY += "(target_type = ? AND target_value = ?)"; @@ -215,7 +232,7 @@ public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String int index = 3; for (int i = 0; i < targetTypes.length; i++) { - pst.setString(index, targetTypes[i]); + pst.setString(index, targetTypes[i].getValue()); index++; pst.setString(index, targetValues[i]); index++; @@ -223,7 +240,7 @@ public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String }, ResultSet::next); } - public static int countTotalNumberOfClientsForApp(Start start, AppIdentifier appIdentifier, + public static int countTotalNumberOfClients(Start start, AppIdentifier appIdentifier, boolean filterByClientCredentialsOnly) throws SQLException, StorageQueryException { if (filterByClientCredentialsOnly) { String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientsTable() + @@ -251,7 +268,7 @@ public static int countTotalNumberOfClientsForApp(Start start, AppIdentifier app } } - public static int countTotalNumberOfM2MTokensAlive(Start start, AppIdentifier appIdentifier) + public static int countTotalNumberOfOAuthM2MTokensAlive(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + " WHERE app_id = ? AND exp > ?"; @@ -266,7 +283,7 @@ public static int countTotalNumberOfM2MTokensAlive(Start start, AppIdentifier ap }); } - public static int countTotalNumberOfM2MTokensCreatedSince(Start start, AppIdentifier appIdentifier, long since) + public static int countTotalNumberOfOAuthM2MTokensCreatedSince(Start start, AppIdentifier appIdentifier, long since) throws SQLException, StorageQueryException { String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + " WHERE app_id = ? AND iat >= ?"; @@ -281,7 +298,7 @@ public static int countTotalNumberOfM2MTokensCreatedSince(Start start, AppIdenti }); } - public static void addM2MToken(Start start, AppIdentifier appIdentifier, String clientId, long iat, long exp) + public static void addOAuthM2MTokenForStats(Start start, AppIdentifier appIdentifier, String clientId, long iat, long exp) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getOAuthM2MTokensTable() + " (app_id, client_id, iat, exp) VALUES (?, ?, ?, ?)"; @@ -293,33 +310,31 @@ public static void addM2MToken(Start start, AppIdentifier appIdentifier, String }); } - public static void cleanUpExpiredAndRevokedTokens(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { + public static void cleanUpExpiredAndRevokedOAuthTokensList(Start start) throws SQLException, StorageQueryException { { // delete expired M2M tokens String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + - " WHERE app_id = ? AND exp < ?"; + " WHERE exp < ?"; long timestamp = System.currentTimeMillis() / 1000 - 3600 * 24 * 31; // expired 31 days ago update(start, QUERY, pst -> { - pst.setString(1, appIdentifier.getAppId()); - pst.setLong(2, timestamp); + pst.setLong(1, timestamp); }); } { // delete expired revoked tokens String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthRevokeTable() + - " WHERE app_id = ? AND exp < ?"; + " WHERE exp < ?"; long timestamp = System.currentTimeMillis() / 1000 - 3600 * 24 * 31; // expired 31 days ago update(start, QUERY, pst -> { - pst.setString(1, appIdentifier.getAppId()); - pst.setLong(2, timestamp); + pst.setLong(1, timestamp); }); } } - public static void addLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge, String clientId, + public static void addOAuthLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge, String clientId, String postLogoutRedirectionUri, String sessionHandle, String state, long timeCreated) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getOAuthLogoutChallengesTable() + " (app_id, challenge, client_id, post_logout_redirect_uri, session_handle, state, time_created) VALUES (?, ?, ?, ?, ?, ?, ?)"; @@ -334,7 +349,7 @@ public static void addLogoutChallenge(Start start, AppIdentifier appIdentifier, }); } - public static OAuthLogoutChallenge getLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge) throws SQLException, StorageQueryException { + public static OAuthLogoutChallenge getOAuthLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge) throws SQLException, StorageQueryException { String QUERY = "SELECT challenge, client_id, post_logout_redirect_uri, session_handle, state, time_created FROM " + Config.getConfig(start).getOAuthLogoutChallengesTable() + " WHERE app_id = ? AND challenge = ?"; @@ -357,7 +372,7 @@ public static OAuthLogoutChallenge getLogoutChallenge(Start start, AppIdentifier }); } - public static void deleteLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge) throws SQLException, StorageQueryException { + public static void deleteOAuthLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge) throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthLogoutChallengesTable() + " WHERE app_id = ? AND challenge = ?"; update(start, QUERY, pst -> { @@ -366,12 +381,11 @@ public static void deleteLogoutChallenge(Start start, AppIdentifier appIdentifie }); } - public static void deleteLogoutChallengesBefore(Start start, AppIdentifier appIdentifier, long time) throws SQLException, StorageQueryException { + public static void deleteOAuthLogoutChallengesBefore(Start start, long time) throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthLogoutChallengesTable() + - " WHERE app_id = ? AND time_created < ?"; + " WHERE time_created < ?"; update(start, QUERY, pst -> { - pst.setString(1, appIdentifier.getAppId()); - pst.setLong(2, time); + pst.setLong(1, time); }); } }