Skip to content

Commit

Permalink
Add source to load registration metrics (#5075)
Browse files Browse the repository at this point in the history
* initial commit

* add source to loading metrics

* detekt

* fix tests

* fix more tests
  • Loading branch information
samgst-amazon authored Nov 12, 2024
1 parent 1b6357a commit 5f3cb69
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ class DiskCache(
clientRegistrationCache(ssoRegion).tryDeleteIfExists()
}

override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration? {
LOG.info { "loadClientRegistration for $cacheKey" }
override fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration? {
LOG.info { "loadClientRegistration:$source for $cacheKey" }
val inputStream = clientRegistrationCache(cacheKey).tryInputStreamIfExists()
if (inputStream == null) {
val stage = LoadCredentialStage.ACCESS_FILE
LOG.info { "Failed to load Client Registration: cache file does not exist" }
AuthTelemetry.modifyConnection(
action = "Load cache file",
source = "loadClientRegistration",
source = "loadClientRegistration:$source",
result = Result.Failed,
reason = "Failed to load Client Registration",
reasonDesc = "Load Step:$stage failed. Cache file does not exist"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class SsoAccessTokenProvider(

@Deprecated("Device authorization grant flow is deprecated")
private fun registerDAGClient(): ClientRegistration {
loadDagClientRegistration()?.let {
loadDagClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
return it
}

Expand Down Expand Up @@ -235,7 +235,7 @@ class SsoAccessTokenProvider(
}

private fun registerPkceClient(): PKCEClientRegistration {
loadPkceClientRegistration()?.let {
loadPkceClientRegistration(SourceOfLoadRegistration.REGISTER_CLIENT.toString())?.let {
return it
}

Expand Down Expand Up @@ -431,8 +431,8 @@ class SsoAccessTokenProvider(
stageName = RefreshCredentialStage.LOAD_REGISTRATION
val registration = try {
when (currentToken) {
is DeviceAuthorizationGrantToken -> loadDagClientRegistration()
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration()
is DeviceAuthorizationGrantToken -> loadDagClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
is PKCEAuthorizationGrantToken -> loadPkceClientRegistration(SourceOfLoadRegistration.REFRESH_TOKEN.toString())
}
} catch (e: Exception) {
val message = e.message ?: "$stageName: ${e::class.java.name}"
Expand Down Expand Up @@ -505,6 +505,11 @@ class SsoAccessTokenProvider(
}
}

enum class SourceOfLoadRegistration {
REGISTER_CLIENT,
REFRESH_TOKEN,
}

private enum class RefreshCredentialStage {
VALIDATE_REFRESH_TOKEN,
LOAD_REGISTRATION,
Expand All @@ -514,13 +519,13 @@ class SsoAccessTokenProvider(
SAVE_TOKEN,
}

private fun loadDagClientRegistration(): ClientRegistration? =
cache.loadClientRegistration(dagClientRegistrationCacheKey)?.let {
private fun loadDagClientRegistration(source: String): ClientRegistration? =
cache.loadClientRegistration(dagClientRegistrationCacheKey, source)?.let {
return it
}

private fun loadPkceClientRegistration(): PKCEClientRegistration? =
cache.loadClientRegistration(pkceClientRegistrationCacheKey)?.let {
private fun loadPkceClientRegistration(source: String): PKCEClientRegistration? =
cache.loadClientRegistration(pkceClientRegistrationCacheKey, source)?.let {
return it as PKCEClientRegistration
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ interface SsoCache {
fun invalidateClientRegistration(ssoRegion: String)
fun invalidateAccessToken(ssoUrl: String)

fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey): ClientRegistration?
fun loadClientRegistration(cacheKey: ClientRegistrationCacheKey, source: String): ClientRegistration?
fun saveClientRegistration(cacheKey: ClientRegistrationCacheKey, registration: ClientRegistration)
fun invalidateClientRegistration(cacheKey: ClientRegistrationCacheKey)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class DiskCacheTest {
startUrl = ssoUrl,
scopes = scopes,
region = ssoRegion
)
),
"testSource"
)
).isNull()
}
Expand All @@ -71,7 +72,7 @@ class DiskCacheTest {
)
cacheLocation.resolve("223224b6f0b4702c1a984be8284fe2c9d9718759.json").writeText("badData")

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -91,7 +92,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -112,7 +113,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
}

@Test
Expand All @@ -134,7 +135,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key))
assertThat(sut.loadClientRegistration(key, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
DeviceAuthorizationClientRegistration(
Expand Down Expand Up @@ -217,7 +218,7 @@ class DiskCacheTest {
""".trimIndent()
)

assertThat(sut.loadClientRegistration(key))
assertThat(sut.loadClientRegistration(key, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
PKCEClientRegistration(
Expand Down Expand Up @@ -323,10 +324,10 @@ class DiskCacheTest {
)
)

assertThat(sut.loadClientRegistration(key1))
assertThat(sut.loadClientRegistration(key1, "testSource"))
.usingRecursiveComparison()
.isEqualTo(
sut.loadClientRegistration(key2)
sut.loadClientRegistration(key2, "testSource")
)
}

Expand All @@ -350,11 +351,11 @@ class DiskCacheTest {
region = ssoRegion
)

assertThat(sut.loadClientRegistration(key)).isNotNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull()

sut.invalidateClientRegistration(key)

assertThat(sut.loadClientRegistration(key)).isNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNull()
assertThat(cacheFile).doesNotExist()
}

Expand Down Expand Up @@ -619,7 +620,7 @@ class DiskCacheTest {
registration.setPosixFilePermissions(emptySet())
assertPosixPermissions(registration, "---------")

assertThat(sut.loadClientRegistration(key)).isNotNull()
assertThat(sut.loadClientRegistration(key, "testSource")).isNotNull()

assertPosixPermissions(registration, "rw-------")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any<String>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}

Expand Down Expand Up @@ -170,7 +170,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
verify(ssoCache).saveClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}
Expand Down Expand Up @@ -267,7 +267,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient, times(2)).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}

Expand Down Expand Up @@ -296,7 +296,7 @@ class SsoAccessTokenProviderTest {
val refreshedToken = runBlocking { sut.refreshToken(sut.accessToken()) }

verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any<String>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(refreshedToken))
}
Expand Down Expand Up @@ -342,7 +342,7 @@ class SsoAccessTokenProviderTest {
)

on(
ssoCache.loadClientRegistration(any<PKCEClientRegistrationCacheKey>())
ssoCache.loadClientRegistration(any<PKCEClientRegistrationCacheKey>(), any<String>())
).thenReturn(
PKCEClientRegistration(
clientType = "public",
Expand All @@ -369,7 +369,7 @@ class SsoAccessTokenProviderTest {
val refreshedToken = runBlocking { sut.refreshToken(sut.accessToken()) }

verify(ssoCache).loadAccessToken(any<PKCEAccessTokenCacheKey>())
verify(ssoCache).loadClientRegistration(any<PKCEClientRegistrationCacheKey>())
verify(ssoCache).loadClientRegistration(any<PKCEClientRegistrationCacheKey>(), any<String>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).saveAccessToken(any<PKCEAccessTokenCacheKey>(), eq(refreshedToken))
}
Expand All @@ -390,7 +390,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
}

@Test
Expand Down Expand Up @@ -432,7 +432,7 @@ class SsoAccessTokenProviderTest {
verify(ssoOidcClient).startDeviceAuthorization(any<StartDeviceAuthorizationRequest>())
verify(ssoOidcClient, times(2)).createToken(any<CreateTokenRequest>())
verify(ssoCache).loadAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat<DeviceAuthorizationClientRegistrationCacheKey> { region == ssoRegion }, any<String>())
verify(ssoCache).saveAccessToken(argThat<DeviceGrantAccessTokenCacheKey> { startUrl == ssoUrl }, eq(accessToken))
}

Expand All @@ -452,7 +452,7 @@ class SsoAccessTokenProviderTest {

verify(ssoOidcClient).registerClient(any<RegisterClientRequest>())
verify(ssoCache).loadAccessToken(any())
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion })
verify(ssoCache).loadClientRegistration(argThat { region == ssoRegion }, any<String>())
}

@Test
Expand Down Expand Up @@ -492,7 +492,7 @@ class SsoAccessTokenProviderTest {
)

on(
ssoCache.loadClientRegistration(argThat { region == ssoRegion })
ssoCache.loadClientRegistration(argThat { region == ssoRegion }, any<String>())
).thenReturn(
returnValue
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.junit.jupiter.api.assertThrows
import org.mockito.Mockito
import org.mockito.kotlin.any
import org.mockito.kotlin.argThat
import org.mockito.kotlin.eq
import org.mockito.kotlin.mock
import org.mockito.kotlin.spy
import org.mockito.kotlin.times
Expand Down Expand Up @@ -273,7 +274,7 @@ class InteractiveBearerTokenProviderTest {
)

private fun stubClientRegistration() {
whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey>())).thenReturn(
whenever(diskCache.loadClientRegistration(any<DeviceAuthorizationClientRegistrationCacheKey>(), eq("testSource"))).thenReturn(
DeviceAuthorizationClientRegistration(
"",
"",
Expand Down

0 comments on commit 5f3cb69

Please sign in to comment.