Skip to content

Commit

Permalink
feat(dev): Stop code generation action (#4948)
Browse files Browse the repository at this point in the history
Currently /dev don't support stop code generation. This PR introduces this functionality (watch the video below).

Mynah UI provides onStopChatResponse API which we can hook in the cancellation token provided on VS Code, sharing across an active session, aborting current progress.
  • Loading branch information
tverney authored Oct 23, 2024
1 parent 6a849a8 commit 8dc4b81
Show file tree
Hide file tree
Showing 32 changed files with 2,816 additions and 2,120 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type" : "feature",
"description" : "Amazon Q /dev: Add stop generation action"
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class FeatureDevApp : AmazonQApp {
"response-body-link-click" to IncomingFeatureDevMessage.ClickedLink::class,
"insert_code_at_cursor_position" to IncomingFeatureDevMessage.InsertCodeAtCursorPosition::class,
"open-diff" to IncomingFeatureDevMessage.OpenDiff::class,
"file-click" to IncomingFeatureDevMessage.FileClicked::class
"file-click" to IncomingFeatureDevMessage.FileClicked::class,
"stop-response" to IncomingFeatureDevMessage.StopResponse::class
)

scope.launch {
Expand Down Expand Up @@ -82,6 +83,7 @@ class FeatureDevApp : AmazonQApp {
is IncomingFeatureDevMessage.InsertCodeAtCursorPosition -> inboundAppMessagesHandler.processInsertCodeAtCursorPosition(message)
is IncomingFeatureDevMessage.OpenDiff -> inboundAppMessagesHandler.processOpenDiff(message)
is IncomingFeatureDevMessage.FileClicked -> inboundAppMessagesHandler.processFileClicked(message)
is IncomingFeatureDevMessage.StopResponse -> inboundAppMessagesHandler.processStopMessage(message)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ interface InboundAppMessagesHandler {
suspend fun processInsertCodeAtCursorPosition(message: IncomingFeatureDevMessage.InsertCodeAtCursorPosition)
suspend fun processOpenDiff(message: IncomingFeatureDevMessage.OpenDiff)
suspend fun processFileClicked(message: IncomingFeatureDevMessage.FileClicked)
suspend fun processStopMessage(message: IncomingFeatureDevMessage.StopResponse)
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,107 +39,135 @@ import java.time.Instant
import software.amazon.awssdk.services.codewhispererruntime.model.ChatTriggerType as SyncChatTriggerType

@Service(Service.Level.PROJECT)
class FeatureDevClient(private val project: Project) {
class FeatureDevClient(
private val project: Project,
) {
fun getTelemetryOptOutPreference() =
if (AwsSettings.getInstance().isTelemetryEnabled) {
OptOutPreference.OPTIN
} else {
OptOutPreference.OPTOUT
}

private val featureDevUserContext = ClientMetadata.getDefault().let {
val osForFeatureDev: OperatingSystem =
when {
SystemInfo.isWindows -> OperatingSystem.WINDOWS
SystemInfo.isMac -> OperatingSystem.MAC
// For now, categorize everything else as "Linux" (Linux/FreeBSD/Solaris/etc.)
else -> OperatingSystem.LINUX
}
private val featureDevUserContext =
ClientMetadata.getDefault().let {
val osForFeatureDev: OperatingSystem =
when {
SystemInfo.isWindows -> OperatingSystem.WINDOWS
SystemInfo.isMac -> OperatingSystem.MAC
// For now, categorize everything else as "Linux" (Linux/FreeBSD/Solaris/etc.)
else -> OperatingSystem.LINUX
}

UserContext.builder()
.ideCategory(IdeCategory.JETBRAINS)
.operatingSystem(osForFeatureDev)
.product(FEATURE_EVALUATION_PRODUCT_NAME)
.clientId(it.clientId)
.ideVersion(it.awsVersion)
.build()
}
UserContext
.builder()
.ideCategory(IdeCategory.JETBRAINS)
.operatingSystem(osForFeatureDev)
.product(FEATURE_EVALUATION_PRODUCT_NAME)
.clientId(it.clientId)
.ideVersion(it.awsVersion)
.build()
}

private fun connection() = ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())
?: error("Attempted to use connection while one does not exist")
private fun connection() =
ToolkitConnectionManager.getInstance(project).activeConnectionForFeature(QConnection.getInstance())
?: error("Attempted to use connection while one does not exist")

private fun bearerClient() = connection().getConnectionSettings().awsClient<CodeWhispererRuntimeClient>()

private val amazonQStreamingClient
get() = AmazonQStreamingClient.getInstance(project)

fun sendFeatureDevTelemetryEvent(conversationId: String): SendTelemetryEventResponse = bearerClient().sendTelemetryEvent { requestBuilder ->
requestBuilder.telemetryEvent { telemetryEventBuilder ->
telemetryEventBuilder.featureDevEvent {
it.conversationId(conversationId)
fun sendFeatureDevTelemetryEvent(conversationId: String): SendTelemetryEventResponse =
bearerClient().sendTelemetryEvent { requestBuilder ->
requestBuilder.telemetryEvent { telemetryEventBuilder ->
telemetryEventBuilder.featureDevEvent {
it.conversationId(conversationId)
}
}
requestBuilder.optOutPreference(getTelemetryOptOutPreference())
requestBuilder.userContext(featureDevUserContext)
}
requestBuilder.optOutPreference(getTelemetryOptOutPreference())
requestBuilder.userContext(featureDevUserContext)
}

fun createTaskAssistConversation(): CreateTaskAssistConversationResponse = bearerClient().createTaskAssistConversation(
CreateTaskAssistConversationRequest.builder().build()
)

fun createTaskAssistUploadUrl(conversationId: String, contentChecksumSha256: String, contentLength: Long): CreateUploadUrlResponse =
fun createTaskAssistConversation(): CreateTaskAssistConversationResponse =
bearerClient().createTaskAssistConversation(
CreateTaskAssistConversationRequest.builder().build(),
)

fun createTaskAssistUploadUrl(
conversationId: String,
contentChecksumSha256: String,
contentLength: Long,
uploadId: String,
): CreateUploadUrlResponse =
bearerClient().createUploadUrl {
it.contentChecksumType(ContentChecksumType.SHA_256)
it
.contentChecksumType(ContentChecksumType.SHA_256)
.uploadId(uploadId)
.contentChecksum(contentChecksumSha256)
.contentLength(contentLength)
.artifactType(ArtifactType.SOURCE_CODE)
.uploadIntent(UploadIntent.TASK_ASSIST_PLANNING)
.uploadContext(
UploadContext.builder()
UploadContext
.builder()
.taskAssistPlanningUploadContext(
TaskAssistPlanningUploadContext.builder()
TaskAssistPlanningUploadContext
.builder()
.conversationId(conversationId)
.build()
)
.build()
.build(),
).build(),
)
}

fun startTaskAssistCodeGeneration(conversationId: String, uploadId: String, userMessage: String): StartTaskAssistCodeGenerationResponse = bearerClient()
.startTaskAssistCodeGeneration {
request ->
request
.conversationState {
it
.conversationId(conversationId)
.chatTriggerType(SyncChatTriggerType.MANUAL)
.currentMessage { cm -> cm.userInputMessage { um -> um.content(userMessage) } }
}
.workspaceState {
it
.programmingLanguage { pl -> pl.languageName("javascript") } // This parameter is omitted by featureDev but required in the request
.uploadId(uploadId)
}
}
fun startTaskAssistCodeGeneration(
conversationId: String,
uploadId: String,
userMessage: String,
codeGenerationId: String?,
currentCodeGenerationId: String?,
): StartTaskAssistCodeGenerationResponse =
bearerClient()
.startTaskAssistCodeGeneration { request ->
request
.conversationState {
it
.conversationId(conversationId)
.chatTriggerType(SyncChatTriggerType.MANUAL)
.currentMessage { cm -> cm.userInputMessage { um -> um.content(userMessage) } }
}.workspaceState {
it
.programmingLanguage { pl -> pl.languageName("javascript") } // This parameter is omitted by featureDev but required in the request
.uploadId(uploadId)
}.codeGenerationId(codeGenerationId.toString())
.currentCodeGenerationId(currentCodeGenerationId)
}

fun getTaskAssistCodeGeneration(conversationId: String, codeGenerationId: String): GetTaskAssistCodeGenerationResponse = bearerClient()
.getTaskAssistCodeGeneration {
it
.conversationId(conversationId)
.codeGenerationId(codeGenerationId)
}
fun getTaskAssistCodeGeneration(
conversationId: String,
codeGenerationId: String,
): GetTaskAssistCodeGenerationResponse =
bearerClient()
.getTaskAssistCodeGeneration {
it
.conversationId(conversationId)
.codeGenerationId(codeGenerationId)
}

suspend fun exportTaskAssistResultArchive(conversationId: String): MutableList<ByteArray> = amazonQStreamingClient.exportResultArchive(
conversationId,
ExportIntent.TASK_ASSIST,
null,
{ e ->
LOG.error(e) { "TaskAssist - ExportResultArchive stream exportId=$conversationId exportIntent=${ExportIntent.TASK_ASSIST} Failed: ${e.message} " }
},
{ startTime ->
LOG.info { "TaskAssist - ExportResultArchive latency: ${calculateTotalLatency(startTime, Instant.now())}" }
}
)
suspend fun exportTaskAssistResultArchive(conversationId: String): MutableList<ByteArray> =
amazonQStreamingClient.exportResultArchive(
conversationId,
ExportIntent.TASK_ASSIST,
null,
{ e ->
LOG.error(
e,
) { "TaskAssist - ExportResultArchive stream exportId=$conversationId exportIntent=${ExportIntent.TASK_ASSIST} Failed: ${e.message} " }
},
{ startTime ->
LOG.info { "TaskAssist - ExportResultArchive latency: ${calculateTotalLatency(startTime, Instant.now())}" }
},
)

companion object {
private val LOG = getLogger<FeatureDevClient>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.editor.Caret
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.fileEditor.FileEditorManager
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.VfsUtil
import com.intellij.openapi.wm.ToolWindowManager
import kotlinx.coroutines.withContext
Expand Down Expand Up @@ -74,6 +75,7 @@ import software.aws.toolkits.jetbrains.utils.notifyError
import software.aws.toolkits.resources.message
import software.aws.toolkits.telemetry.AmazonqTelemetry
import software.aws.toolkits.telemetry.Result
import software.aws.toolkits.telemetry.UiTelemetry
import java.util.UUID

class FeatureDevController(
Expand All @@ -92,6 +94,10 @@ class FeatureDevController(
)
}

override suspend fun processStopMessage(message: IncomingFeatureDevMessage.StopResponse) {
handleStopMessage(message)
}

override suspend fun processNewTabCreatedMessage(message: IncomingFeatureDevMessage.NewTabCreated) {
newTabOpened(message.tabId)
}
Expand Down Expand Up @@ -284,6 +290,26 @@ class FeatureDevController(
}
}

private suspend fun handleStopMessage(message: IncomingFeatureDevMessage.StopResponse) {
val session: Session?
UiTelemetry.click(null as Project?, "amazonq_stopCodeGeneration")
messenger.sendAnswer(
tabId = message.tabId,
message("amazonqFeatureDev.code_generation.stopping_code_generation"),
messageType = FeatureDevMessageType.Answer,
canBeVoted = false
)
messenger.sendUpdatePlaceholder(
tabId = message.tabId,
newPlaceholder = message("amazonqFeatureDev.code_generation.stopping_code_generation")
)
messenger.sendChatInputEnabledMessage(tabId = message.tabId, enabled = false)
session = getSessionInfo(message.tabId)

if (session.sessionState.token?.token !== null) {
session.sessionState.token?.cancel()
}
}
private suspend fun insertCode(tabId: String) {
var session: Session? = null
try {
Expand Down
Loading

0 comments on commit 8dc4b81

Please sign in to comment.