diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCache.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCache.kt index fd38ef3dd9..e39a9f27c8 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCache.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCache.kt @@ -101,15 +101,15 @@ class DiskCache( clientRegistrationCache(ssoRegion).tryDeleteIfExists() } - override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? { - LOG.info { "loadClientRegistration for $cacheKey" } + override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration? { + LOG.info { "loadClientRegistration:$source for $cacheKey" } val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists() if (inputStream == null) { val stage = LoadCredentialStage.ACCESS_FILE LOG.info { "Failed to load Client Registration: cache file does not exist" } AuthTelemetry.modifyConnection( action = "Load cache file", - source = "loadClientRegistration", + source = "loadClientRegistration:$source", result = Result.Failed, reason = "Failed to load Client Registration", reasonDesc = "Load Step:$stage failed. Cache file does not exist" diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt index bb694f3042..7fd718a822 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProvider.kt @@ -194,7 +194,7 @@ class SsoAccessTokenProvider( @Deprecated("Device authorization grant flow is deprecated") private fun registerDAGClient(): ClientRegistration { - loadDagClientRegistration()?.let { + loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let { return it } @@ -235,7 +235,7 @@ class SsoAccessTokenProvider( } private fun registerPkceClient(): PKCEClientRegistration { - loadPkceClientRegistration()?.let { + loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let { return it } @@ -431,8 +431,8 @@ class SsoAccessTokenProvider( stageName = RefreshCredentialStage.LOAD_REGISTRATION val registration = try { when (currentToken) { - is DeviceAuthorizationGrantToken -> loadDagClientRegistration() - is PKCEAuthorizationGrantToken -> loadPkceClientRegistration() + is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString()) + is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString()) } } catch (e: Exception) { val message = e.message ?: "$stageName: ${e::class.java.name}" @@ -505,6 +505,11 @@ class SsoAccessTokenProvider( } } + enum class SourceOfLoadRegistration { + REGISTER_CLIENT, + REFRESH_TOKEN, + } + private enum class RefreshCredentialStage { VALIDATE_REFRESH_TOKEN, LOAD_REGISTRATION, @@ -514,13 +519,13 @@ class SsoAccessTokenProvider( SAVE_TOKEN, } - private fun loadDagClientRegistration(): ClientRegistration? = - cache.loadClientRegistration(dagClientRegistrationCacheKey)?.let { + private fun loadDagClientRegistration(source: String): ClientRegistration? = + cache.loadClientRegistration(dagClientRegistrationCacheKey, source)?.let { return it } - private fun loadPkceClientRegistration(): PKCEClientRegistration? = - cache.loadClientRegistration(pkceClientRegistrationCacheKey)?.let { + private fun loadPkceClientRegistration(source: String): PKCEClientRegistration? = + cache.loadClientRegistration(pkceClientRegistrationCacheKey, source)?.let { return it as PKCEClientRegistration } diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoCache.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoCache.kt index 4aa845e388..b3a7496935 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoCache.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/core/credentials/sso/SsoCache.kt @@ -7,7 +7,7 @@ interface SsoCache { fun invalidateClientRegistration(ssoRegion: String) fun invalidateAccessToken(ssoUrl: String) - fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? + fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration? fun saveClientRegistration(cacheKey: ClientRegistrationCacheKey, registration: ClientRegistration) fun invalidateClientRegistration(cacheKey: ClientRegistrationCacheKey) diff --git a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCacheTest.kt b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCacheTest.kt index bdb30af210..e5ac374d36 100644 --- a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCacheTest.kt +++ b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/DiskCacheTest.kt @@ -57,7 +57,8 @@ class DiskCacheTest { startUrl = ssoUrl, scopes = scopes, region = ssoRegion - ) + ), + "testSource" ) ).isNull() } @@ -71,7 +72,7 @@ class DiskCacheTest { ) cacheLocation.resolve("223224b6f0b4702c1a984be8284fe2c9d9718759.json").writeText("badData") - assertThat(sut.loadClientRegistration(key)).isNull() + assertThat(sut.loadClientRegistration(key, "testSource")).isNull() } @Test @@ -91,7 +92,7 @@ class DiskCacheTest { """.trimIndent() ) - assertThat(sut.loadClientRegistration(key)).isNull() + assertThat(sut.loadClientRegistration(key, "testSource")).isNull() } @Test @@ -112,7 +113,7 @@ class DiskCacheTest { """.trimIndent() ) - assertThat(sut.loadClientRegistration(key)).isNull() + assertThat(sut.loadClientRegistration(key, "testSource")).isNull() } @Test @@ -134,7 +135,7 @@ class DiskCacheTest { """.trimIndent() ) - assertThat(sut.loadClientRegistration(key)) + assertThat(sut.loadClientRegistration(key, "testSource")) .usingRecursiveComparison() .isEqualTo( DeviceAuthorizationClientRegistration( @@ -217,7 +218,7 @@ class DiskCacheTest { """.trimIndent() ) - assertThat(sut.loadClientRegistration(key)) + assertThat(sut.loadClientRegistration(key, "testSource")) .usingRecursiveComparison() .isEqualTo( PKCEClientRegistration( @@ -323,10 +324,10 @@ class DiskCacheTest { ) ) - assertThat(sut.loadClientRegistration(key1)) + assertThat(sut.loadClientRegistration(key1, "testSource")) .usingRecursiveComparison() .isEqualTo( - sut.loadClientRegistration(key2) + sut.loadClientRegistration(key2, "testSource") ) } @@ -350,11 +351,11 @@ class DiskCacheTest { region = ssoRegion ) - assertThat(sut.loadClientRegistration(key)).isNotNull() + assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull() sut.invalidateClientRegistration(key) - assertThat(sut.loadClientRegistration(key)).isNull() + assertThat(sut.loadClientRegistration(key, "testSource")).isNull() assertThat(cacheFile).doesNotExist() } @@ -619,7 +620,7 @@ class DiskCacheTest { registration.setPosixFilePermissions(emptySet()) assertPosixPermissions(registration, "---------") - assertThat(sut.loadClientRegistration(key)).isNotNull() + assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull() assertPosixPermissions(registration, "rw-------") } diff --git a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProviderTest.kt b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProviderTest.kt index 7291d561ab..5f6f267687 100644 --- a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProviderTest.kt +++ b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/SsoAccessTokenProviderTest.kt @@ -124,7 +124,7 @@ class SsoAccessTokenProviderTest { verify(ssoOidcClient).startDeviceAuthorization(any()) verify(ssoOidcClient).createToken(any()) verify(ssoCache).loadAccessToken(argThat { startUrl == ssoUrl }) - verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }) + verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any()) verify(ssoCache).saveAccessToken(argThat { startUrl == ssoUrl }, eq(accessToken)) } @@ -170,7 +170,7 @@ class SsoAccessTokenProviderTest { verify(ssoOidcClient).startDeviceAuthorization(any()) verify(ssoOidcClient).createToken(any()) verify(ssoCache).loadAccessToken(argThat { startUrl == ssoUrl }) - verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }) + verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any()) verify(ssoCache).saveClientRegistration(argThat { region == ssoRegion }, any()) verify(ssoCache).saveAccessToken(argThat { startUrl == ssoUrl }, eq(accessToken)) } @@ -267,7 +267,7 @@ class SsoAccessTokenProviderTest { verify(ssoOidcClient).startDeviceAuthorization(any()) verify(ssoOidcClient, times(2)).createToken(any()) verify(ssoCache).loadAccessToken(argThat { startUrl == ssoUrl }) - verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }) + verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any()) verify(ssoCache).saveAccessToken(argThat { startUrl == ssoUrl }, eq(accessToken)) } @@ -296,7 +296,7 @@ class SsoAccessTokenProviderTest { val refreshedToken = runBlocking { sut.refreshToken(sut.accessToken()) } verify(ssoCache).loadAccessToken(argThat { startUrl == ssoUrl }) - verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }) + verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any()) verify(ssoOidcClient).createToken(any()) verify(ssoCache).saveAccessToken(argThat { startUrl == ssoUrl }, eq(refreshedToken)) } @@ -342,7 +342,7 @@ class SsoAccessTokenProviderTest { ) on( - ssoCache.loadClientRegistration(any()) + ssoCache.loadClientRegistration(any(), any()) ).thenReturn( PKCEClientRegistration( clientType = "public", @@ -369,7 +369,7 @@ class SsoAccessTokenProviderTest { val refreshedToken = runBlocking { sut.refreshToken(sut.accessToken()) } verify(ssoCache).loadAccessToken(any()) - verify(ssoCache).loadClientRegistration(any()) + verify(ssoCache).loadClientRegistration(any(), any()) verify(ssoOidcClient).createToken(any()) verify(ssoCache).saveAccessToken(any(), eq(refreshedToken)) } @@ -390,7 +390,7 @@ class SsoAccessTokenProviderTest { verify(ssoOidcClient).startDeviceAuthorization(any()) verify(ssoOidcClient).createToken(any()) verify(ssoCache).loadAccessToken(argThat { startUrl == ssoUrl }) - verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }) + verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any()) } @Test @@ -432,7 +432,7 @@ class SsoAccessTokenProviderTest { verify(ssoOidcClient).startDeviceAuthorization(any()) verify(ssoOidcClient, times(2)).createToken(any()) verify(ssoCache).loadAccessToken(argThat { startUrl == ssoUrl }) - verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }) + verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any()) verify(ssoCache).saveAccessToken(argThat { startUrl == ssoUrl }, eq(accessToken)) } @@ -452,7 +452,7 @@ class SsoAccessTokenProviderTest { verify(ssoOidcClient).registerClient(any()) verify(ssoCache).loadAccessToken(any()) - verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }) + verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any()) } @Test @@ -492,7 +492,7 @@ class SsoAccessTokenProviderTest { ) on( - ssoCache.loadClientRegistration(argThat { region == ssoRegion }) + ssoCache.loadClientRegistration(argThat { region == ssoRegion }, any()) ).thenReturn( returnValue ) diff --git a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt index 4ca91a5fad..ca5661b1f2 100644 --- a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt +++ b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/core/credentials/sso/bearer/InteractiveBearerTokenProviderTest.kt @@ -16,6 +16,7 @@ import org.junit.jupiter.api.assertThrows import org.mockito.Mockito import org.mockito.kotlin.any import org.mockito.kotlin.argThat +import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.spy import org.mockito.kotlin.times @@ -273,7 +274,7 @@ class InteractiveBearerTokenProviderTest { ) private fun stubClientRegistration() { - whenever(diskCache.loadClientRegistration(any())).thenReturn( + whenever(diskCache.loadClientRegistration(any(), eq("testSource"))).thenReturn( DeviceAuthorizationClientRegistration( "", "",