Skip to content

Commit

Permalink
config(amazonq): Add project context to inline completion (#4976)
Browse files Browse the repository at this point in the history
  • Loading branch information
Will-ShaoHua authored Oct 24, 2024
1 parent 4eb0b30 commit 41965f8
Show file tree
Hide file tree
Showing 12 changed files with 619 additions and 110 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type" : "bugfix",
"description" : "Update `@workspace` index when adding or deleting a file"
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,56 @@ import com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo
import com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig
import com.github.tomakehurst.wiremock.http.Body
import com.github.tomakehurst.wiremock.junit.WireMockRule
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.project.Project
import com.intellij.testFramework.DisposableRule
import com.intellij.testFramework.replaceService
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.jupiter.api.assertThrows
import org.mockito.kotlin.any
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.spy
import org.mockito.kotlin.stub
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import software.aws.toolkits.jetbrains.core.coroutines.getCoroutineBgContext
import software.aws.toolkits.jetbrains.services.amazonq.project.EncoderServer
import software.aws.toolkits.jetbrains.services.amazonq.project.IndexRequest
import software.aws.toolkits.jetbrains.services.amazonq.project.IndexUpdateMode
import software.aws.toolkits.jetbrains.services.amazonq.project.InlineBm25Chunk
import software.aws.toolkits.jetbrains.services.amazonq.project.LspMessage
import software.aws.toolkits.jetbrains.services.amazonq.project.ProjectContextProvider
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryChatRequest
import software.aws.toolkits.jetbrains.services.amazonq.project.QueryInlineCompletionRequest
import software.aws.toolkits.jetbrains.services.amazonq.project.RelevantDocument
import software.aws.toolkits.jetbrains.services.amazonq.project.UpdateIndexRequest
import software.aws.toolkits.jetbrains.settings.CodeWhispererSettings
import software.aws.toolkits.jetbrains.utils.rules.CodeInsightTestFixtureRule
import software.aws.toolkits.jetbrains.utils.rules.JavaCodeInsightTestFixtureRule
import java.net.ConnectException
import kotlin.test.Test

@OptIn(ExperimentalCoroutinesApi::class)
class ProjectContextProviderTest {
@Rule
@JvmField
val projectRule: CodeInsightTestFixtureRule = JavaCodeInsightTestFixtureRule()

@Rule
@JvmField
val disposableRule: DisposableRule = DisposableRule()

@Rule
@JvmField
val wireMock: WireMockRule = createMockServer()
Expand All @@ -56,21 +75,23 @@ class ProjectContextProviderTest {

private val mapper = jacksonObjectMapper()

private val dispatcher = StandardTestDispatcher()

@Before
fun setup() {
encoderServer = spy(EncoderServer(project))
encoderServer.stub { on { port } doReturn wireMock.port() }

sut = ProjectContextProvider(project, encoderServer, TestScope())
sut = ProjectContextProvider(project, encoderServer, TestScope(context = dispatcher))

// initialization
stubFor(any(urlPathEqualTo("/initialize")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))

// build index
stubFor(any(urlPathEqualTo("/indexFiles")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
stubFor(any(urlPathEqualTo("/buildIndex")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))

// update index
stubFor(any(urlPathEqualTo("/updateIndex")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))
stubFor(any(urlPathEqualTo("/updateIndexV2")).willReturn(aResponse().withStatus(200).withResponseBody(Body("initialize response"))))

// query
stubFor(
Expand All @@ -80,6 +101,15 @@ class ProjectContextProviderTest {
.withResponseBody(Body(validQueryChatResponse))
)
)
stubFor(
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
aResponse()
.withStatus(200)
.withResponseBody(
Body(validQueryInlineResponse)
)
)
)

stubFor(
any(urlPathEqualTo("/getUsage"))
Expand All @@ -92,32 +122,73 @@ class ProjectContextProviderTest {
}

@Test
fun `Lsp endpoint are correct`() {
fun `Lsp endpoint correctness`() {
assertThat(LspMessage.Initialize.endpoint).isEqualTo("initialize")
assertThat(LspMessage.Index.endpoint).isEqualTo("indexFiles")
assertThat(LspMessage.Index.endpoint).isEqualTo("buildIndex")
assertThat(LspMessage.UpdateIndex.endpoint).isEqualTo("updateIndexV2")
assertThat(LspMessage.QueryChat.endpoint).isEqualTo("query")
assertThat(LspMessage.QueryInlineCompletion.endpoint).isEqualTo("queryInlineProjectContext")
assertThat(LspMessage.GetUsageMetrics.endpoint).isEqualTo("getUsage")
}

@Test
fun `index should send files within the project to lsp`() {
fun `index should send files within the project to lsp - vector index enabled`() {
ApplicationManager.getApplication().replaceService(
CodeWhispererSettings::class.java,
mock { on { isProjectContextEnabled() } doReturn true },
disposableRule.disposable
)

projectRule.fixture.addFileToProject("Foo.java", "foo")
projectRule.fixture.addFileToProject("Bar.java", "bar")
projectRule.fixture.addFileToProject("Baz.java", "baz")

sut.index()

val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "all", "")
assertThat(request.filePaths).hasSize(3)
assertThat(request.filePaths).satisfies({
it.contains("/src/Foo.java") &&
it.contains("/src/Baz.java") &&
it.contains("/src/Bar.java")
})
assertThat(request.config).isEqualTo("all")

wireMock.verify(
1,
postRequestedFor(urlPathEqualTo("/buildIndex"))
.withHeader("Content-Type", equalTo("text/plain"))
// comment it out because order matters and will cause json string different
// .withRequestBody(equalTo(encryptedRequest))
)
}

@Test
fun `index should send files within the project to lsp - vector index disabled`() {
ApplicationManager.getApplication().replaceService(
CodeWhispererSettings::class.java,
mock { on { isProjectContextEnabled() } doReturn false },
disposableRule.disposable
)

projectRule.fixture.addFileToProject("Foo.java", "foo")
projectRule.fixture.addFileToProject("Bar.java", "bar")
projectRule.fixture.addFileToProject("Baz.java", "baz")

sut.index()

val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", false)
val request = IndexRequest(listOf("/src/Foo.java", "/src/Bar.java", "/src/Baz.java"), "/src", "default", "")
assertThat(request.filePaths).hasSize(3)
assertThat(request.filePaths).satisfies({
it.contains("/src/Foo.java") &&
it.contains("/src/Baz.java") &&
it.contains("/src/Bar.java")
})
assertThat(request.config).isEqualTo("default")

wireMock.verify(
1,
postRequestedFor(urlPathEqualTo("/indexFiles"))
postRequestedFor(urlPathEqualTo("/buildIndex"))
.withHeader("Content-Type", equalTo("text/plain"))
// comment it out because order matters and will cause json string different
// .withRequestBody(equalTo(encryptedRequest))
Expand All @@ -126,17 +197,17 @@ class ProjectContextProviderTest {

@Test
fun `updateIndex should send correct encrypted request to lsp`() {
sut.updateIndex("foo.java")
val request = UpdateIndexRequest("foo.java")
sut.updateIndex(listOf("foo.java"), IndexUpdateMode.UPDATE)
val request = UpdateIndexRequest(listOf("foo.java"), IndexUpdateMode.UPDATE.command)
val requestJson = mapper.writeValueAsString(request)

assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "filePath": "foo.java" }"""))
assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "filePaths": ["foo.java"], "mode": "update" }"""))

val encryptedRequest = encoderServer.encrypt(requestJson)

wireMock.verify(
1,
postRequestedFor(urlPathEqualTo("/updateIndex"))
postRequestedFor(urlPathEqualTo("/updateIndexV2"))
.withHeader("Content-Type", equalTo("text/plain"))
.withRequestBody(equalTo(encryptedRequest))
)
Expand All @@ -161,6 +232,26 @@ class ProjectContextProviderTest {
)
}

@Test
fun `queryInline should send correct encrypted request to lsp`() = runTest {
sut = ProjectContextProvider(project, encoderServer, this)
sut.queryInline("foo", "Foo.java")
advanceUntilIdle()

val request = QueryInlineCompletionRequest("foo", "Foo.java")
val requestJson = mapper.writeValueAsString(request)

assertThat(mapper.readTree(requestJson)).isEqualTo(mapper.readTree("""{ "query": "foo", "filePath": "Foo.java" }"""))

val encryptedRequest = encoderServer.encrypt(requestJson)
wireMock.verify(
1,
postRequestedFor(urlPathEqualTo("/queryInlineProjectContext"))
.withHeader("Content-Type", equalTo("text/plain"))
.withRequestBody(equalTo(encryptedRequest))
)
}

@Test
fun `query chat should return empty if result set non deserializable`() = runTest {
stubFor(
Expand Down Expand Up @@ -200,12 +291,92 @@ class ProjectContextProviderTest {
)
}

@Test
fun `query inline should throw if resultset not deserializable`() {
assertThrows<Exception> {
runTest {
sut = ProjectContextProvider(project, encoderServer, this)
stubFor(
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
aResponse().withStatus(200).withResponseBody(
Body(
"""
[
"foo", "bar"
]
""".trimIndent()
)
)
)
)

assertThrows<Exception> {
sut.queryInline("foo", "filepath")
advanceUntilIdle()
}
}
}
}

@Test
fun `query inline should return deserialized bm25 chunks`() = runTest {
sut = ProjectContextProvider(project, encoderServer, this)
advanceUntilIdle()
val r = sut.queryInline("foo", "filepath")
assertThat(r).hasSize(3)
assertThat(r[0]).isEqualTo(
InlineBm25Chunk(
"content1",
"file1",
0.1
)
)
assertThat(r[1]).isEqualTo(
InlineBm25Chunk(
"content2",
"file2",
0.2
)
)
assertThat(r[2]).isEqualTo(
InlineBm25Chunk(
"content3",
"file3",
0.3
)
)
}

@Test
fun `get usage should return memory, cpu usage`() = runTest {
val r = sut.getUsage()
assertThat(r).isEqualTo(ProjectContextProvider.Usage(123, 456))
}

@Test
fun `queryInline should throw if time elapsed is greater than 50ms`() = runTest {
assertThrows<TimeoutCancellationException> {
sut = ProjectContextProvider(project, encoderServer, this)
stubFor(
any(urlPathEqualTo("/queryInlineProjectContext")).willReturn(
aResponse()
.withStatus(200)
.withResponseBody(
Body(validQueryInlineResponse)
)
.withFixedDelay(51) // 10 sec
)
)

// it won't throw if it's executed within TestDispatcher context
withContext(getCoroutineBgContext()) {
sut.queryInline("foo", "bar")
}

advanceUntilIdle()
}
}

@Test
fun `test index payload is encrypted`() = runTest {
whenever(encoderServer.port).thenReturn(3000)
Expand All @@ -231,6 +402,27 @@ class ProjectContextProviderTest {
private fun createMockServer() = WireMockRule(wireMockConfig().dynamicPort())
}

// language=JSON
val validQueryInlineResponse = """
[
{
"content": "content1",
"filePath": "file1",
"score": 0.1
},
{
"content": "content2",
"filePath": "file2",
"score": 0.2
},
{
"content": "content3",
"filePath": "file3",
"score": 0.3
}
]
""".trimIndent()

// language=JSON
val validQueryChatResponse = """
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ object CodeWhispererConstants {
const val POPUP_DELAY_CHECK_INTERVAL: Long = 25
const val IDLE_TIME_CHECK_INTERVAL: Long = 25
const val SUPPLEMENTAL_CONTEXT_TIMEOUT = 50L
const val SUPPLEMETAL_CONTEXT_BUFFER = 10L

val AWSTemplateKeyWordsRegex = Regex("(AWSTemplateFormatVersion|Resources|AWS::|Description)")
val AWSTemplateCaseInsensitiveKeyWordsRegex = Regex("(cloudformation|cfn|template|description)")
Expand Down
Loading

0 comments on commit 41965f8

Please sign in to comment.