Skip to content

Commit

Permalink
feat: Added unit tests for Q FeatureDev planning phase (#4235)
Browse files Browse the repository at this point in the history
* feat: Added unit tests for FeatureDev planning phase
  • Loading branch information
kumsmrit authored Apr 9, 2024
1 parent 804dca8 commit 4a9a50c
Show file tree
Hide file tree
Showing 7 changed files with 450 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ import java.util.UUID

class FeatureDevController(
private val context: AmazonQAppInitContext,
private val chatSessionStorage: ChatSessionStorage
private val chatSessionStorage: ChatSessionStorage,
private val authController: AuthController = AuthController()
) : InboundAppMessagesHandler {

private val authController = AuthController()
private val messenger = context.messagesFromAppToUi
private val toolWindow = ToolWindowManager.getInstance(context.project).getToolWindow(AmazonQToolWindowFactory.WINDOW_ID)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.spy
import org.mockito.kotlin.whenever
import software.amazon.awssdk.awscore.DefaultAwsResponseMetadata
import software.amazon.awssdk.awscore.util.AwsHeader
import software.amazon.awssdk.services.codewhispererruntime.model.CodeGenerationStatus
import software.amazon.awssdk.services.codewhispererruntime.model.CreateTaskAssistConversationResponse
import software.amazon.awssdk.services.codewhispererruntime.model.CreateUploadUrlResponse
Expand All @@ -30,6 +32,7 @@ import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.sso.AccessToken
import software.aws.toolkits.jetbrains.core.credentials.sso.bearer.BearerTokenProvider
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.GenerateTaskAssistPlanResult
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.HeavyJavaCodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
Expand All @@ -50,26 +53,36 @@ open class FeatureDevTestBase(
internal lateinit var clientAdaptorSpy: FeatureDevClient
internal lateinit var toolkitConnectionManager: ToolkitConnectionManager

internal val testRequestId = "test_aws_request_id"
internal val testConversationId = "1234"
internal val userMessage = "test-user-message"
internal val testChecksumSha = "test-sha"
internal val testContentLength: Long = 40

internal val exampleCreateTaskAssistConversationResponse = CreateTaskAssistConversationResponse.builder()
.conversationId(testConversationId)
.responseMetadata(DefaultAwsResponseMetadata.create(mapOf(AwsHeader.AWS_REQUEST_ID to testRequestId)))
.build() as CreateTaskAssistConversationResponse

internal val exampleCreateUploadUrlResponse = CreateUploadUrlResponse.builder()
.uploadUrl("https://smth.com")
.uploadId("1234")
.kmsKeyArn("0000000000000000000000000000000000:key/1234abcd")
.responseMetadata(DefaultAwsResponseMetadata.create(mapOf(AwsHeader.AWS_REQUEST_ID to testRequestId)))
.build() as CreateUploadUrlResponse

internal val exampleGenerateTaskAssistPlanResult = GenerateTaskAssistPlanResult(approach = "Generated approach for plan", succeededPlanning = true)

internal val exampleStartTaskAssistConversationResponse = StartTaskAssistCodeGenerationResponse.builder()
.conversationId(testConversationId)
.codeGenerationId("1234")
.responseMetadata(DefaultAwsResponseMetadata.create(mapOf(AwsHeader.AWS_REQUEST_ID to testRequestId)))
.build() as StartTaskAssistCodeGenerationResponse

internal val exampleGetTaskAssistConversationResponse = GetTaskAssistCodeGenerationResponse.builder()
.conversationId(testConversationId)
.codeGenerationStatus(CodeGenerationStatus.builder().status("InitialCodeGeneration").currentStage("InProgress").build())
.responseMetadata(DefaultAwsResponseMetadata.create(mapOf(AwsHeader.AWS_REQUEST_ID to testRequestId)))
.build() as GetTaskAssistCodeGenerationResponse

internal val exampleExportResultArchiveResponse = mutableListOf(byteArrayOf(100))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package software.aws.toolkits.jetbrains.services.amazonqFeatureDev.controller

import com.intellij.testFramework.RuleChain
import com.intellij.testFramework.replaceService
import io.mockk.every
import io.mockk.just
import io.mockk.mockkStatic
import io.mockk.runs
import kotlinx.coroutines.test.runTest
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.spy
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import software.aws.toolkits.jetbrains.services.amazonq.apps.AmazonQAppInitContext
import software.aws.toolkits.jetbrains.services.amazonq.auth.AuthController
import software.aws.toolkits.jetbrains.services.amazonq.auth.AuthNeededStates
import software.aws.toolkits.jetbrains.services.amazonq.messages.MessagePublisher
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.FeatureDevTestBase
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.messages.IncomingFeatureDevMessage
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session.RefinementState
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session.Session
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.storage.ChatSessionStorage
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.util.uploadArtifactToS3

class FeatureDevControllerTest : FeatureDevTestBase() {
@Rule
@JvmField
val ruleChain = RuleChain(projectRule, disposableRule)

private lateinit var controller: FeatureDevController
private lateinit var messenger: MessagePublisher
private lateinit var chatSessionStorage: ChatSessionStorage
private lateinit var appContext: AmazonQAppInitContext
private lateinit var authController: AuthController
private lateinit var spySession: Session
private lateinit var featureDevClient: FeatureDevClient

private val tabId = "tabId"

@Before
override fun setup() {
super.setup()
featureDevClient = mock()
messenger = mock()
chatSessionStorage = mock()
appContext = mock<AmazonQAppInitContext> {
on { project }.thenReturn(project)
on { messagesFromAppToUi }.thenReturn(messenger)
}
authController = spy(AuthController())
mockkStatic("software.aws.toolkits.jetbrains.services.amazonqFeatureDev.util.UploadArtifactKt")
every { uploadArtifactToS3(any(), any(), any(), any(), any()) } just runs

controller = FeatureDevController(appContext, chatSessionStorage, authController)
}

@Test
fun `test new tab opened`() {
val message = IncomingFeatureDevMessage.NewTabCreated("new-tab-created", tabId)
spySession = spy(Session("tabId", project))
whenever(chatSessionStorage.getSession(any(), any())).thenReturn(spySession)

runTest {
controller.processNewTabCreatedMessage(message)
}
verify(authController, times(1)).getAuthNeededStates(project)
verify(chatSessionStorage, times(1)).getSession(tabId, project)
assertThat(spySession.isAuthenticating).isTrue()
}

@Test
fun `test handle chat for planning phase`() {
val testAuth = AuthNeededStates(amazonQ = null)
val message: IncomingFeatureDevMessage.ChatPrompt = IncomingFeatureDevMessage.ChatPrompt(userMessage, "chat-prompt", tabId)
projectRule.project.replaceService(FeatureDevClient::class.java, featureDevClient, disposableRule.disposable)
spySession = spy(Session("tabId", project))

doReturn(testAuth).`when`(authController).getAuthNeededStates(any())
whenever(chatSessionStorage.getSession(any(), any())).thenReturn(spySession)
whenever(featureDevClient.createTaskAssistConversation()).thenReturn(exampleCreateTaskAssistConversationResponse)
whenever(featureDevClient.createTaskAssistUploadUrl(any(), any(), any())).thenReturn(exampleCreateUploadUrlResponse)

runTest {
whenever(featureDevClient.generateTaskAssistPlan(any(), any(), any())).thenReturn(exampleGenerateTaskAssistPlanResult)

controller.processPromptChatMessage(message)

verify(spySession, times(1)).preloader(any(), any())
verify(spySession, times(1)).send(any())
assertThat(spySession.send(userMessage).content?.trim()).isEqualTo(exampleGenerateTaskAssistPlanResult.approach)
assertThat(spySession.send(userMessage).interactionSucceeded).isTrue()
}
verify(authController, times(1)).getAuthNeededStates(any())
assertThat(spySession.sessionState).isInstanceOf(RefinementState::class.java)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session

import com.intellij.testFramework.RuleChain
import io.mockk.every
import io.mockk.just
import io.mockk.mockkStatic
import io.mockk.runs
import kotlinx.coroutines.test.runTest
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.mock
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.FeatureDevTestBase
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.model.ZipCreationResult
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.util.uploadArtifactToS3
import java.io.File

class PrepareRefinementStateTest : FeatureDevTestBase() {
@Rule
@JvmField
val ruleChain = RuleChain(projectRule, disposableRule)

private lateinit var prepareRefinementState: PrepareRefinementState
private lateinit var repoContext: FeatureDevSessionContext
private lateinit var sessionStateConfig: SessionStateConfig
private lateinit var featureDevClient: FeatureDevClient

@Before
override fun setup() {
repoContext = mock()
featureDevClient = mock()
sessionStateConfig = SessionStateConfig(testConversationId, featureDevClient, repoContext)
prepareRefinementState = PrepareRefinementState("", "tabId", sessionStateConfig)
mockkStatic("software.aws.toolkits.jetbrains.services.amazonqFeatureDev.util.UploadArtifactKt")
every { uploadArtifactToS3(any(), any(), any(), any(), any()) } just runs
}

@Test
fun `test interact`() {
val mockFile: File = mock()
val repoZipResult = ZipCreationResult(mockFile, testChecksumSha, testContentLength)
val action = SessionStateAction("test-task", userMessage)

whenever(repoContext.getProjectZip()).thenReturn(repoZipResult)
whenever(featureDevClient.createTaskAssistUploadUrl(testConversationId, testChecksumSha, testContentLength)).thenReturn(exampleCreateUploadUrlResponse)

runTest {
whenever(
featureDevClient.generateTaskAssistPlan(testConversationId, exampleCreateUploadUrlResponse.uploadId(), userMessage)
).thenReturn(exampleGenerateTaskAssistPlanResult)

val actual = prepareRefinementState.interact(action)
assertThat(actual.nextState).isInstanceOf(RefinementState::class.java)
}
assertThat(prepareRefinementState.phase).isEqualTo(SessionStatePhase.APPROACH)
verify(repoContext, times(1)).getProjectZip()
verify(featureDevClient, times(1)).createTaskAssistUploadUrl(any(), any(), any())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session

import com.intellij.testFramework.RuleChain
import kotlinx.coroutines.test.runTest
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.mock
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.FeatureDevException
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.FeatureDevTestBase
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.GenerateTaskAssistPlanResult
import software.aws.toolkits.resources.message

class RefinementStateTest : FeatureDevTestBase() {
@Rule
@JvmField
val ruleChain = RuleChain(projectRule, disposableRule)

private lateinit var refinementState: RefinementState
private lateinit var sessionStateConfig: SessionStateConfig
private lateinit var featureDevClient: FeatureDevClient
private lateinit var repoContext: FeatureDevSessionContext

private val action = SessionStateAction("test-task", userMessage)

@Before
override fun setup() {
featureDevClient = mock()
repoContext = mock()
sessionStateConfig = SessionStateConfig(testConversationId, featureDevClient, repoContext)
refinementState = RefinementState("", "tabId", sessionStateConfig, exampleCreateUploadUrlResponse.uploadId(), 0)
}

@Test
fun `test refinement state with no userMssg`() {
val actionNoMssg = SessionStateAction("test-task", "")
assertThatThrownBy {
runTest {
refinementState.interact(actionNoMssg)
}
}.isInstanceOf(FeatureDevException::class.java).hasMessage(message("amazonqFeatureDev.exception.message_not_found"))
}

@Test
fun `test refinement state with successful approach`() = runTest {
whenever(
featureDevClient.generateTaskAssistPlan(testConversationId, exampleCreateUploadUrlResponse.uploadId(), userMessage)
).thenReturn(exampleGenerateTaskAssistPlanResult)

val actual = refinementState.interact(action)
assertThat(actual.nextState).isInstanceOf(RefinementState::class.java)
assertThat(actual.interaction.interactionSucceeded).isTrue()

verify(featureDevClient, times(1)).generateTaskAssistPlan(any(), any(), any())
assertThat(refinementState.phase).isEqualTo(SessionStatePhase.APPROACH)
assertThat(refinementState.approach).isEqualTo(exampleGenerateTaskAssistPlanResult.approach)
}

@Test
fun `test refinement state with failed approach`() = runTest {
val generateTaskAssistPlanResult = GenerateTaskAssistPlanResult(approach = "There has been a problem generating approach", succeededPlanning = false)
whenever(
featureDevClient.generateTaskAssistPlan(testConversationId, exampleCreateUploadUrlResponse.uploadId(), userMessage)
).thenReturn(generateTaskAssistPlanResult)

val actual = refinementState.interact(action)
assertThat(actual.nextState).isInstanceOf(RefinementState::class.java)
assertThat(actual.interaction.interactionSucceeded).isFalse()

verify(featureDevClient, times(1)).generateTaskAssistPlan(any(), any(), any())
assertThat(refinementState.phase).isEqualTo(SessionStatePhase.APPROACH)
assertThat(refinementState.approach).isEqualTo(generateTaskAssistPlanResult.approach)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package software.aws.toolkits.jetbrains.services.amazonqFeatureDev.session

import com.intellij.testFramework.RuleChain
import com.intellij.testFramework.replaceService
import kotlinx.coroutines.test.runTest
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.mock
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import software.aws.toolkits.jetbrains.services.amazonq.messages.MessagePublisher
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.FeatureDevTestBase
import software.aws.toolkits.jetbrains.services.amazonqFeatureDev.clients.FeatureDevClient

class SessionTest : FeatureDevTestBase() {
@Rule
@JvmField
val ruleChain = RuleChain(projectRule, disposableRule)

private lateinit var featureDevClient: FeatureDevClient
private lateinit var session: Session
private lateinit var messenger: MessagePublisher

@Before
override fun setup() {
featureDevClient = mock()
projectRule.project.replaceService(FeatureDevClient::class.java, featureDevClient, disposableRule.disposable)
session = Session("tabId", projectRule.project)
messenger = mock()
}

@Test
fun `test session before preloader`() {
assertThat(session.sessionState).isInstanceOf(ConversationNotStartedState::class.java)
assertThat(session.isAuthenticating).isFalse()
}

@Test
fun `test preloader`() = runTest {
whenever(featureDevClient.createTaskAssistConversation()).thenReturn(exampleCreateTaskAssistConversationResponse)

session.preloader(userMessage, messenger)
assertThat(session.conversationId).isEqualTo(testConversationId)
assertThat(session.sessionState).isInstanceOf(PrepareRefinementState::class.java)
verify(featureDevClient, times(1)).createTaskAssistConversation()
}
}
Loading

0 comments on commit 4a9a50c

Please sign in to comment.