Skip to content

Commit

Permalink
fix: various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tamassoltesz committed Oct 23, 2024
1 parent 8f9c59c commit 11229d9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 92 deletions.
31 changes: 13 additions & 18 deletions src/main/java/io/supertokens/inmemorydb/queries/OAuthQueries.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public static void createOrUpdateOAuthSession(Start start, AppIdentifier appIden
List<String> jtis, long exp)
throws SQLException, StorageQueryException {
String QUERY = "INSERT INTO " + Config.getConfig(start).getOAuthSessionsTable() +
" (gid, client_id, app_id, external_refresh_token, internal_refresh_token, session_handle, jti, exp) VALUES (?, ?, ?, ?, ?, ?) " +
" (gid, client_id, app_id, external_refresh_token, internal_refresh_token, session_handle, jti, exp) VALUES (?, ?, ?, ?, ?, ?, ?, ?) " +
"ON CONFLICT (gid) DO UPDATE SET external_refresh_token = ?, internal_refresh_token = ?, " +
"session_handle = ? , jti = ?, exp = ?";
update(start, QUERY, pst -> {
Expand Down Expand Up @@ -251,22 +251,17 @@ public static boolean deleteOAuthSessionBySessionHandle(Start start, AppIdentifi

public static boolean deleteJTIFromOAuthSession(Start start, AppIdentifier appIdentifier, String gid, String jti)
throws SQLException, StorageQueryException {
//jti is a coma separated list. When deleting a jti, just have to delete from the list
//jti is a comma separated list. When deleting a jti, just have to delete from the list
List<String> savedJTIs = getOAuthJTIsByGID(start, appIdentifier, gid);
List<String> toSaveJTIs = new ArrayList<>(savedJTIs);
boolean deletionHappened = false;
if (savedJTIs != null && savedJTIs.contains(jti)){
savedJTIs.remove(jti);
deletionHappened = updateOAuthJTIsByGID(start, appIdentifier, gid, savedJTIs) > 0;
if (toSaveJTIs != null && toSaveJTIs.contains(jti)){
toSaveJTIs.remove(jti);
deletionHappened = updateOAuthJTIsByGID(start, appIdentifier, gid, toSaveJTIs) > 0;
}
return deletionHappened;
}

// public static boolean isOAuthTokenRevokedBasedOnTargetFields(Start start, AppIdentifier appIdentifier, OAuthRevokeTargetType[] targetTypes, String[] targetValues, long issuedAt)
// throws SQLException, StorageQueryException {
// String oAuth2RefreshTokenMappingTable = Config.getConfig(start).getOAuthRefreshTokenMappingTable();
//
// }

public static int countTotalNumberOfClients(Start start, AppIdentifier appIdentifier,
boolean filterByClientCredentialsOnly) throws SQLException, StorageQueryException {
if (filterByClientCredentialsOnly) {
Expand Down Expand Up @@ -421,14 +416,14 @@ public static void deleteRefreshTokenMapping(Start start, AppIdentifier appIdent

public static List<String> getOAuthJTIsByGID(Start start, AppIdentifier appIdentifier, String gid)
throws SQLException, StorageQueryException {
String SELECT = "SELECT jit FROM " + Config.getConfig(start).getOAuthSessionsTable() +
String SELECT = "SELECT jti FROM " + Config.getConfig(start).getOAuthSessionsTable() +
" WHERE app_id = ? AND gid = ?";
return execute(start, SELECT, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, gid);
}, result -> {
if (result.next()) {
return List.of(result.getString("jit").split(","));
return List.of(result.getString("jti").split(","));
}
return null;
});
Expand All @@ -437,7 +432,7 @@ public static List<String> getOAuthJTIsByGID(Start start, AppIdentifier appIdent
public static int updateOAuthJTIsByGID(Start start, AppIdentifier appIdentifier, String gid, List<String> jtis)
throws SQLException, StorageQueryException {
String UPDATE = "UPDATE " + Config.getConfig(start).getOAuthSessionsTable() +
" SET jit = ? WHERE app_id = ? AND gid = ?";
" SET jti = ? WHERE app_id = ? AND gid = ?";
return update(start, UPDATE, pst -> {
pst.setString(1, String.join(",", jtis));
pst.setString(2, appIdentifier.getAppId());
Expand Down Expand Up @@ -474,7 +469,7 @@ public static boolean isOAuthSessionExistsByGID(Start start, AppIdentifier appId
pst.setString(2, gid);
}, result -> {
if(result.next()){
return result.getInt(0) > 0;
return result.getInt(1) > 0;
}
return false;
});
Expand All @@ -489,7 +484,7 @@ public static boolean isOAuthSessionExistsByClientId(Start start, AppIdentifier
pst.setString(2, clientId);
}, result -> {
if(result.next()){
return result.getInt(0) > 0;
return result.getInt(1) > 0;
}
return false;
});
Expand All @@ -504,7 +499,7 @@ public static boolean isOAuthSessionExistsBySessionHandle(Start start, AppIdenti
pst.setString(2, sessionHandle);
}, result -> {
if(result.next()){
return result.getInt(0) > 0;
return result.getInt(1) > 0;
}
return false;
});
Expand All @@ -519,7 +514,7 @@ public static boolean isOAuthSessionExistsByJTI(Start start, AppIdentifier appId
pst.setString(2, gid);
}, result -> {
if(result.next()){
List<String> jtis = List.of(result.getString(0).split(","));
List<String> jtis = List.of(result.getString(1).split(","));
return jtis.contains(jti);
}
return false;
Expand Down
12 changes: 6 additions & 6 deletions src/main/java/io/supertokens/oauth/OAuth.java
Original file line number Diff line number Diff line change
Expand Up @@ -654,14 +654,14 @@ public static OAuthClient getOAuthClientById(Main main, AppIdentifier appIdentif
return client;
}

public static String getOAuthProviderRefreshToken(Main main, AppIdentifier appIdentifier, Storage storage,
String refreshToken) throws StorageQueryException {
public static String getInternalRefreshToken(Main main, AppIdentifier appIdentifier, Storage storage,
String externalRefreshToken) throws StorageQueryException {
OAuthStorage oauthStorage = StorageUtils.getOAuthStorage(storage);
String opRefreshToken = oauthStorage.getRefreshTokenMapping(appIdentifier, refreshToken);
if (opRefreshToken == null) {
return refreshToken;
String internalRefreshToken = oauthStorage.getRefreshTokenMapping(appIdentifier, externalRefreshToken);
if (internalRefreshToken == null) {
return externalRefreshToken;
}
return opRefreshToken;
return internalRefreshToken;
}

public static void createOrUpdateRefreshTokenMapping(Main main, AppIdentifier appIdentifier, Storage storage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
String refreshToken = InputParser.parseStringOrThrowError(bodyFromSDK, "refresh_token", false);
inputRefreshToken = refreshToken;

String internalRefreshToken = OAuth.getOAuthProviderRefreshToken(main, appIdentifier, storage, refreshToken);
String internalRefreshToken = OAuth.getInternalRefreshToken(main, appIdentifier, storage, refreshToken);

Map<String, String> formFieldsForTokenIntrospect = new HashMap<>();
formFieldsForTokenIntrospect.put("token", internalRefreshToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
AppIdentifier appIdentifier = getAppIdentifier(req);
Storage storage = enforcePublicTenantAndGetPublicTenantStorage(req);

token = OAuth.getOAuthProviderRefreshToken(main, appIdentifier, storage, token);
token = OAuth.getInternalRefreshToken(main, appIdentifier, storage, token);
formFields.put("token", token);

HttpRequestForOAuthProvider.Response response = OAuthProxyHelper.proxyFormPOST(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
Storage storage = enforcePublicTenantAndGetPublicTenantStorage(req);

if (token.startsWith("st_rt_")) {
token = OAuth.getOAuthProviderRefreshToken(main, appIdentifier, storage, token);
token = OAuth.getInternalRefreshToken(main, appIdentifier, storage, token);

String gid = null;
long exp = -1;
Expand Down
98 changes: 33 additions & 65 deletions src/test/java/io/supertokens/test/oauth/OAuthStorageTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException;
import io.supertokens.pluginInterface.oauth.OAuthClient;
import io.supertokens.pluginInterface.oauth.OAuthLogoutChallenge;
import io.supertokens.pluginInterface.oauth.OAuthRevokeTargetType;
import io.supertokens.pluginInterface.oauth.OAuthStorage;
import io.supertokens.pluginInterface.oauth.exception.DuplicateOAuthLogoutChallengeException;
import io.supertokens.pluginInterface.oauth.exception.OAuthClientNotFoundException;
Expand Down Expand Up @@ -177,65 +176,35 @@ public void testRevoke() throws Exception {

AppIdentifier appIdentifier = new AppIdentifier(null, null);

storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.GID, "abcd", System.currentTimeMillis()/1000 + 2 - 3600 * 24 * 31);
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.SESSION_HANDLE, "efgh", System.currentTimeMillis()/1000 + 2 - 3600 * 24 * 31);
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.JTI, "ijkl", System.currentTimeMillis()/1000 + 2 - 3600 * 24 * 31);

assertTrue(storage.isOAuthTokenRevokedBasedOnTargetFields(
appIdentifier,
new OAuthRevokeTargetType[]{OAuthRevokeTargetType.GID},
new String[]{"abcd"},
System.currentTimeMillis()/1000 - 2
));
assertFalse(storage.isOAuthTokenRevokedBasedOnTargetFields(
appIdentifier,
new OAuthRevokeTargetType[]{OAuthRevokeTargetType.GID},
new String[]{"efgh"},
System.currentTimeMillis()/1000 - 2
));
assertTrue(storage.isOAuthTokenRevokedBasedOnTargetFields(
appIdentifier,
new OAuthRevokeTargetType[]{OAuthRevokeTargetType.GID, OAuthRevokeTargetType.SESSION_HANDLE},
new String[]{"efgh", "efgh"},
System.currentTimeMillis()/1000 - 2
));
storage.addOrUpdateOauthClient(appIdentifier, "clientid", "clientSecret", false, true);
storage.createOrUpdateOAuthSession(appIdentifier, "abcd", "clientid", "externalRefreshToken",
"internalRefreshToken", "efgh", List.of("ijkl", "mnop"), System.currentTimeMillis() + 1000 * 60 * 60 * 24);

assertFalse(storage.isOAuthTokenRevokedByGID(appIdentifier,"abcd"));
assertFalse(storage.isOAuthTokenRevokedByClientId(appIdentifier,"clientid"));
assertFalse(storage.isOAuthTokenRevokedBySessionHandle(appIdentifier, "efgh"));
assertFalse(storage.isOAuthTokenRevokedByJTI(appIdentifier, "abcd", "ijkl"));

storage.revokeOAuthTokenByJTI(appIdentifier, "abcd","ijkl");
assertTrue(storage.isOAuthTokenRevokedByJTI(appIdentifier, "abcd", "ijkl"));
assertFalse(storage.isOAuthTokenRevokedByJTI(appIdentifier, "abcd", "mnop"));

storage.revokeOAuthTokenByJTI(appIdentifier, "abcd","mnop");
assertTrue(storage.isOAuthTokenRevokedByJTI(appIdentifier, "abcd", "ijkl"));
assertTrue(storage.isOAuthTokenRevokedByJTI(appIdentifier, "abcd", "mnop"));


storage.revokeOAuthTokenByGID(appIdentifier, "abcd");
assertTrue(storage.isOAuthTokenRevokedByGID(appIdentifier,"abcd"));

storage.createOrUpdateOAuthSession(appIdentifier, "abcd", "clientid", "externalRefreshToken",
"internalRefreshToken", "efgh", List.of("ijkl", "mnop"), System.currentTimeMillis() + 1000 * 60 * 60 * 24);
storage.revokeOAuthTokenBySessionHandle(appIdentifier, "efgh");
assertTrue(storage.isOAuthTokenRevokedBySessionHandle(appIdentifier, "efgh"));

// test cleanup
Thread.sleep(3000);
storage.deleteExpiredRevokedOAuthTokens(System.currentTimeMillis() / 1000 - 3);

assertFalse(storage.isOAuthTokenRevokedBasedOnTargetFields(
appIdentifier,
new OAuthRevokeTargetType[]{OAuthRevokeTargetType.GID},
new String[]{"abcd"},
System.currentTimeMillis()/1000 - 5
));
assertFalse(storage.isOAuthTokenRevokedBasedOnTargetFields(
appIdentifier,
new OAuthRevokeTargetType[]{OAuthRevokeTargetType.GID, OAuthRevokeTargetType.SESSION_HANDLE},
new String[]{"efgh", "efgh"},
System.currentTimeMillis()/1000 - 5
));

// newly issued should be allowed
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.GID, "abcd", System.currentTimeMillis()/1000 + 2 - 3600 * 24 * 31);
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.SESSION_HANDLE, "efgh", System.currentTimeMillis()/1000 + 2 - 3600 * 24 * 31);
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.JTI, "ijkl", System.currentTimeMillis()/1000 + 2 - 3600 * 24 * 31);

Thread.sleep(2000);

assertFalse(storage.isOAuthTokenRevokedBasedOnTargetFields(
appIdentifier,
new OAuthRevokeTargetType[]{OAuthRevokeTargetType.GID},
new String[]{"abcd"},
System.currentTimeMillis()/1000
));
assertFalse(storage.isOAuthTokenRevokedBasedOnTargetFields(
appIdentifier,
new OAuthRevokeTargetType[]{OAuthRevokeTargetType.GID, OAuthRevokeTargetType.SESSION_HANDLE},
new String[]{"efgh", "efgh"},
System.currentTimeMillis()/1000
));
storage.deleteExpiredOAuthSessions(System.currentTimeMillis() / 1000 - 3);

process.kill();
assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED));
Expand Down Expand Up @@ -304,8 +273,7 @@ public void testConstraints() throws Exception {
// this is what we expect
}
{
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.GID, "abcd", 0);
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier, OAuthRevokeTargetType.GID, "abcd", 0); // should update
storage.revokeOAuthTokenByGID(appIdentifier, "abcd");
}

// App id FK
Expand All @@ -316,12 +284,12 @@ public void testConstraints() throws Exception {
} catch (TenantOrAppNotFoundException e) {
// expected
}
try {
storage.revokeOAuthTokensBasedOnTargetFields(appIdentifier2, OAuthRevokeTargetType.GID, "abcd", 0);
fail();
} catch (TenantOrAppNotFoundException e) {
// expected
}
// try {
storage.revokeOAuthTokenByGID(appIdentifier2, "abcd");
// fail();
// } catch (TenantOrAppNotFoundException e) {
// // expected
// }

// Client FK
try {
Expand Down

0 comments on commit 11229d9

Please sign in to comment.