diff --git a/src/modules/rest/impl/pom.xml b/src/modules/rest/impl/pom.xml index de722bc7..f278bb81 100644 --- a/src/modules/rest/impl/pom.xml +++ b/src/modules/rest/impl/pom.xml @@ -255,6 +255,22 @@ test + + + org.junit.jupiter + junit-jupiter + 5.9.1 + test + + + + + org.mockito + mockito-junit-jupiter + 4.8.0 + test + + @@ -288,8 +304,15 @@ --> + + + maven-surefire-plugin + 3.0.0-M7 + + false + + - diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java index 99f04981..16055016 100644 --- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java +++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/impl/RESTSessionServiceImpl.java @@ -189,7 +189,9 @@ private SessionToken toSessionToken(String accessToken, UserSession sessionToken @Override public SessionToken refresh(SecurityContext sc, String sessionId, String refreshToken) { String provider = - (String) RequestContextHolder.getRequestAttributes().getAttribute(PROVIDER_KEY, 0); + (String) + Objects.requireNonNull(RequestContextHolder.getRequestAttributes()) + .getAttribute(PROVIDER_KEY, 0); SessionServiceDelegate delegate = getDelegate(provider); return delegate.refresh(refreshToken, sessionId); } diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java index 8cbad1aa..1165d4c3 100644 --- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java +++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java @@ -41,6 +41,7 @@ import java.io.IOException; import java.util.Date; import java.util.Map; +import java.util.Objects; import java.util.Optional; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -48,6 +49,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.springframework.http.*; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2ClientContext; @@ -56,11 +58,13 @@ import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken; import org.springframework.security.oauth2.common.OAuth2AccessToken; +import org.springframework.security.oauth2.common.OAuth2RefreshToken; import org.springframework.security.oauth2.provider.authentication.OAuth2AuthenticationDetails; import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.HttpMessageConverterExtractor; +import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestTemplate; import org.springframework.web.context.request.RequestContextHolder; @@ -81,6 +85,9 @@ public OAuth2SessionServiceDelegate( this.userService = userService; } + public OAuth2SessionServiceDelegate( + RestTemplate restTemplate, OAuth2Configuration configuration) {} + @Override public SessionToken refresh(String refreshToken, String accessToken) { HttpServletRequest request = getRequest(); @@ -90,25 +97,27 @@ public SessionToken refresh(String refreshToken, String accessToken) { throw new NotFoundWebEx("Either the accessToken or the refresh token are missing"); OAuth2AccessToken currentToken = retrieveAccessToken(accessToken); - Date expiresIn = currentToken.getExpiration(); - if (refreshToken == null || refreshToken.isEmpty()) - refreshToken = getParameterValue(REFRESH_TOKEN_PARAM, request); - Date fiveMinutesFromNow = fiveMinutesFromNow(); + String refreshTokenToUse = + currentToken.getRefreshToken() != null + && currentToken.getRefreshToken().getValue() != null + && !currentToken.getRefreshToken().getValue().isEmpty() + ? currentToken.getRefreshToken().getValue() + : refreshToken; + if (refreshTokenToUse == null || refreshTokenToUse.isEmpty()) + refreshTokenToUse = getParameterValue(REFRESH_TOKEN_PARAM, request); SessionToken sessionToken = null; OAuth2Configuration configuration = configuration(); if (configuration != null && configuration.isEnabled()) { - if ((expiresIn == null || fiveMinutesFromNow.after(expiresIn)) - && refreshToken != null) { - if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token."); - try { - sessionToken = doRefresh(refreshToken, accessToken, configuration); - } catch (NullPointerException npe) { - LOGGER.error("Current configuration wasn't correctly initialized."); - } + if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token."); + try { + sessionToken = doRefresh(refreshTokenToUse, accessToken, configuration); + } catch (NullPointerException npe) { + LOGGER.error("Current configuration wasn't correctly initialized."); } } if (sessionToken == null) - sessionToken = sessionToken(accessToken, refreshToken, currentToken.getExpiration()); + sessionToken = + sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration()); request.setAttribute( OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken()); @@ -119,86 +128,122 @@ public SessionToken refresh(String refreshToken, String accessToken) { } /** - * Invokes the refresh endpoint and return a session token holding the updated tokens details. + * Invokes the refresh endpoint to get a new session token with updated token details. + * + *

This method attempts to refresh the session by exchanging the provided refresh token for a + * new access token. If the refresh token is invalid or the request fails after several retries, + * the session is cleared, and the user is redirected to the login page. * - * @param refreshToken the refresh token. - * @param accessToken the access token. - * @param configuration the OAuth2Configuration. - * @return the SessionToken. + * @param refreshToken the refresh token to use for obtaining new access and refresh tokens + * @param accessToken the current access token + * @param configuration the OAuth2Configuration containing client credentials and endpoint URI + * @return a SessionToken containing the new token details, or null if the refresh process + * failed */ protected SessionToken doRefresh( String refreshToken, String accessToken, OAuth2Configuration configuration) { SessionToken sessionToken = null; + int maxRetries = 3; + int attempt = 0; + boolean success = false; - RestTemplate restTemplate = new RestTemplate(); + // Setup HTTP headers and body for the request + // Use restTemplate() method to get RestTemplate instance + OAuth2RestTemplate restTemplate = restTemplate(); HttpHeaders headers = getHttpHeaders(accessToken, configuration); - MultiValueMap requestBody = new LinkedMultiValueMap<>(); requestBody.add("grant_type", "refresh_token"); requestBody.add("refresh_token", refreshToken); requestBody.add("client_secret", configuration.getClientSecret()); requestBody.add("client_id", configuration.getClientId()); - HttpEntity> requestEntity = new HttpEntity<>(requestBody, headers); - OAuth2AccessToken newToken = null; - try { - newToken = - restTemplate - .exchange( - configuration - .buildRefreshTokenURI(), // Use exchange method for POST - // request - HttpMethod.POST, - requestEntity, // Include request body - OAuth2AccessToken.class) - .getBody(); - } catch (Exception ex) { - LOGGER.error("Error trying to obtain a refresh token.", ex); - } + while (attempt < maxRetries && !success) { + attempt++; + LOGGER.info("Attempting to refresh token, attempt {} of {}", attempt, maxRetries); - if (refreshToken != null - && accessToken != null - && !refreshToken.isEmpty() - && !accessToken.isEmpty() - && newToken != null - && newToken.getValue() != null - && !newToken.getValue().isEmpty()) { - // update the Authentication - String newRefreshToken = - newToken.getRefreshToken() != null - && newToken.getRefreshToken().getValue() != null - && !newToken.getRefreshToken().getValue().isEmpty() - ? newToken.getRefreshToken().getValue() - : refreshToken; - updateAuthToken(accessToken, newToken, newRefreshToken, configuration); - sessionToken = - sessionToken(newToken.getValue(), refreshToken, newToken.getExpiration()); - } else if (accessToken != null) { - // update the Authentication - sessionToken = sessionToken(accessToken, refreshToken, null); - } else { - // the refresh token was invalid. let's clear the session and send a remote logout. - // then redirect to the login entry point. - LOGGER.info( - "Unable to refresh the token. The following request was performed: {}. Redirecting to login.", - configuration.buildRefreshTokenURI("offline")); - doLogout(null); try { - getResponse() - .sendRedirect( - "../../openid/" - + configuration.getProvider().toLowerCase() - + "/login"); - } catch (IOException e) { - LOGGER.error("Error while sending redirect to login service. ", e); - throw new RuntimeException(e); + ResponseEntity response = + restTemplate.exchange( + configuration.buildRefreshTokenURI(), + HttpMethod.POST, + requestEntity, + OAuth2AccessToken.class); + + if (response.getStatusCode().is2xxSuccessful()) { + OAuth2AccessToken newToken = response.getBody(); + if (newToken != null + && newToken.getValue() != null + && !newToken.getValue().isEmpty()) { + // Process and update the new token details + OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken(); + OAuth2RefreshToken refreshTokenToUse = + (newRefreshToken != null && newRefreshToken.getValue() != null) + ? newRefreshToken + : new DefaultOAuth2RefreshToken(refreshToken); + + updateAuthToken(accessToken, newToken, refreshTokenToUse, configuration); + sessionToken = + sessionToken( + newToken.getValue(), + refreshTokenToUse.getValue(), + newToken.getExpiration()); + + LOGGER.info("Token refreshed successfully on attempt {}", attempt); + success = true; + } else { + LOGGER.warn("Received empty or null token on attempt {}", attempt); + } + } else if (response.getStatusCode().is4xxClientError()) { + // For client errors (e.g., 400, 401, 403), do not retry. + LOGGER.error( + "Client error occurred: {}. Stopping further attempts.", + response.getStatusCode()); + break; + } else { + // For server errors (5xx), continue retrying + LOGGER.warn("Server error occurred: {}. Retrying...", response.getStatusCode()); + } + } catch (RestClientException ex) { + LOGGER.error("Attempt {}: Error refreshing token: {}", attempt, ex.getMessage()); + if (attempt == maxRetries) { + LOGGER.error("Max retries reached. Unable to refresh token."); + } } } + + // Handle unsuccessful refresh + if (!success) { + handleRefreshFailure(accessToken, refreshToken, configuration); + } return sessionToken; } + /** + * Handles the refresh failure by clearing the session, logging out remotely, and redirecting to + * login. + * + * @param accessToken the current access token + * @param refreshToken the current refresh token + * @param configuration the OAuth2Configuration with endpoint details + */ + private void handleRefreshFailure( + String accessToken, String refreshToken, OAuth2Configuration configuration) { + LOGGER.info( + "Unable to refresh token after max retries. Clearing session and redirecting to login."); + doLogout(null); + + try { + String redirectUrl = + "../../openid/" + configuration.getProvider().toLowerCase() + "/login"; + getResponse().sendRedirect(redirectUrl); + } catch (IOException e) { + LOGGER.error("Error while sending redirect to login service: ", e); + throw new RuntimeException("Failed to redirect to login", e); + } + } + private static HttpHeaders getHttpHeaders( String accessToken, OAuth2Configuration configuration) { HttpHeaders headers = new HttpHeaders(); @@ -227,21 +272,20 @@ private SessionToken sessionToken(String accessToken, String refreshToken, Date // Builds an authentication instance out of the passed values. // Sets it to the cache and to the SecurityContext to be sure the new token is updates. - private Authentication updateAuthToken( + protected void updateAuthToken( String oldToken, OAuth2AccessToken newToken, - String refreshToken, + OAuth2RefreshToken refreshToken, OAuth2Configuration conf) { Authentication authentication = cache().get(oldToken); if (authentication == null) authentication = SecurityContextHolder.getContext().getAuthentication(); - if (authentication instanceof PreAuthenticatedAuthenticationToken) { - if (LOGGER.isDebugEnabled()) - LOGGER.info("Updating the cache and the SecurityContext with new Auth details"); - String idToken = null; + if (LOGGER.isDebugEnabled()) + LOGGER.info("Updating the cache and the SecurityContext with new Auth details"); + if (authentication != null && !(authentication instanceof AnonymousAuthenticationToken)) { TokenDetails details = getTokenDetails(authentication); - idToken = details.getIdToken(); + String idToken = details.getIdToken(); cache().removeEntry(oldToken); PreAuthenticatedAuthenticationToken updated = new PreAuthenticatedAuthenticationToken( @@ -250,24 +294,23 @@ private Authentication updateAuthToken( authentication.getAuthorities()); DefaultOAuth2AccessToken accessToken = new DefaultOAuth2AccessToken(newToken); if (refreshToken != null) { - accessToken.setRefreshToken(new DefaultOAuth2RefreshToken(refreshToken)); + accessToken.setRefreshToken(refreshToken); } if (LOGGER.isDebugEnabled()) LOGGER.debug( - "Creating new details. AccessToken: " - + accessToken - + " IdToken: " - + idToken); + "Creating new details. AccessToken: {} IdToken: {}", accessToken, idToken); updated.setDetails(new TokenDetails(accessToken, idToken, conf.getBeanName())); cache().putCacheEntry(newToken.getValue(), updated); SecurityContextHolder.getContext().setAuthentication(updated); - authentication = updated; } - return authentication; } - private OAuth2AccessToken retrieveAccessToken(String accessToken) { - Authentication authentication = cache().get(accessToken); + protected TokenDetails getTokenDetails(Authentication authentication) { + return OAuth2Utils.getTokenDetails(authentication); + } + + protected OAuth2AccessToken retrieveAccessToken(String accessToken) { + Authentication authentication = cache() != null ? cache().get(accessToken) : null; OAuth2AccessToken result = null; if (authentication != null) { TokenDetails details = OAuth2Utils.getTokenDetails(authentication); @@ -284,6 +327,14 @@ private OAuth2AccessToken retrieveAccessToken(String accessToken) { return result; } + protected HttpServletRequest getRequest() { + return OAuth2Utils.getRequest(); + } + + protected HttpServletResponse getResponse() { + return OAuth2Utils.getResponse(); + } + @Override public void doLogout(String sessionId) { HttpServletRequest request = getRequest(); @@ -318,7 +369,7 @@ public void doLogout(String sessionId) { if (token == null) { token = (String) - RequestContextHolder.getRequestAttributes() + Objects.requireNonNull(RequestContextHolder.getRequestAttributes()) .getAttribute(REFRESH_TOKEN_PARAM, 0); } } @@ -333,7 +384,7 @@ public void doLogout(String sessionId) { if (accessToken == null) { accessToken = (String) - RequestContextHolder.getRequestAttributes() + Objects.requireNonNull(RequestContextHolder.getRequestAttributes()) .getAttribute(ACCESS_TOKEN_PARAM, 0); } } @@ -365,8 +416,10 @@ private void clearSession(OAuth2RestTemplate restTemplate, HttpServletRequest re .removePreservedState(accessTokenRequest.getStateKey()); } try { - accessTokenRequest.remove("access_token"); - accessTokenRequest.remove("refresh_token"); + if (accessTokenRequest != null) { + accessTokenRequest.remove("access_token"); + accessTokenRequest.remove("refresh_token"); + } request.logout(); } catch (ServletException e) { LOGGER.error("Error happened while doing request logout: ", e); @@ -445,9 +498,8 @@ protected void callRemoteLogout(String token, String accessToken) { protected void clearCookies(HttpServletRequest request, HttpServletResponse response) { javax.servlet.http.Cookie[] allCookies = request.getCookies(); - if (allCookies != null && allCookies.length > 0) - for (int i = 0; i < allCookies.length; i++) { - javax.servlet.http.Cookie toDelete = allCookies[i]; + if (allCookies != null) + for (javax.servlet.http.Cookie toDelete : allCookies) { if (deleteCookie(toDelete)) { toDelete.setMaxAge(-1); toDelete.setPath("/"); @@ -463,7 +515,7 @@ protected boolean deleteCookie(javax.servlet.http.Cookie c) { || c.getName().equalsIgnoreCase(REFRESH_TOKEN_PARAM); } - private TokenAuthenticationCache cache() { + protected TokenAuthenticationCache cache() { return GeoStoreContext.bean("oAuth2Cache", TokenAuthenticationCache.class); } diff --git a/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java b/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java new file mode 100644 index 00000000..68d63ceb --- /dev/null +++ b/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java @@ -0,0 +1,478 @@ +package it.geosolutions.geostore.rest.security.oauth2.openid_connect; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import it.geosolutions.geostore.services.rest.model.SessionToken; +import it.geosolutions.geostore.services.rest.security.TokenAuthenticationCache; +import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2Configuration; +import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2SessionServiceDelegate; +import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2Utils; +import it.geosolutions.geostore.services.rest.security.oauth2.TokenDetails; +import java.util.Date; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.*; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.springframework.http.*; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.OAuth2RestTemplate; +import org.springframework.security.oauth2.common.*; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +/** Test class for OAuth2SessionServiceDelegate. */ +class RefreshTokenServiceTest { + + private TestOAuth2SessionServiceDelegate serviceDelegate; + private OAuth2Configuration configuration; + private OAuth2RestTemplate restTemplate; + private MockHttpServletRequest mockRequest; + private MockHttpServletResponse mockResponse; + private DefaultOAuth2AccessToken mockOAuth2AccessToken; + + @Mock private TokenAuthenticationCache authenticationCache; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + // Initialize mocks and dependencies + configuration = mock(OAuth2Configuration.class); + restTemplate = mock(OAuth2RestTemplate.class); + authenticationCache = mock(TokenAuthenticationCache.class); + + // Create an instance of the test subclass + serviceDelegate = spy(new TestOAuth2SessionServiceDelegate()); + serviceDelegate.setRestTemplate(restTemplate); + serviceDelegate.setConfiguration(configuration); + serviceDelegate.authenticationCache = authenticationCache; + + // Set up mock request and response + mockRequest = new MockHttpServletRequest(); + mockResponse = new MockHttpServletResponse(); + + // Set the RequestAttributes in RequestContextHolder + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(mockRequest)); + + // Mock configuration behavior + when(configuration.isEnabled()).thenReturn(true); + when(configuration.getClientId()).thenReturn("testClientId"); + when(configuration.getClientSecret()).thenReturn("testClientSecret"); + when(configuration.buildRefreshTokenURI()).thenReturn("https://example.com/oauth2/token"); + + // Mock the existing OAuth2AccessToken with a refresh token + mockOAuth2AccessToken = new DefaultOAuth2AccessToken("providedAccessToken"); + OAuth2RefreshToken mockRefreshToken = new DefaultOAuth2RefreshToken("existingRefreshToken"); + mockOAuth2AccessToken.setRefreshToken(mockRefreshToken); + mockOAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() + 3600 * 1000)); + + // Initialize currentAccessToken + serviceDelegate.currentAccessToken = mockOAuth2AccessToken; + + // Mock the Authentication object + Authentication mockAuthentication = mock(Authentication.class); + + // Mock TokenDetails + TokenDetails mockTokenDetails = mock(TokenDetails.class); + when(mockTokenDetails.getIdToken()).thenReturn("mockIdToken"); + + // Mock getTokenDetails(authentication) to return mockTokenDetails + doReturn(mockTokenDetails).when(serviceDelegate).getTokenDetails(mockAuthentication); + + // Ensure cache returns the mocked Authentication for oldToken + when(authenticationCache.get("providedAccessToken")).thenReturn(mockAuthentication); + + // Optionally, set up SecurityContextHolder + SecurityContext securityContext = mock(SecurityContext.class); + when(securityContext.getAuthentication()).thenReturn(mockAuthentication); + SecurityContextHolder.setContext(securityContext); + } + + @AfterEach + void tearDown() { + // Clear the RequestContextHolder after each test + RequestContextHolder.resetRequestAttributes(); + + // Close static mocks if any + // For Mockito 3.4.0 and above + Mockito.framework().clearInlineMocks(); + } + + @Test + void testRefreshWithValidTokens() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "providedAccessToken"; + + // Mock the RestTemplate exchange method to simulate a successful token refresh + DefaultOAuth2AccessToken newAccessToken = new DefaultOAuth2AccessToken("newAccessToken"); + OAuth2RefreshToken newRefreshToken = new DefaultOAuth2RefreshToken("newRefreshToken"); + newAccessToken.setRefreshToken(newRefreshToken); + newAccessToken.setExpiration( + new Date(System.currentTimeMillis() + 7200 * 1000)); // Expires in 2 hours + + ResponseEntity responseEntity = + new ResponseEntity<>(newAccessToken, HttpStatus.OK); + + when(restTemplate.exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class))) + .thenReturn(responseEntity); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull(sessionToken, "SessionToken should not be null"); + assertEquals( + "newAccessToken", sessionToken.getAccessToken(), "Access token should be updated"); + assertEquals( + "newRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should be updated"); + assertTrue( + sessionToken.getExpires() > System.currentTimeMillis(), + "Token expiration should be in the future"); + assertEquals("bearer", sessionToken.getTokenType(), "Token type should be 'bearer'"); + } + + @Test + void testRefreshWithInvalidRefreshToken() { + // Arrange + String refreshToken = "invalidRefreshToken"; + String accessToken = "providedAccessToken"; + + // Mock the RestTemplate exchange method to simulate a client error (400 Bad Request) + when(restTemplate.exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class))) + .thenThrow(new HttpClientErrorException(HttpStatus.BAD_REQUEST)); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull(sessionToken, "SessionToken should not be null even when refresh fails"); + assertEquals( + "providedAccessToken", + sessionToken.getAccessToken(), + "Access token should remain unchanged"); + assertEquals( + "existingRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should remain unchanged"); + } + + @Test + void testRefreshWithServerError() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "providedAccessToken"; + + // Mock the RestTemplate exchange method to simulate a server error (500 Internal Server + // Error) + when(restTemplate.exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class))) + .thenThrow(new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR)); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull(sessionToken, "SessionToken should not be null even when refresh fails"); + assertEquals( + "providedAccessToken", + sessionToken.getAccessToken(), + "Access token should remain unchanged after server error"); + assertEquals( + "existingRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should remain unchanged after server error"); + // You can also verify that the method retried the expected number of times + verify(restTemplate, times(3)) + .exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class)); + } + + @Test + void testRefreshWithNullResponse() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "providedAccessToken"; + + // Mock the RestTemplate exchange method to return a response with null body + ResponseEntity responseEntity = + new ResponseEntity<>(null, HttpStatus.OK); + when(restTemplate.exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class))) + .thenReturn(responseEntity); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull(sessionToken, "SessionToken should not be null even when response is null"); + assertEquals( + "providedAccessToken", + sessionToken.getAccessToken(), + "Access token should remain unchanged"); + assertEquals( + "existingRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should remain unchanged"); + } + + @Test + void testRefreshWhenConfigurationDisabled() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "providedAccessToken"; + + // Mock configuration to be disabled + when(configuration.isEnabled()).thenReturn(false); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull( + sessionToken, "SessionToken should not be null when configuration is disabled"); + assertEquals( + "providedAccessToken", + sessionToken.getAccessToken(), + "Access token should remain unchanged"); + assertEquals( + "existingRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should remain unchanged"); + // Verify that no exchange was attempted + verify(restTemplate, never()) + .exchange( + anyString(), + any(HttpMethod.class), + any(HttpEntity.class), + eq(OAuth2AccessToken.class)); + } + + @Test + void testRefreshWithMissingAccessToken() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = null; // Access token is missing + + // Act & Assert + Exception exception = + assertThrows( + RuntimeException.class, + () -> { + serviceDelegate.refresh(refreshToken, accessToken); + }); + + assertTrue( + exception + .getMessage() + .contains("Either the accessToken or the refresh token are missing"), + "Expected exception message"); + } + + @Test + void testRefreshWhenCacheReturnsNullAuthentication() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "providedAccessToken"; + + // Mock cache to return null + when(authenticationCache.get("providedAccessToken")).thenReturn(null); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull( + sessionToken, + "SessionToken should not be null even when authentication is not found in cache"); + assertEquals( + "providedAccessToken", + sessionToken.getAccessToken(), + "Access token should remain unchanged"); + assertEquals( + "existingRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should remain unchanged"); + } + + @Test + void testRefreshWhenAuthenticationIsAnonymous() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "providedAccessToken"; + + // Mock an AnonymousAuthenticationToken + Authentication anonymousAuthentication = + mock( + org.springframework.security.authentication.AnonymousAuthenticationToken + .class); + when(authenticationCache.get("providedAccessToken")).thenReturn(anonymousAuthentication); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull( + sessionToken, + "SessionToken should not be null even when authentication is anonymous"); + assertEquals( + "providedAccessToken", + sessionToken.getAccessToken(), + "Access token should remain unchanged"); + assertEquals( + "existingRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should remain unchanged"); + } + + @Test + void testRefreshWithExpiredAccessToken() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "expiredAccessToken"; + + // Set the current access token to be expired + mockOAuth2AccessToken.setExpiration( + new Date(System.currentTimeMillis() - 1000)); // Set expiration in the past + serviceDelegate.currentAccessToken = mockOAuth2AccessToken; + + // Mock the RestTemplate exchange method to simulate a successful token refresh + DefaultOAuth2AccessToken newAccessToken = new DefaultOAuth2AccessToken("newAccessToken"); + OAuth2RefreshToken newRefreshToken = new DefaultOAuth2RefreshToken("newRefreshToken"); + newAccessToken.setRefreshToken(newRefreshToken); + newAccessToken.setExpiration( + new Date(System.currentTimeMillis() + 7200 * 1000)); // Expires in 2 hours + + ResponseEntity responseEntity = + new ResponseEntity<>(newAccessToken, HttpStatus.OK); + + when(restTemplate.exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class))) + .thenReturn(responseEntity); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull(sessionToken, "SessionToken should not be null"); + assertEquals( + "newAccessToken", sessionToken.getAccessToken(), "Access token should be updated"); + assertEquals( + "newRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should be updated"); + assertTrue( + sessionToken.getExpires() > System.currentTimeMillis(), + "Token expiration should be in the future"); + } + + /** Test subclass of OAuth2SessionServiceDelegate for testing purposes. */ + class TestOAuth2SessionServiceDelegate extends OAuth2SessionServiceDelegate { + + private OAuth2RestTemplate restTemplate; + private OAuth2Configuration configuration; + private OAuth2AccessToken currentAccessToken; + protected TokenAuthenticationCache authenticationCache; + + public TestOAuth2SessionServiceDelegate() { + super(null, null); // Mocked dependencies + } + + public void setRestTemplate(OAuth2RestTemplate restTemplate) { + this.restTemplate = restTemplate; + } + + public void setConfiguration(OAuth2Configuration configuration) { + this.configuration = configuration; + } + + @Override + protected OAuth2RestTemplate restTemplate() { + return restTemplate; + } + + @Override + protected OAuth2Configuration configuration() { + return configuration; + } + + @Override + protected HttpServletRequest getRequest() { + return mockRequest; + } + + @Override + protected HttpServletResponse getResponse() { + return mockResponse; + } + + @Override + protected TokenDetails getTokenDetails(Authentication authentication) { + // This method is now mocked in the test setup using doReturn() + return super.getTokenDetails(authentication); + } + + @Override + protected OAuth2AccessToken retrieveAccessToken(String accessToken) { + return currentAccessToken; + } + + @Override + protected TokenAuthenticationCache cache() { + return authenticationCache; // Return the mocked cache + } + + @Override + protected void updateAuthToken( + String oldAccessToken, + OAuth2AccessToken newAccessToken, + OAuth2RefreshToken newRefreshToken, + OAuth2Configuration configuration) { + // Update the currentAccessToken to the newAccessToken + this.currentAccessToken = newAccessToken; + + // Simulate updating the authentication in the cache + Authentication newAuthentication = mock(Authentication.class); + TokenDetails newTokenDetails = mock(TokenDetails.class); + when(newTokenDetails.getAccessToken()).thenReturn(newAccessToken); + when(OAuth2Utils.getTokenDetails(newAuthentication)).thenReturn(newTokenDetails); + + // Remove the old token from the cache + authenticationCache.removeEntry(oldAccessToken); + + // Add the new token to the cache + authenticationCache.putCacheEntry(newAccessToken.getValue(), newAuthentication); + } + } +}