Skip to content

Commit

Permalink
fix: logout apis
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Sep 26, 2024
1 parent 2ea8d10 commit ef2e7fb
Show file tree
Hide file tree
Showing 12 changed files with 300 additions and 152 deletions.
4 changes: 2 additions & 2 deletions src/main/java/io/supertokens/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import io.supertokens.config.Config;
import io.supertokens.config.CoreConfig;
import io.supertokens.cronjobs.Cronjobs;
import io.supertokens.cronjobs.cleanupOAuthRevokeList.CleanupOAuthRevokeList;
import io.supertokens.cronjobs.cleanupOAuthRevokeList.CleanupOAuthRevokeListAndChallenges;
import io.supertokens.cronjobs.deleteExpiredAccessTokenSigningKeys.DeleteExpiredAccessTokenSigningKeys;
import io.supertokens.cronjobs.deleteExpiredDashboardSessions.DeleteExpiredDashboardSessions;
import io.supertokens.cronjobs.deleteExpiredEmailVerificationTokens.DeleteExpiredEmailVerificationTokens;
Expand Down Expand Up @@ -257,7 +257,7 @@ private void init() throws IOException, StorageQueryException {
// starts DeleteExpiredAccessTokenSigningKeys cronjob if the access token signing keys can change
Cronjobs.addCronjob(this, DeleteExpiredAccessTokenSigningKeys.init(this, uniqueUserPoolIdsTenants));

Cronjobs.addCronjob(this, CleanupOAuthRevokeList.init(this, uniqueUserPoolIdsTenants));
Cronjobs.addCronjob(this, CleanupOAuthRevokeListAndChallenges.init(this, uniqueUserPoolIdsTenants));

// this is to ensure tenantInfos are in sync for the new cron job as well
MultitenancyHelper.getInstance(this).refreshCronjobs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,27 @@
import io.supertokens.pluginInterface.oauth.OAuthStorage;
import io.supertokens.storageLayer.StorageLayer;

public class CleanupOAuthRevokeList extends CronTask {
public class CleanupOAuthRevokeListAndChallenges extends CronTask {

public static final String RESOURCE_KEY = "io.supertokens.cronjobs.cleanupOAuthRevokeList" +
".CleanupOAuthRevokeList";

private CleanupOAuthRevokeList(Main main, List<List<TenantIdentifier>> tenantsInfo) {
private CleanupOAuthRevokeListAndChallenges(Main main, List<List<TenantIdentifier>> tenantsInfo) {
super("CleanupOAuthRevokeList", main, tenantsInfo, true);
}

public static CleanupOAuthRevokeList init(Main main, List<List<TenantIdentifier>> tenantsInfo) {
return (CleanupOAuthRevokeList) main.getResourceDistributor()
public static CleanupOAuthRevokeListAndChallenges init(Main main, List<List<TenantIdentifier>> tenantsInfo) {
return (CleanupOAuthRevokeListAndChallenges) main.getResourceDistributor()
.setResource(new TenantIdentifier(null, null, null), RESOURCE_KEY,
new CleanupOAuthRevokeList(main, tenantsInfo));
new CleanupOAuthRevokeListAndChallenges(main, tenantsInfo));
}

@Override
protected void doTaskPerApp(AppIdentifier app) throws Exception {
Storage storage = StorageLayer.getStorage(app.getAsPublicTenantIdentifier(), main);
OAuthStorage oauthStorage = StorageUtils.getOAuthStorage(storage);
oauthStorage.cleanUpExpiredAndRevokedTokens(app);
oauthStorage.deleteLogoutChallengesBefore(app, System.currentTimeMillis() - 1000 * 60 * 60 * 48);
}

@Override
Expand Down
38 changes: 38 additions & 0 deletions src/main/java/io/supertokens/inmemorydb/Start.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateThirdPartyIdException;
import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException;
import io.supertokens.pluginInterface.multitenancy.sqlStorage.MultitenancySQLStorage;
import io.supertokens.pluginInterface.oauth.OAuthLogoutChallenge;
import io.supertokens.pluginInterface.oauth.sqlStorage.OAuthSQLStorage;
import io.supertokens.pluginInterface.passwordless.PasswordlessCode;
import io.supertokens.pluginInterface.passwordless.PasswordlessDevice;
Expand Down Expand Up @@ -3077,6 +3078,43 @@ public void addM2MToken(AppIdentifier appIdentifier, String clientId, long iat,
}
}

@Override
public void addLogoutChallenge(AppIdentifier appIdentifier, String challenge, String clientId,
String postLogoutRedirectionUri, String state, long timeCreated) throws StorageQueryException {
try {
OAuthQueries.addLogoutChallenge(this, appIdentifier, challenge, clientId, postLogoutRedirectionUri, state, timeCreated);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public OAuthLogoutChallenge getLogoutChallenge(AppIdentifier appIdentifier, String challenge) throws StorageQueryException {
try {
return OAuthQueries.getLogoutChallenge(this, appIdentifier, challenge);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public void deleteLogoutChallenge(AppIdentifier appIdentifier, String challenge) throws StorageQueryException {
try {
OAuthQueries.deleteLogoutChallenge(this, appIdentifier, challenge);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public void deleteLogoutChallengesBefore(AppIdentifier appIdentifier, long time) throws StorageQueryException {
try {
OAuthQueries.deleteLogoutChallengesBefore(this, appIdentifier, time);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public void cleanUpExpiredAndRevokedTokens(AppIdentifier appIdentifier) throws StorageQueryException {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,8 @@ public String getOAuthRevokeTable() {
public String getOAuthM2MTokensTable() {
return "oauth_m2m_tokens";
}

public String getOAuthLogoutChallengesTable() {
return "oauth_logout_challenges";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,15 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc
update(start, OAuthQueries.getQueryToCreateOAuthM2MTokenIatIndex(start), NO_OP_SETTER);
update(start, OAuthQueries.getQueryToCreateOAuthM2MTokenExpIndex(start), NO_OP_SETTER);
}
}

if (!doesTableExists(start, Config.getConfig(start).getOAuthLogoutChallengesTable())) {
getInstance(main).addState(CREATING_NEW_TABLE, null);
update(start, OAuthQueries.getQueryToCreateOAuthLogoutChallengesTable(start), NO_OP_SETTER);

// index
update(start, OAuthQueries.getQueryToCreateOAuthLogoutChallengesTimeCreatedIndex(start), NO_OP_SETTER);
}
}

public static void setKeyValue_Transaction(Start start, Connection con, TenantIdentifier tenantIdentifier,
String key, KeyValueInfo info)
Expand Down
82 changes: 82 additions & 0 deletions src/main/java/io/supertokens/inmemorydb/queries/OAuthQueries.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.supertokens.inmemorydb.config.Config;
import io.supertokens.pluginInterface.exceptions.StorageQueryException;
import io.supertokens.pluginInterface.multitenancy.AppIdentifier;
import io.supertokens.pluginInterface.oauth.OAuthLogoutChallenge;

import java.sql.ResultSet;
import java.sql.SQLException;
Expand Down Expand Up @@ -92,6 +93,32 @@ public static String getQueryToCreateOAuthM2MTokenExpIndex(Start start) {
+ oAuth2M2MTokensTable + "(exp DESC, app_id DESC);";
}

public static String getQueryToCreateOAuthLogoutChallengesTable(Start start) {
String oAuth2LogoutChallengesTable = Config.getConfig(start).getOAuthLogoutChallengesTable();
// @formatter:off
return "CREATE TABLE IF NOT EXISTS " + oAuth2LogoutChallengesTable + " ("
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "challenge VARCHAR(128) NOT NULL,"
+ "client_id VARCHAR(128) NOT NULL,"
+ "post_logout_redirect_uri VARCHAR(1024),"
+ "gid VARCHAR(128),"
+ "state VARCHAR(128),"
+ "time_created BIGINT NOT NULL,"
+ "PRIMARY KEY (app_id, challenge),"
+ "FOREIGN KEY(app_id, client_id)"
+ " REFERENCES " + Config.getConfig(start).getOAuthClientsTable() + "(app_id, client_id) ON DELETE CASCADE,"
+ "FOREIGN KEY(app_id)"
+ " REFERENCES " + Config.getConfig(start).getAppsTable() + "(app_id) ON DELETE CASCADE"
+ ");";
// @formatter:on
}

public static String getQueryToCreateOAuthLogoutChallengesTimeCreatedIndex(Start start) {
String oAuth2LogoutChallengesTable = Config.getConfig(start).getOAuthLogoutChallengesTable();
return "CREATE INDEX IF NOT EXISTS oauth_logout_challenges_time_created_index ON "
+ oAuth2LogoutChallengesTable + "(time_created ASC, app_id ASC);";
}

public static boolean isClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier)
throws SQLException, StorageQueryException {
String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthClientsTable() +
Expand Down Expand Up @@ -285,4 +312,59 @@ public static void cleanUpExpiredAndRevokedTokens(Start start, AppIdentifier app
});
}
}

public static void addLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge, String clientId,
String postLogoutRedirectionUri, String state, long timeCreated) throws SQLException, StorageQueryException {
String QUERY = "INSERT INTO " + Config.getConfig(start).getOAuthLogoutChallengesTable() +
" (app_id, challenge, client_id, post_logout_redirect_uri, state, time_created) VALUES (?, ?, ?, ?, ?, ?)";
update(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, challenge);
pst.setString(3, clientId);
pst.setString(4, postLogoutRedirectionUri);
pst.setString(5, state);
pst.setLong(6, timeCreated);
});
}

public static OAuthLogoutChallenge getLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge) throws SQLException, StorageQueryException {
String QUERY = "SELECT challenge, client_id, post_logout_redirect_uri, gid, state, time_created FROM " +
Config.getConfig(start).getOAuthLogoutChallengesTable() +
" WHERE app_id = ? AND challenge = ?";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, challenge);
}, result -> {
if (result.next()) {
return new OAuthLogoutChallenge(
result.getString("challenge"),
result.getString("client_id"),
result.getString("post_logout_redirect_uri"),
result.getString("gid"),
result.getString("state"),
result.getLong("time_created")
);
}
return null;
});
}

public static void deleteLogoutChallenge(Start start, AppIdentifier appIdentifier, String challenge) throws SQLException, StorageQueryException {
String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthLogoutChallengesTable() +
" WHERE app_id = ? AND challenge = ?";
update(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, challenge);
});
}

public static void deleteLogoutChallengesBefore(Start start, AppIdentifier appIdentifier, long time) throws SQLException, StorageQueryException {
String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthLogoutChallengesTable() +
" WHERE app_id = ? AND time_created < ?";
update(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setLong(2, time);
});
}
}
74 changes: 45 additions & 29 deletions src/main/java/io/supertokens/oauth/OAuth.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException;
import io.supertokens.pluginInterface.multitenancy.AppIdentifier;
import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException;
import io.supertokens.pluginInterface.oauth.OAuthLogoutChallenge;
import io.supertokens.pluginInterface.oauth.OAuthStorage;
import io.supertokens.session.jwt.JWT.JWTException;
import io.supertokens.utils.Utils;
Expand Down Expand Up @@ -330,6 +331,8 @@ public static String transformTokensInAuthRedirect(Main main, AppIdentifier appI
public static JsonObject transformTokens(Main main, AppIdentifier appIdentifier, Storage storage, JsonObject jsonBody, String iss, JsonObject accessTokenUpdate, JsonObject idTokenUpdate, boolean useDynamicKey) throws IOException, JWTException, InvalidKeyException, NoSuchAlgorithmException, StorageQueryException, StorageTransactionLogicException, UnsupportedJWTSigningAlgorithmException, TenantOrAppNotFoundException, InvalidKeySpecException, JWTCreationException, InvalidConfigException {
String atHash = null;

System.out.println("transformTokens: " + jsonBody.toString());

if (jsonBody.has("refresh_token")) {
String refreshToken = jsonBody.get("refresh_token").getAsString();
refreshToken = refreshToken.replace("ory_rt_", "st_rt_");
Expand Down Expand Up @@ -551,45 +554,58 @@ public static void revokeSessionHandle(Main main, AppIdentifier appIdentifier, S
oauthStorage.revoke(appIdentifier, "session_handle", sessionHandle, exp);
}

public static void verifyIdTokenHintClientIdAndUpdateQueryParamsForLogout(Main main, AppIdentifier appIdentifier, Storage storage,
Map<String, String> queryParams) throws StorageQueryException, OAuthAPIException, TenantOrAppNotFoundException, UnsupportedJWTSigningAlgorithmException, StorageTransactionLogicException {
public static JsonObject verifyIdTokenAndGetPayload(Main main, AppIdentifier appIdentifier, Storage storage,
String idToken) throws StorageQueryException, OAuthAPIException, TenantOrAppNotFoundException, UnsupportedJWTSigningAlgorithmException, StorageTransactionLogicException {
try {
return OAuthToken.getPayloadFromJWTToken(appIdentifier, main, idToken);
} catch (TryRefreshTokenException e) {
// invalid id token
throw new OAuthAPIException("invalid_request", "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.", 400);
}
}

String idTokenHint = queryParams.get("idTokenHint");
String clientId = queryParams.get("clientId");
public static void addM2MToken(Main main, AppIdentifier appIdentifier, Storage storage, String accessToken) throws StorageQueryException, TenantOrAppNotFoundException, TryRefreshTokenException, UnsupportedJWTSigningAlgorithmException, StorageTransactionLogicException {
OAuthStorage oauthStorage = StorageUtils.getOAuthStorage(storage);
JsonObject payload = OAuthToken.getPayloadFromJWTToken(appIdentifier, main, accessToken);
oauthStorage.addM2MToken(appIdentifier, payload.get("client_id").getAsString(), payload.get("iat").getAsLong(), payload.get("exp").getAsLong());
}

JsonObject idTokenPayload = null;
if (idTokenHint != null) {
queryParams.remove("idTokenHint");
public static String createLogoutRequestAndReturnRedirectUri(Main main, AppIdentifier appIdentifier, Storage storage, String clientId,
String postLogoutRedirectionUri, String state, String idTokenHint) throws StorageQueryException {

OAuthStorage oauthStorage = StorageUtils.getOAuthStorage(storage);

try {
idTokenPayload = OAuthToken.getPayloadFromJWTToken(appIdentifier, main, idTokenHint);
} catch (TryRefreshTokenException e) {
// invalid id token
throw new OAuthAPIException("invalid_request", "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.", 400);
}
}
String logoutChallenge = UUID.randomUUID().toString();
oauthStorage.addLogoutChallenge(appIdentifier, logoutChallenge, clientId, postLogoutRedirectionUri, state, System.currentTimeMillis());
return "{apiDomain}/oauth/logout?logout_challenge=" + logoutChallenge;
}

if (idTokenPayload != null) {
if (!idTokenPayload.has("stt") || idTokenPayload.get("stt").getAsInt() != OAuthToken.TokenType.ID_TOKEN.getValue()) {
// Invalid id token
throw new OAuthAPIException("invalid_request", "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.", 400);
}
public static String consumeLogoutChallengeAndGetRedirectUri(Main main, AppIdentifier appIdentifier, Storage storage, String challenge) throws StorageQueryException, OAuthAPIException {
OAuthStorage oauthStorage = StorageUtils.getOAuthStorage(storage);
OAuthLogoutChallenge logoutChallenge = oauthStorage.getLogoutChallenge(appIdentifier, challenge);

String clientIdInIdTokenPayload = idTokenPayload.get("aud").getAsString();
if (logoutChallenge == null) {
throw new OAuthAPIException("invalid_request", "Logout request not found", 400);
}

if (clientId != null) {
if (!clientId.equals(clientIdInIdTokenPayload)) {
throw new OAuthAPIException("invalid_request", "The client_id in the id_token_hint does not match the client_id in the request.", 400);
}
}
oauthStorage.revoke(appIdentifier, "gid", logoutChallenge.gid, 3600 * 24 * (183 + 31));

String url = null;
if (logoutChallenge.postLogoutRedirectionUri != null) {
url = logoutChallenge.postLogoutRedirectionUri;
} else {
url = "{apiDomain}/fallbacks/logout/callback";
}

queryParams.put("clientId", clientIdInIdTokenPayload);
if (logoutChallenge.state != null) {
return url + "?state=" + logoutChallenge.state;
} else {
return url;
}
}

public static void addM2MToken(Main main, AppIdentifier appIdentifier, Storage storage, String accessToken) throws StorageQueryException, TenantOrAppNotFoundException, TryRefreshTokenException, UnsupportedJWTSigningAlgorithmException, StorageTransactionLogicException {
public static void deleteLogoutChallenge(Main main, AppIdentifier appIdentifier, Storage storage, String challenge) throws StorageQueryException {
OAuthStorage oauthStorage = StorageUtils.getOAuthStorage(storage);
JsonObject payload = OAuthToken.getPayloadFromJWTToken(appIdentifier, main, accessToken);
oauthStorage.addM2MToken(appIdentifier, payload.get("client_id").getAsString(), payload.get("iat").getAsLong(), payload.get("exp").getAsLong());
oauthStorage.deleteLogoutChallenge(appIdentifier, challenge);
}
}
1 change: 0 additions & 1 deletion src/main/java/io/supertokens/webserver/Webserver.java
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ private void setupRoutes() {
addAPI(new OAuthGetAuthLoginRequestAPI(main));
addAPI(new OAuthAcceptAuthLoginRequestAPI(main));
addAPI(new OAuthRejectAuthLoginRequestAPI(main));
addAPI(new OAuthGetAuthLogoutRequestAPI(main));
addAPI(new OAuthAcceptAuthLogoutRequestAPI(main));
addAPI(new OAuthRejectAuthLogoutRequestAPI(main));
addAPI(new OAuthTokenIntrospectAPI(main));
Expand Down
Loading

0 comments on commit ef2e7fb

Please sign in to comment.