Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements with OKTA OIDC provider integration #385

Merged
merged 5 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ private SessionToken toSessionToken(String accessToken, UserSession sessionToken
@Override
public SessionToken refresh(SecurityContext sc, String sessionId, String refreshToken) {
String provider =
(String) RequestContextHolder.getRequestAttributes().getAttribute(PROVIDER_KEY, 0);
(String)
Objects.requireNonNull(RequestContextHolder.getRequestAttributes())
.getAttribute(PROVIDER_KEY, 0);
SessionServiceDelegate delegate = getDelegate(provider);
return delegate.refresh(refreshToken, sessionId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@
import java.io.IOException;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.http.*;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2ClientContext;
Expand All @@ -56,11 +58,13 @@
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.provider.authentication.OAuth2AuthenticationDetails;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpMessageConverterExtractor;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.context.request.RequestContextHolder;

Expand Down Expand Up @@ -90,25 +94,27 @@ public SessionToken refresh(String refreshToken, String accessToken) {
throw new NotFoundWebEx("Either the accessToken or the refresh token are missing");

OAuth2AccessToken currentToken = retrieveAccessToken(accessToken);
Date expiresIn = currentToken.getExpiration();
if (refreshToken == null || refreshToken.isEmpty())
refreshToken = getParameterValue(REFRESH_TOKEN_PARAM, request);
Date fiveMinutesFromNow = fiveMinutesFromNow();
String refreshTokenToUse =
currentToken.getRefreshToken() != null
&& currentToken.getRefreshToken().getValue() != null
&& !currentToken.getRefreshToken().getValue().isEmpty()
? currentToken.getRefreshToken().getValue()
: refreshToken;
if (refreshTokenToUse == null || refreshTokenToUse.isEmpty())
refreshTokenToUse = getParameterValue(REFRESH_TOKEN_PARAM, request);
SessionToken sessionToken = null;
OAuth2Configuration configuration = configuration();
if (configuration != null && configuration.isEnabled()) {
if ((expiresIn == null || fiveMinutesFromNow.after(expiresIn))
&& refreshToken != null) {
if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token.");
try {
sessionToken = doRefresh(refreshToken, accessToken, configuration);
} catch (NullPointerException npe) {
LOGGER.error("Current configuration wasn't correctly initialized.");
}
if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token.");
try {
sessionToken = doRefresh(refreshTokenToUse, accessToken, configuration);
} catch (NullPointerException npe) {
LOGGER.error("Current configuration wasn't correctly initialized.");
}
}
if (sessionToken == null)
sessionToken = sessionToken(accessToken, refreshToken, currentToken.getExpiration());
sessionToken =
sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration());

request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken());
Expand All @@ -119,86 +125,121 @@ public SessionToken refresh(String refreshToken, String accessToken) {
}

/**
* Invokes the refresh endpoint and return a session token holding the updated tokens details.
* Invokes the refresh endpoint to get a new session token with updated token details.
*
* @param refreshToken the refresh token.
* @param accessToken the access token.
* @param configuration the OAuth2Configuration.
* @return the SessionToken.
* <p>This method attempts to refresh the session by exchanging the provided refresh token for a
* new access token. If the refresh token is invalid or the request fails after several retries,
* the session is cleared, and the user is redirected to the login page.
*
* @param refreshToken the refresh token to use for obtaining new access and refresh tokens
* @param accessToken the current access token
* @param configuration the OAuth2Configuration containing client credentials and endpoint URI
* @return a SessionToken containing the new token details, or null if the refresh process
* failed
*/
protected SessionToken doRefresh(
String refreshToken, String accessToken, OAuth2Configuration configuration) {
SessionToken sessionToken = null;
int maxRetries = 3;
int attempt = 0;
boolean success = false;

// Setup HTTP headers and body for the request
RestTemplate restTemplate = new RestTemplate();
HttpHeaders headers = getHttpHeaders(accessToken, configuration);

MultiValueMap<String, String> requestBody = new LinkedMultiValueMap<>();
requestBody.add("grant_type", "refresh_token");
requestBody.add("refresh_token", refreshToken);
requestBody.add("client_secret", configuration.getClientSecret());
requestBody.add("client_id", configuration.getClientId());

HttpEntity<MultiValueMap<String, String>> requestEntity =
new HttpEntity<>(requestBody, headers);

OAuth2AccessToken newToken = null;
try {
newToken =
restTemplate
.exchange(
configuration
.buildRefreshTokenURI(), // Use exchange method for POST
// request
HttpMethod.POST,
requestEntity, // Include request body
OAuth2AccessToken.class)
.getBody();
} catch (Exception ex) {
LOGGER.error("Error trying to obtain a refresh token.", ex);
}
while (attempt < maxRetries && !success) {
attempt++;
LOGGER.info("Attempting to refresh token, attempt {} of {}", attempt, maxRetries);

if (refreshToken != null
&& accessToken != null
&& !refreshToken.isEmpty()
&& !accessToken.isEmpty()
&& newToken != null
&& newToken.getValue() != null
&& !newToken.getValue().isEmpty()) {
// update the Authentication
String newRefreshToken =
newToken.getRefreshToken() != null
&& newToken.getRefreshToken().getValue() != null
&& !newToken.getRefreshToken().getValue().isEmpty()
? newToken.getRefreshToken().getValue()
: refreshToken;
updateAuthToken(accessToken, newToken, newRefreshToken, configuration);
sessionToken =
sessionToken(newToken.getValue(), refreshToken, newToken.getExpiration());
} else if (accessToken != null) {
// update the Authentication
sessionToken = sessionToken(accessToken, refreshToken, null);
} else {
// the refresh token was invalid. let's clear the session and send a remote logout.
// then redirect to the login entry point.
LOGGER.info(
"Unable to refresh the token. The following request was performed: {}. Redirecting to login.",
configuration.buildRefreshTokenURI("offline"));
doLogout(null);
try {
getResponse()
.sendRedirect(
"../../openid/"
+ configuration.getProvider().toLowerCase()
+ "/login");
} catch (IOException e) {
LOGGER.error("Error while sending redirect to login service. ", e);
throw new RuntimeException(e);
ResponseEntity<OAuth2AccessToken> response =
restTemplate.exchange(
configuration.buildRefreshTokenURI(),
HttpMethod.POST,
requestEntity,
OAuth2AccessToken.class);

if (response.getStatusCode().is2xxSuccessful()) {
OAuth2AccessToken newToken = response.getBody();
if (newToken != null
&& newToken.getValue() != null
&& !newToken.getValue().isEmpty()) {
// Process and update the new token details
OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken();
OAuth2RefreshToken refreshTokenToUse =
(newRefreshToken != null && newRefreshToken.getValue() != null)
? newRefreshToken
: new DefaultOAuth2RefreshToken(refreshToken);

updateAuthToken(accessToken, newToken, refreshTokenToUse, configuration);
sessionToken =
sessionToken(
newToken.getValue(),
refreshTokenToUse.getValue(),
newToken.getExpiration());

LOGGER.info("Token refreshed successfully on attempt {}", attempt);
success = true;
} else {
LOGGER.warn("Received empty or null token on attempt {}", attempt);
}
} else if (response.getStatusCode().is4xxClientError()) {
// For client errors (e.g., 400, 401, 403), do not retry.
LOGGER.error(
"Client error occurred: {}. Stopping further attempts.",
response.getStatusCode());
break;
} else {
// For server errors (5xx), continue retrying
LOGGER.warn("Server error occurred: {}. Retrying...", response.getStatusCode());
}
} catch (RestClientException ex) {
LOGGER.error("Attempt {}: Error refreshing token: {}", attempt, ex.getMessage());
if (attempt == maxRetries) {
LOGGER.error("Max retries reached. Unable to refresh token.");
}
}
}

// Handle unsuccessful refresh
if (!success) {
handleRefreshFailure(accessToken, refreshToken, configuration);
}
return sessionToken;
}

/**
* Handles the refresh failure by clearing the session, logging out remotely, and redirecting to
* login.
*
* @param accessToken the current access token
* @param refreshToken the current refresh token
* @param configuration the OAuth2Configuration with endpoint details
*/
private void handleRefreshFailure(
String accessToken, String refreshToken, OAuth2Configuration configuration) {
LOGGER.info(
"Unable to refresh token after max retries. Clearing session and redirecting to login.");
doLogout(null);

try {
String redirectUrl =
"../../openid/" + configuration.getProvider().toLowerCase() + "/login";
getResponse().sendRedirect(redirectUrl);
} catch (IOException e) {
LOGGER.error("Error while sending redirect to login service: ", e);
throw new RuntimeException("Failed to redirect to login", e);
}
}

private static HttpHeaders getHttpHeaders(
String accessToken, OAuth2Configuration configuration) {
HttpHeaders headers = new HttpHeaders();
Expand Down Expand Up @@ -227,21 +268,20 @@ private SessionToken sessionToken(String accessToken, String refreshToken, Date

// Builds an authentication instance out of the passed values.
// Sets it to the cache and to the SecurityContext to be sure the new token is updates.
private Authentication updateAuthToken(
private void updateAuthToken(
String oldToken,
OAuth2AccessToken newToken,
String refreshToken,
OAuth2RefreshToken refreshToken,
OAuth2Configuration conf) {
Authentication authentication = cache().get(oldToken);
if (authentication == null)
authentication = SecurityContextHolder.getContext().getAuthentication();

if (authentication instanceof PreAuthenticatedAuthenticationToken) {
if (LOGGER.isDebugEnabled())
LOGGER.info("Updating the cache and the SecurityContext with new Auth details");
String idToken = null;
if (LOGGER.isDebugEnabled())
LOGGER.info("Updating the cache and the SecurityContext with new Auth details");
if (authentication != null && !(authentication instanceof AnonymousAuthenticationToken)) {
TokenDetails details = getTokenDetails(authentication);
idToken = details.getIdToken();
String idToken = details.getIdToken();
cache().removeEntry(oldToken);
PreAuthenticatedAuthenticationToken updated =
new PreAuthenticatedAuthenticationToken(
Expand All @@ -250,20 +290,15 @@ private Authentication updateAuthToken(
authentication.getAuthorities());
DefaultOAuth2AccessToken accessToken = new DefaultOAuth2AccessToken(newToken);
if (refreshToken != null) {
accessToken.setRefreshToken(new DefaultOAuth2RefreshToken(refreshToken));
accessToken.setRefreshToken(refreshToken);
}
if (LOGGER.isDebugEnabled())
LOGGER.debug(
"Creating new details. AccessToken: "
+ accessToken
+ " IdToken: "
+ idToken);
"Creating new details. AccessToken: {} IdToken: {}", accessToken, idToken);
updated.setDetails(new TokenDetails(accessToken, idToken, conf.getBeanName()));
cache().putCacheEntry(newToken.getValue(), updated);
SecurityContextHolder.getContext().setAuthentication(updated);
authentication = updated;
}
return authentication;
}

private OAuth2AccessToken retrieveAccessToken(String accessToken) {
Expand Down Expand Up @@ -318,7 +353,7 @@ public void doLogout(String sessionId) {
if (token == null) {
token =
(String)
RequestContextHolder.getRequestAttributes()
Objects.requireNonNull(RequestContextHolder.getRequestAttributes())
.getAttribute(REFRESH_TOKEN_PARAM, 0);
}
}
Expand All @@ -333,7 +368,7 @@ public void doLogout(String sessionId) {
if (accessToken == null) {
accessToken =
(String)
RequestContextHolder.getRequestAttributes()
Objects.requireNonNull(RequestContextHolder.getRequestAttributes())
.getAttribute(ACCESS_TOKEN_PARAM, 0);
}
}
Expand Down Expand Up @@ -365,8 +400,10 @@ private void clearSession(OAuth2RestTemplate restTemplate, HttpServletRequest re
.removePreservedState(accessTokenRequest.getStateKey());
}
try {
accessTokenRequest.remove("access_token");
accessTokenRequest.remove("refresh_token");
if (accessTokenRequest != null) {
accessTokenRequest.remove("access_token");
accessTokenRequest.remove("refresh_token");
}
request.logout();
} catch (ServletException e) {
LOGGER.error("Error happened while doing request logout: ", e);
Expand Down Expand Up @@ -445,9 +482,8 @@ protected void callRemoteLogout(String token, String accessToken) {

protected void clearCookies(HttpServletRequest request, HttpServletResponse response) {
javax.servlet.http.Cookie[] allCookies = request.getCookies();
if (allCookies != null && allCookies.length > 0)
for (int i = 0; i < allCookies.length; i++) {
javax.servlet.http.Cookie toDelete = allCookies[i];
if (allCookies != null)
for (javax.servlet.http.Cookie toDelete : allCookies) {
if (deleteCookie(toDelete)) {
toDelete.setMaxAge(-1);
toDelete.setPath("/");
Expand Down
Loading