diff --git a/src/modules/rest/impl/pom.xml b/src/modules/rest/impl/pom.xml
index de722bc7..f278bb81 100644
--- a/src/modules/rest/impl/pom.xml
+++ b/src/modules/rest/impl/pom.xml
@@ -255,6 +255,22 @@
test
+
+
+ org.junit.jupiter
+ junit-jupiter
+ 5.9.1
+ test
+
+
+
+
+ org.mockito
+ mockito-junit-jupiter
+ 4.8.0
+ test
+
+
@@ -288,8 +304,15 @@
-->
+
+
+ maven-surefire-plugin
+ 3.0.0-M7
+
+ false
+
+
-
diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java
index 99f04981..16055016 100644
--- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java
+++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java
@@ -189,7 +189,9 @@ private SessionToken toSessionToken(String accessToken, UserSession sessionToken
@Override
public SessionToken refresh(SecurityContext sc, String sessionId, String refreshToken) {
String provider =
- (String) RequestContextHolder.getRequestAttributes().getAttribute(PROVIDER_KEY, 0);
+ (String)
+ Objects.requireNonNull(RequestContextHolder.getRequestAttributes())
+ .getAttribute(PROVIDER_KEY, 0);
SessionServiceDelegate delegate = getDelegate(provider);
return delegate.refresh(refreshToken, sessionId);
}
diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java
index 8cbad1aa..1165d4c3 100644
--- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java
+++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java
@@ -41,6 +41,7 @@
import java.io.IOException;
import java.util.Date;
import java.util.Map;
+import java.util.Objects;
import java.util.Optional;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
@@ -48,6 +49,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.http.*;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2ClientContext;
@@ -56,11 +58,13 @@
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
+import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.provider.authentication.OAuth2AuthenticationDetails;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpMessageConverterExtractor;
+import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.context.request.RequestContextHolder;
@@ -81,6 +85,9 @@ public OAuth2SessionServiceDelegate(
this.userService = userService;
}
+ public OAuth2SessionServiceDelegate(
+ RestTemplate restTemplate, OAuth2Configuration configuration) {}
+
@Override
public SessionToken refresh(String refreshToken, String accessToken) {
HttpServletRequest request = getRequest();
@@ -90,25 +97,27 @@ public SessionToken refresh(String refreshToken, String accessToken) {
throw new NotFoundWebEx("Either the accessToken or the refresh token are missing");
OAuth2AccessToken currentToken = retrieveAccessToken(accessToken);
- Date expiresIn = currentToken.getExpiration();
- if (refreshToken == null || refreshToken.isEmpty())
- refreshToken = getParameterValue(REFRESH_TOKEN_PARAM, request);
- Date fiveMinutesFromNow = fiveMinutesFromNow();
+ String refreshTokenToUse =
+ currentToken.getRefreshToken() != null
+ && currentToken.getRefreshToken().getValue() != null
+ && !currentToken.getRefreshToken().getValue().isEmpty()
+ ? currentToken.getRefreshToken().getValue()
+ : refreshToken;
+ if (refreshTokenToUse == null || refreshTokenToUse.isEmpty())
+ refreshTokenToUse = getParameterValue(REFRESH_TOKEN_PARAM, request);
SessionToken sessionToken = null;
OAuth2Configuration configuration = configuration();
if (configuration != null && configuration.isEnabled()) {
- if ((expiresIn == null || fiveMinutesFromNow.after(expiresIn))
- && refreshToken != null) {
- if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token.");
- try {
- sessionToken = doRefresh(refreshToken, accessToken, configuration);
- } catch (NullPointerException npe) {
- LOGGER.error("Current configuration wasn't correctly initialized.");
- }
+ if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token.");
+ try {
+ sessionToken = doRefresh(refreshTokenToUse, accessToken, configuration);
+ } catch (NullPointerException npe) {
+ LOGGER.error("Current configuration wasn't correctly initialized.");
}
}
if (sessionToken == null)
- sessionToken = sessionToken(accessToken, refreshToken, currentToken.getExpiration());
+ sessionToken =
+ sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration());
request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken());
@@ -119,86 +128,122 @@ public SessionToken refresh(String refreshToken, String accessToken) {
}
/**
- * Invokes the refresh endpoint and return a session token holding the updated tokens details.
+ * Invokes the refresh endpoint to get a new session token with updated token details.
+ *
+ *
This method attempts to refresh the session by exchanging the provided refresh token for a
+ * new access token. If the refresh token is invalid or the request fails after several retries,
+ * the session is cleared, and the user is redirected to the login page.
*
- * @param refreshToken the refresh token.
- * @param accessToken the access token.
- * @param configuration the OAuth2Configuration.
- * @return the SessionToken.
+ * @param refreshToken the refresh token to use for obtaining new access and refresh tokens
+ * @param accessToken the current access token
+ * @param configuration the OAuth2Configuration containing client credentials and endpoint URI
+ * @return a SessionToken containing the new token details, or null if the refresh process
+ * failed
*/
protected SessionToken doRefresh(
String refreshToken, String accessToken, OAuth2Configuration configuration) {
SessionToken sessionToken = null;
+ int maxRetries = 3;
+ int attempt = 0;
+ boolean success = false;
- RestTemplate restTemplate = new RestTemplate();
+ // Setup HTTP headers and body for the request
+ // Use restTemplate() method to get RestTemplate instance
+ OAuth2RestTemplate restTemplate = restTemplate();
HttpHeaders headers = getHttpHeaders(accessToken, configuration);
-
MultiValueMap requestBody = new LinkedMultiValueMap<>();
requestBody.add("grant_type", "refresh_token");
requestBody.add("refresh_token", refreshToken);
requestBody.add("client_secret", configuration.getClientSecret());
requestBody.add("client_id", configuration.getClientId());
-
HttpEntity> requestEntity =
new HttpEntity<>(requestBody, headers);
- OAuth2AccessToken newToken = null;
- try {
- newToken =
- restTemplate
- .exchange(
- configuration
- .buildRefreshTokenURI(), // Use exchange method for POST
- // request
- HttpMethod.POST,
- requestEntity, // Include request body
- OAuth2AccessToken.class)
- .getBody();
- } catch (Exception ex) {
- LOGGER.error("Error trying to obtain a refresh token.", ex);
- }
+ while (attempt < maxRetries && !success) {
+ attempt++;
+ LOGGER.info("Attempting to refresh token, attempt {} of {}", attempt, maxRetries);
- if (refreshToken != null
- && accessToken != null
- && !refreshToken.isEmpty()
- && !accessToken.isEmpty()
- && newToken != null
- && newToken.getValue() != null
- && !newToken.getValue().isEmpty()) {
- // update the Authentication
- String newRefreshToken =
- newToken.getRefreshToken() != null
- && newToken.getRefreshToken().getValue() != null
- && !newToken.getRefreshToken().getValue().isEmpty()
- ? newToken.getRefreshToken().getValue()
- : refreshToken;
- updateAuthToken(accessToken, newToken, newRefreshToken, configuration);
- sessionToken =
- sessionToken(newToken.getValue(), refreshToken, newToken.getExpiration());
- } else if (accessToken != null) {
- // update the Authentication
- sessionToken = sessionToken(accessToken, refreshToken, null);
- } else {
- // the refresh token was invalid. let's clear the session and send a remote logout.
- // then redirect to the login entry point.
- LOGGER.info(
- "Unable to refresh the token. The following request was performed: {}. Redirecting to login.",
- configuration.buildRefreshTokenURI("offline"));
- doLogout(null);
try {
- getResponse()
- .sendRedirect(
- "../../openid/"
- + configuration.getProvider().toLowerCase()
- + "/login");
- } catch (IOException e) {
- LOGGER.error("Error while sending redirect to login service. ", e);
- throw new RuntimeException(e);
+ ResponseEntity response =
+ restTemplate.exchange(
+ configuration.buildRefreshTokenURI(),
+ HttpMethod.POST,
+ requestEntity,
+ OAuth2AccessToken.class);
+
+ if (response.getStatusCode().is2xxSuccessful()) {
+ OAuth2AccessToken newToken = response.getBody();
+ if (newToken != null
+ && newToken.getValue() != null
+ && !newToken.getValue().isEmpty()) {
+ // Process and update the new token details
+ OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken();
+ OAuth2RefreshToken refreshTokenToUse =
+ (newRefreshToken != null && newRefreshToken.getValue() != null)
+ ? newRefreshToken
+ : new DefaultOAuth2RefreshToken(refreshToken);
+
+ updateAuthToken(accessToken, newToken, refreshTokenToUse, configuration);
+ sessionToken =
+ sessionToken(
+ newToken.getValue(),
+ refreshTokenToUse.getValue(),
+ newToken.getExpiration());
+
+ LOGGER.info("Token refreshed successfully on attempt {}", attempt);
+ success = true;
+ } else {
+ LOGGER.warn("Received empty or null token on attempt {}", attempt);
+ }
+ } else if (response.getStatusCode().is4xxClientError()) {
+ // For client errors (e.g., 400, 401, 403), do not retry.
+ LOGGER.error(
+ "Client error occurred: {}. Stopping further attempts.",
+ response.getStatusCode());
+ break;
+ } else {
+ // For server errors (5xx), continue retrying
+ LOGGER.warn("Server error occurred: {}. Retrying...", response.getStatusCode());
+ }
+ } catch (RestClientException ex) {
+ LOGGER.error("Attempt {}: Error refreshing token: {}", attempt, ex.getMessage());
+ if (attempt == maxRetries) {
+ LOGGER.error("Max retries reached. Unable to refresh token.");
+ }
}
}
+
+ // Handle unsuccessful refresh
+ if (!success) {
+ handleRefreshFailure(accessToken, refreshToken, configuration);
+ }
return sessionToken;
}
+ /**
+ * Handles the refresh failure by clearing the session, logging out remotely, and redirecting to
+ * login.
+ *
+ * @param accessToken the current access token
+ * @param refreshToken the current refresh token
+ * @param configuration the OAuth2Configuration with endpoint details
+ */
+ private void handleRefreshFailure(
+ String accessToken, String refreshToken, OAuth2Configuration configuration) {
+ LOGGER.info(
+ "Unable to refresh token after max retries. Clearing session and redirecting to login.");
+ doLogout(null);
+
+ try {
+ String redirectUrl =
+ "../../openid/" + configuration.getProvider().toLowerCase() + "/login";
+ getResponse().sendRedirect(redirectUrl);
+ } catch (IOException e) {
+ LOGGER.error("Error while sending redirect to login service: ", e);
+ throw new RuntimeException("Failed to redirect to login", e);
+ }
+ }
+
private static HttpHeaders getHttpHeaders(
String accessToken, OAuth2Configuration configuration) {
HttpHeaders headers = new HttpHeaders();
@@ -227,21 +272,20 @@ private SessionToken sessionToken(String accessToken, String refreshToken, Date
// Builds an authentication instance out of the passed values.
// Sets it to the cache and to the SecurityContext to be sure the new token is updates.
- private Authentication updateAuthToken(
+ protected void updateAuthToken(
String oldToken,
OAuth2AccessToken newToken,
- String refreshToken,
+ OAuth2RefreshToken refreshToken,
OAuth2Configuration conf) {
Authentication authentication = cache().get(oldToken);
if (authentication == null)
authentication = SecurityContextHolder.getContext().getAuthentication();
- if (authentication instanceof PreAuthenticatedAuthenticationToken) {
- if (LOGGER.isDebugEnabled())
- LOGGER.info("Updating the cache and the SecurityContext with new Auth details");
- String idToken = null;
+ if (LOGGER.isDebugEnabled())
+ LOGGER.info("Updating the cache and the SecurityContext with new Auth details");
+ if (authentication != null && !(authentication instanceof AnonymousAuthenticationToken)) {
TokenDetails details = getTokenDetails(authentication);
- idToken = details.getIdToken();
+ String idToken = details.getIdToken();
cache().removeEntry(oldToken);
PreAuthenticatedAuthenticationToken updated =
new PreAuthenticatedAuthenticationToken(
@@ -250,24 +294,23 @@ private Authentication updateAuthToken(
authentication.getAuthorities());
DefaultOAuth2AccessToken accessToken = new DefaultOAuth2AccessToken(newToken);
if (refreshToken != null) {
- accessToken.setRefreshToken(new DefaultOAuth2RefreshToken(refreshToken));
+ accessToken.setRefreshToken(refreshToken);
}
if (LOGGER.isDebugEnabled())
LOGGER.debug(
- "Creating new details. AccessToken: "
- + accessToken
- + " IdToken: "
- + idToken);
+ "Creating new details. AccessToken: {} IdToken: {}", accessToken, idToken);
updated.setDetails(new TokenDetails(accessToken, idToken, conf.getBeanName()));
cache().putCacheEntry(newToken.getValue(), updated);
SecurityContextHolder.getContext().setAuthentication(updated);
- authentication = updated;
}
- return authentication;
}
- private OAuth2AccessToken retrieveAccessToken(String accessToken) {
- Authentication authentication = cache().get(accessToken);
+ protected TokenDetails getTokenDetails(Authentication authentication) {
+ return OAuth2Utils.getTokenDetails(authentication);
+ }
+
+ protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
+ Authentication authentication = cache() != null ? cache().get(accessToken) : null;
OAuth2AccessToken result = null;
if (authentication != null) {
TokenDetails details = OAuth2Utils.getTokenDetails(authentication);
@@ -284,6 +327,14 @@ private OAuth2AccessToken retrieveAccessToken(String accessToken) {
return result;
}
+ protected HttpServletRequest getRequest() {
+ return OAuth2Utils.getRequest();
+ }
+
+ protected HttpServletResponse getResponse() {
+ return OAuth2Utils.getResponse();
+ }
+
@Override
public void doLogout(String sessionId) {
HttpServletRequest request = getRequest();
@@ -318,7 +369,7 @@ public void doLogout(String sessionId) {
if (token == null) {
token =
(String)
- RequestContextHolder.getRequestAttributes()
+ Objects.requireNonNull(RequestContextHolder.getRequestAttributes())
.getAttribute(REFRESH_TOKEN_PARAM, 0);
}
}
@@ -333,7 +384,7 @@ public void doLogout(String sessionId) {
if (accessToken == null) {
accessToken =
(String)
- RequestContextHolder.getRequestAttributes()
+ Objects.requireNonNull(RequestContextHolder.getRequestAttributes())
.getAttribute(ACCESS_TOKEN_PARAM, 0);
}
}
@@ -365,8 +416,10 @@ private void clearSession(OAuth2RestTemplate restTemplate, HttpServletRequest re
.removePreservedState(accessTokenRequest.getStateKey());
}
try {
- accessTokenRequest.remove("access_token");
- accessTokenRequest.remove("refresh_token");
+ if (accessTokenRequest != null) {
+ accessTokenRequest.remove("access_token");
+ accessTokenRequest.remove("refresh_token");
+ }
request.logout();
} catch (ServletException e) {
LOGGER.error("Error happened while doing request logout: ", e);
@@ -445,9 +498,8 @@ protected void callRemoteLogout(String token, String accessToken) {
protected void clearCookies(HttpServletRequest request, HttpServletResponse response) {
javax.servlet.http.Cookie[] allCookies = request.getCookies();
- if (allCookies != null && allCookies.length > 0)
- for (int i = 0; i < allCookies.length; i++) {
- javax.servlet.http.Cookie toDelete = allCookies[i];
+ if (allCookies != null)
+ for (javax.servlet.http.Cookie toDelete : allCookies) {
if (deleteCookie(toDelete)) {
toDelete.setMaxAge(-1);
toDelete.setPath("/");
@@ -463,7 +515,7 @@ protected boolean deleteCookie(javax.servlet.http.Cookie c) {
|| c.getName().equalsIgnoreCase(REFRESH_TOKEN_PARAM);
}
- private TokenAuthenticationCache cache() {
+ protected TokenAuthenticationCache cache() {
return GeoStoreContext.bean("oAuth2Cache", TokenAuthenticationCache.class);
}
diff --git a/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java b/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java
new file mode 100644
index 00000000..68d63ceb
--- /dev/null
+++ b/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java
@@ -0,0 +1,478 @@
+package it.geosolutions.geostore.rest.security.oauth2.openid_connect;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.*;
+
+import it.geosolutions.geostore.services.rest.model.SessionToken;
+import it.geosolutions.geostore.services.rest.security.TokenAuthenticationCache;
+import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2Configuration;
+import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2SessionServiceDelegate;
+import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2Utils;
+import it.geosolutions.geostore.services.rest.security.oauth2.TokenDetails;
+import java.util.Date;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import org.junit.jupiter.api.*;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+import org.springframework.http.*;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.OAuth2RestTemplate;
+import org.springframework.security.oauth2.common.*;
+import org.springframework.web.client.HttpClientErrorException;
+import org.springframework.web.client.HttpServerErrorException;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+
+/** Test class for OAuth2SessionServiceDelegate. */
+class RefreshTokenServiceTest {
+
+ private TestOAuth2SessionServiceDelegate serviceDelegate;
+ private OAuth2Configuration configuration;
+ private OAuth2RestTemplate restTemplate;
+ private MockHttpServletRequest mockRequest;
+ private MockHttpServletResponse mockResponse;
+ private DefaultOAuth2AccessToken mockOAuth2AccessToken;
+
+ @Mock private TokenAuthenticationCache authenticationCache;
+
+ @BeforeEach
+ void setUp() {
+ MockitoAnnotations.openMocks(this);
+
+ // Initialize mocks and dependencies
+ configuration = mock(OAuth2Configuration.class);
+ restTemplate = mock(OAuth2RestTemplate.class);
+ authenticationCache = mock(TokenAuthenticationCache.class);
+
+ // Create an instance of the test subclass
+ serviceDelegate = spy(new TestOAuth2SessionServiceDelegate());
+ serviceDelegate.setRestTemplate(restTemplate);
+ serviceDelegate.setConfiguration(configuration);
+ serviceDelegate.authenticationCache = authenticationCache;
+
+ // Set up mock request and response
+ mockRequest = new MockHttpServletRequest();
+ mockResponse = new MockHttpServletResponse();
+
+ // Set the RequestAttributes in RequestContextHolder
+ RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(mockRequest));
+
+ // Mock configuration behavior
+ when(configuration.isEnabled()).thenReturn(true);
+ when(configuration.getClientId()).thenReturn("testClientId");
+ when(configuration.getClientSecret()).thenReturn("testClientSecret");
+ when(configuration.buildRefreshTokenURI()).thenReturn("https://example.com/oauth2/token");
+
+ // Mock the existing OAuth2AccessToken with a refresh token
+ mockOAuth2AccessToken = new DefaultOAuth2AccessToken("providedAccessToken");
+ OAuth2RefreshToken mockRefreshToken = new DefaultOAuth2RefreshToken("existingRefreshToken");
+ mockOAuth2AccessToken.setRefreshToken(mockRefreshToken);
+ mockOAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() + 3600 * 1000));
+
+ // Initialize currentAccessToken
+ serviceDelegate.currentAccessToken = mockOAuth2AccessToken;
+
+ // Mock the Authentication object
+ Authentication mockAuthentication = mock(Authentication.class);
+
+ // Mock TokenDetails
+ TokenDetails mockTokenDetails = mock(TokenDetails.class);
+ when(mockTokenDetails.getIdToken()).thenReturn("mockIdToken");
+
+ // Mock getTokenDetails(authentication) to return mockTokenDetails
+ doReturn(mockTokenDetails).when(serviceDelegate).getTokenDetails(mockAuthentication);
+
+ // Ensure cache returns the mocked Authentication for oldToken
+ when(authenticationCache.get("providedAccessToken")).thenReturn(mockAuthentication);
+
+ // Optionally, set up SecurityContextHolder
+ SecurityContext securityContext = mock(SecurityContext.class);
+ when(securityContext.getAuthentication()).thenReturn(mockAuthentication);
+ SecurityContextHolder.setContext(securityContext);
+ }
+
+ @AfterEach
+ void tearDown() {
+ // Clear the RequestContextHolder after each test
+ RequestContextHolder.resetRequestAttributes();
+
+ // Close static mocks if any
+ // For Mockito 3.4.0 and above
+ Mockito.framework().clearInlineMocks();
+ }
+
+ @Test
+ void testRefreshWithValidTokens() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = "providedAccessToken";
+
+ // Mock the RestTemplate exchange method to simulate a successful token refresh
+ DefaultOAuth2AccessToken newAccessToken = new DefaultOAuth2AccessToken("newAccessToken");
+ OAuth2RefreshToken newRefreshToken = new DefaultOAuth2RefreshToken("newRefreshToken");
+ newAccessToken.setRefreshToken(newRefreshToken);
+ newAccessToken.setExpiration(
+ new Date(System.currentTimeMillis() + 7200 * 1000)); // Expires in 2 hours
+
+ ResponseEntity responseEntity =
+ new ResponseEntity<>(newAccessToken, HttpStatus.OK);
+
+ when(restTemplate.exchange(
+ anyString(),
+ eq(HttpMethod.POST),
+ any(HttpEntity.class),
+ eq(OAuth2AccessToken.class)))
+ .thenReturn(responseEntity);
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(sessionToken, "SessionToken should not be null");
+ assertEquals(
+ "newAccessToken", sessionToken.getAccessToken(), "Access token should be updated");
+ assertEquals(
+ "newRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should be updated");
+ assertTrue(
+ sessionToken.getExpires() > System.currentTimeMillis(),
+ "Token expiration should be in the future");
+ assertEquals("bearer", sessionToken.getTokenType(), "Token type should be 'bearer'");
+ }
+
+ @Test
+ void testRefreshWithInvalidRefreshToken() {
+ // Arrange
+ String refreshToken = "invalidRefreshToken";
+ String accessToken = "providedAccessToken";
+
+ // Mock the RestTemplate exchange method to simulate a client error (400 Bad Request)
+ when(restTemplate.exchange(
+ anyString(),
+ eq(HttpMethod.POST),
+ any(HttpEntity.class),
+ eq(OAuth2AccessToken.class)))
+ .thenThrow(new HttpClientErrorException(HttpStatus.BAD_REQUEST));
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(sessionToken, "SessionToken should not be null even when refresh fails");
+ assertEquals(
+ "providedAccessToken",
+ sessionToken.getAccessToken(),
+ "Access token should remain unchanged");
+ assertEquals(
+ "existingRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should remain unchanged");
+ }
+
+ @Test
+ void testRefreshWithServerError() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = "providedAccessToken";
+
+ // Mock the RestTemplate exchange method to simulate a server error (500 Internal Server
+ // Error)
+ when(restTemplate.exchange(
+ anyString(),
+ eq(HttpMethod.POST),
+ any(HttpEntity.class),
+ eq(OAuth2AccessToken.class)))
+ .thenThrow(new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR));
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(sessionToken, "SessionToken should not be null even when refresh fails");
+ assertEquals(
+ "providedAccessToken",
+ sessionToken.getAccessToken(),
+ "Access token should remain unchanged after server error");
+ assertEquals(
+ "existingRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should remain unchanged after server error");
+ // You can also verify that the method retried the expected number of times
+ verify(restTemplate, times(3))
+ .exchange(
+ anyString(),
+ eq(HttpMethod.POST),
+ any(HttpEntity.class),
+ eq(OAuth2AccessToken.class));
+ }
+
+ @Test
+ void testRefreshWithNullResponse() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = "providedAccessToken";
+
+ // Mock the RestTemplate exchange method to return a response with null body
+ ResponseEntity responseEntity =
+ new ResponseEntity<>(null, HttpStatus.OK);
+ when(restTemplate.exchange(
+ anyString(),
+ eq(HttpMethod.POST),
+ any(HttpEntity.class),
+ eq(OAuth2AccessToken.class)))
+ .thenReturn(responseEntity);
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(sessionToken, "SessionToken should not be null even when response is null");
+ assertEquals(
+ "providedAccessToken",
+ sessionToken.getAccessToken(),
+ "Access token should remain unchanged");
+ assertEquals(
+ "existingRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should remain unchanged");
+ }
+
+ @Test
+ void testRefreshWhenConfigurationDisabled() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = "providedAccessToken";
+
+ // Mock configuration to be disabled
+ when(configuration.isEnabled()).thenReturn(false);
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(
+ sessionToken, "SessionToken should not be null when configuration is disabled");
+ assertEquals(
+ "providedAccessToken",
+ sessionToken.getAccessToken(),
+ "Access token should remain unchanged");
+ assertEquals(
+ "existingRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should remain unchanged");
+ // Verify that no exchange was attempted
+ verify(restTemplate, never())
+ .exchange(
+ anyString(),
+ any(HttpMethod.class),
+ any(HttpEntity.class),
+ eq(OAuth2AccessToken.class));
+ }
+
+ @Test
+ void testRefreshWithMissingAccessToken() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = null; // Access token is missing
+
+ // Act & Assert
+ Exception exception =
+ assertThrows(
+ RuntimeException.class,
+ () -> {
+ serviceDelegate.refresh(refreshToken, accessToken);
+ });
+
+ assertTrue(
+ exception
+ .getMessage()
+ .contains("Either the accessToken or the refresh token are missing"),
+ "Expected exception message");
+ }
+
+ @Test
+ void testRefreshWhenCacheReturnsNullAuthentication() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = "providedAccessToken";
+
+ // Mock cache to return null
+ when(authenticationCache.get("providedAccessToken")).thenReturn(null);
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(
+ sessionToken,
+ "SessionToken should not be null even when authentication is not found in cache");
+ assertEquals(
+ "providedAccessToken",
+ sessionToken.getAccessToken(),
+ "Access token should remain unchanged");
+ assertEquals(
+ "existingRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should remain unchanged");
+ }
+
+ @Test
+ void testRefreshWhenAuthenticationIsAnonymous() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = "providedAccessToken";
+
+ // Mock an AnonymousAuthenticationToken
+ Authentication anonymousAuthentication =
+ mock(
+ org.springframework.security.authentication.AnonymousAuthenticationToken
+ .class);
+ when(authenticationCache.get("providedAccessToken")).thenReturn(anonymousAuthentication);
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(
+ sessionToken,
+ "SessionToken should not be null even when authentication is anonymous");
+ assertEquals(
+ "providedAccessToken",
+ sessionToken.getAccessToken(),
+ "Access token should remain unchanged");
+ assertEquals(
+ "existingRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should remain unchanged");
+ }
+
+ @Test
+ void testRefreshWithExpiredAccessToken() {
+ // Arrange
+ String refreshToken = "providedRefreshToken";
+ String accessToken = "expiredAccessToken";
+
+ // Set the current access token to be expired
+ mockOAuth2AccessToken.setExpiration(
+ new Date(System.currentTimeMillis() - 1000)); // Set expiration in the past
+ serviceDelegate.currentAccessToken = mockOAuth2AccessToken;
+
+ // Mock the RestTemplate exchange method to simulate a successful token refresh
+ DefaultOAuth2AccessToken newAccessToken = new DefaultOAuth2AccessToken("newAccessToken");
+ OAuth2RefreshToken newRefreshToken = new DefaultOAuth2RefreshToken("newRefreshToken");
+ newAccessToken.setRefreshToken(newRefreshToken);
+ newAccessToken.setExpiration(
+ new Date(System.currentTimeMillis() + 7200 * 1000)); // Expires in 2 hours
+
+ ResponseEntity responseEntity =
+ new ResponseEntity<>(newAccessToken, HttpStatus.OK);
+
+ when(restTemplate.exchange(
+ anyString(),
+ eq(HttpMethod.POST),
+ any(HttpEntity.class),
+ eq(OAuth2AccessToken.class)))
+ .thenReturn(responseEntity);
+
+ // Act
+ SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken);
+
+ // Assert
+ assertNotNull(sessionToken, "SessionToken should not be null");
+ assertEquals(
+ "newAccessToken", sessionToken.getAccessToken(), "Access token should be updated");
+ assertEquals(
+ "newRefreshToken",
+ sessionToken.getRefreshToken(),
+ "Refresh token should be updated");
+ assertTrue(
+ sessionToken.getExpires() > System.currentTimeMillis(),
+ "Token expiration should be in the future");
+ }
+
+ /** Test subclass of OAuth2SessionServiceDelegate for testing purposes. */
+ class TestOAuth2SessionServiceDelegate extends OAuth2SessionServiceDelegate {
+
+ private OAuth2RestTemplate restTemplate;
+ private OAuth2Configuration configuration;
+ private OAuth2AccessToken currentAccessToken;
+ protected TokenAuthenticationCache authenticationCache;
+
+ public TestOAuth2SessionServiceDelegate() {
+ super(null, null); // Mocked dependencies
+ }
+
+ public void setRestTemplate(OAuth2RestTemplate restTemplate) {
+ this.restTemplate = restTemplate;
+ }
+
+ public void setConfiguration(OAuth2Configuration configuration) {
+ this.configuration = configuration;
+ }
+
+ @Override
+ protected OAuth2RestTemplate restTemplate() {
+ return restTemplate;
+ }
+
+ @Override
+ protected OAuth2Configuration configuration() {
+ return configuration;
+ }
+
+ @Override
+ protected HttpServletRequest getRequest() {
+ return mockRequest;
+ }
+
+ @Override
+ protected HttpServletResponse getResponse() {
+ return mockResponse;
+ }
+
+ @Override
+ protected TokenDetails getTokenDetails(Authentication authentication) {
+ // This method is now mocked in the test setup using doReturn()
+ return super.getTokenDetails(authentication);
+ }
+
+ @Override
+ protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
+ return currentAccessToken;
+ }
+
+ @Override
+ protected TokenAuthenticationCache cache() {
+ return authenticationCache; // Return the mocked cache
+ }
+
+ @Override
+ protected void updateAuthToken(
+ String oldAccessToken,
+ OAuth2AccessToken newAccessToken,
+ OAuth2RefreshToken newRefreshToken,
+ OAuth2Configuration configuration) {
+ // Update the currentAccessToken to the newAccessToken
+ this.currentAccessToken = newAccessToken;
+
+ // Simulate updating the authentication in the cache
+ Authentication newAuthentication = mock(Authentication.class);
+ TokenDetails newTokenDetails = mock(TokenDetails.class);
+ when(newTokenDetails.getAccessToken()).thenReturn(newAccessToken);
+ when(OAuth2Utils.getTokenDetails(newAuthentication)).thenReturn(newTokenDetails);
+
+ // Remove the old token from the cache
+ authenticationCache.removeEntry(oldAccessToken);
+
+ // Add the new token to the cache
+ authenticationCache.putCacheEntry(newAccessToken.getValue(), newAuthentication);
+ }
+ }
+}