Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the way how we check the OIDC Access Token Expiration and Val… #391

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
<spring-security-oauth2.version>2.0.17.RELEASE</spring-security-oauth2.version>
<jasypt.version>1.9.3</jasypt.version>
<keycloak-spring-security-adapter.version>18.0.0</keycloak-spring-security-adapter.version>
<spring-security-jwt.version>1.0.11.RELEASE</spring-security-jwt.version>
<spring-security-jwt.version>1.1.1.RELEASE</spring-security-jwt.version>
<java-jwt.version>3.18.3</java-jwt.version>
<wiremock-standalone.version>2.1.12</wiremock-standalone.version>
<hamcrest-core.version>1.3</hamcrest-core.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import static it.geosolutions.geostore.services.rest.security.oauth2.OAuth2Utils.*;

import com.fasterxml.jackson.databind.ObjectMapper;
import it.geosolutions.geostore.core.model.User;
import it.geosolutions.geostore.core.security.password.SecurityUtils;
import it.geosolutions.geostore.services.UserService;
Expand All @@ -52,6 +53,9 @@
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.jwt.crypto.sign.InvalidSignatureException;
import org.springframework.security.oauth2.client.OAuth2ClientContext;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
Expand Down Expand Up @@ -94,66 +98,131 @@ public SessionToken refresh(String refreshToken, String accessToken) {
String errorMessage = "";
String warningMessage = "";
HttpServletRequest request = getRequest();
if (accessToken == null || accessToken.isEmpty())

// Ensure accessToken is available
if (accessToken == null || accessToken.isEmpty()) {
accessToken = OAuth2Utils.tokenFromParamsOrBearer(ACCESS_TOKEN_PARAM, request);
if (accessToken == null || accessToken.isEmpty())
}
if (accessToken == null || accessToken.isEmpty()) {
throw new NotFoundWebEx("Either the accessToken or the refresh token are missing");
}

OAuth2AccessToken currentToken = retrieveAccessToken(accessToken);
OAuth2AccessToken currentToken = retrieveAccessToken(accessToken, null);

// Determine refreshTokenToUse
String refreshTokenToUse =
currentToken.getRefreshToken() != null
&& currentToken.getRefreshToken().getValue() != null
&& !currentToken.getRefreshToken().getValue().isEmpty()
? currentToken.getRefreshToken().getValue()
: refreshToken;
if (refreshTokenToUse == null || refreshTokenToUse.isEmpty())
Optional.ofNullable(currentToken.getRefreshToken())
.map(OAuth2RefreshToken::getValue)
.filter(value -> !value.isEmpty())
.orElse(refreshToken);

if (refreshTokenToUse == null || refreshTokenToUse.isEmpty()) {
refreshTokenToUse = getParameterValue(REFRESH_TOKEN_PARAM, request);
}

SessionToken sessionToken = null;
OAuth2Configuration configuration = configuration();

if (configuration != null && configuration.isEnabled()) {
if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token.");
LOGGER.info("Attempting to refresh the token.");
try {
sessionToken = doRefresh(refreshTokenToUse, accessToken, configuration);
if (sessionToken != null) {
currentToken =
retrieveAccessToken(
sessionToken.getAccessToken(), sessionToken.getExpires());
}
} catch (UserRedirectRequiredException e) {
// Log the warning and set the warning message in the session token
warningMessage = "A redirect is required to get the user's approval.";
LOGGER.warn(warningMessage);
} catch (NullPointerException npe) {
// Log the error and set the error message in the session token
errorMessage = "Current configuration wasn't correctly initialized.";
LOGGER.error("Current configuration wasn't correctly initialized.", npe);
LOGGER.warn(warningMessage, e);
} catch (Exception e) {
// Log the error and set the error message in the session token
errorMessage = "An error occurred during token refresh: " + e.getMessage();
LOGGER.error(errorMessage);
LOGGER.error(errorMessage, e);
}
} else {
LOGGER.warn("Configuration is null or disabled; skipping token refresh.");
}
if (sessionToken == null && !isTokenExpired(currentToken)) {
if (warningMessage.isEmpty())
warningMessage =
"Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!";
sessionToken =
sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration());

if (sessionToken == null) {
if (isTokenExpired(currentToken) /* || !isTokenValid(currentToken) */) {
errorMessage = "Token is invalid or expired, and refresh failed.";
LOGGER.error(errorMessage);
handleRefreshFailure(accessToken, refreshTokenToUse, configuration);
return null;
} else {
if (warningMessage.isEmpty()) warningMessage = "Using existing access token.";
sessionToken =
sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration());
}
}

if (sessionToken != null) {
if (!warningMessage.isEmpty()) sessionToken.setWarning(warningMessage);
if (!errorMessage.isEmpty()) sessionToken.setError(errorMessage);
request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken());
request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE, sessionToken.getTokenType());
if (!warningMessage.isEmpty()) {
sessionToken.setWarning(warningMessage);
}
if (!errorMessage.isEmpty()) {
sessionToken.setError(errorMessage);
}

request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken());
request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE, sessionToken.getTokenType());

return sessionToken;
}

private boolean isTokenExpired(OAuth2AccessToken token) {
return token != null
&& !token.getValue().isEmpty()
&& (token.getExpiration() == null
|| (token.getExpiration() != null
&& token.getExpiration().before(new Date())));
if (token == null || token.getValue().isEmpty()) {
return true;
}

Date expiration = token.getExpiration();

if (expiration == null) {
expiration = getExpirationDateFromToken(token.getValue());
if (expiration == null) {
return true;
}
}

// Allow clock skew if necessary
return expiration.before(new Date());
}

private Date getExpirationDateFromToken(String token) {
try {
Jwt decodedToken = JwtHelper.decode(token);
String claimsJson = decodedToken.getClaims();

ObjectMapper mapper = new ObjectMapper();
Map<String, Object> claims = mapper.readValue(claimsJson, Map.class);

Object exp = claims.get("exp");
if (exp != null) {
long expLong;
if (exp instanceof Integer) {
expLong = ((Integer) exp).longValue();
} else if (exp instanceof Long) {
expLong = (Long) exp;
} else if (exp instanceof String) {
expLong = Long.parseLong((String) exp);
} else {
throw new IllegalArgumentException("Cannot parse 'exp' claim from token");
}

// The 'exp' claim is usually in seconds since epoch
Date expiration = new Date(expLong * 1000);
return expiration;
} else {
return null;
}
} catch (InvalidSignatureException e) {
LOGGER.error("Invalid JWT signature: {}", e.getMessage());
return null;
} catch (Exception e) {
LOGGER.error("Failed to parse JWT token: {}", e.getMessage());
return null;
}
}

/**
Expand Down Expand Up @@ -203,7 +272,8 @@ protected SessionToken doRefresh(

if (response.getStatusCode().is2xxSuccessful()) {
OAuth2AccessToken newToken = response.getBody();
if (newToken != null && !isTokenExpired(newToken)) {
if (newToken != null
&& !isTokenExpired(newToken) /* && isTokenValid(newToken) */) {
OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken();
OAuth2RefreshToken refreshTokenToUse =
(newRefreshToken != null && newRefreshToken.getValue() != null)
Expand Down Expand Up @@ -290,9 +360,11 @@ public void handleRefreshFailure(
doLogout(null);

try {
String redirectUrl =
"../../openid/" + configuration.getProvider().toLowerCase() + "/login";
getResponse().sendRedirect(redirectUrl);
if (configuration != null && configuration.getProvider() != null) {
String redirectUrl =
"../../openid/" + configuration.getProvider().toLowerCase() + "/login";
getResponse().sendRedirect(redirectUrl);
}
} catch (IOException e) {
LOGGER.error("Error while sending redirect to login service: ", e);
throw new RuntimeException("Failed to redirect to login", e);
Expand Down Expand Up @@ -364,7 +436,7 @@ protected TokenDetails getTokenDetails(Authentication authentication) {
return OAuth2Utils.getTokenDetails(authentication);
}

protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
protected OAuth2AccessToken retrieveAccessToken(String accessToken, Long expires) {
Authentication authentication = cache() != null ? cache().get(accessToken) : null;
OAuth2AccessToken result = null;
if (authentication != null) {
Expand All @@ -378,7 +450,12 @@ protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
if (context != null) result = context.getAccessToken();
}
}
if (result == null) result = new DefaultOAuth2AccessToken(accessToken);
if (result == null) {
result = new DefaultOAuth2AccessToken(accessToken);
if (expires != null && expires > 0) {
((DefaultOAuth2AccessToken) result).setExpiration(new Date(expires));
}
}
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,7 @@ void testRefreshWithInvalidRefreshToken() {
"Refresh token should remain unchanged");
assertNotNull(sessionToken.getWarning(), "Warning message should be set");
assertTrue(
sessionToken
.getWarning()
.contains(
"Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!"),
sessionToken.getWarning().contains("Using existing access token."),
"Expected error message in SessionToken");
}

Expand Down Expand Up @@ -235,10 +232,7 @@ void testRefreshWithServerError() {
"Refresh token should remain unchanged after server error");
assertNotNull(sessionToken.getWarning(), "Warning message should be set");
assertTrue(
sessionToken
.getWarning()
.contains(
"Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!"),
sessionToken.getWarning().contains("Using existing access token."),
"Expected error message in SessionToken");
verify(restTemplate, times(3))
.exchange(
Expand Down Expand Up @@ -279,7 +273,7 @@ void testRefreshWithNullResponse() {
"Refresh token should remain unchanged");
assertNotNull(sessionToken.getWarning(), "Warning message should be set");
assertTrue(
sessionToken.getWarning().contains("Seeding it with previous Access Token!"),
sessionToken.getWarning().contains("Using existing access token."),
"Expected warning message in SessionToken");
}

Expand Down Expand Up @@ -558,7 +552,7 @@ protected TokenDetails getTokenDetails(Authentication authentication) {
}

@Override
protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
protected OAuth2AccessToken retrieveAccessToken(String accessToken, Long expires) {
return currentAccessToken;
}

Expand Down
Loading