Skip to content

Commit

Permalink
Restrict TestSpark action to suitable code types and fix generation f…
Browse files Browse the repository at this point in the history
…or a line (#344)

* fix update function

* create availableForGeneration

* ktlint

* feat: add javadoc for `JavaPsiHelper.availableForGeneration`

* feat: check for nullness of a PSI file in `TestSparkAction.update`

* feat: update javadocs in `PsiComponents.kt`

* feat: check for a class or method/func in `KotlinPsiHelper.availableForGeneration`

* feat: add TODO to `ToolUtils` about a potential bug

The bug is reflected in the issue #375.

* feat: make `PsiHelper.getSurroundingLineNumber` return 1-based line numbers

Before, the `KotlinPsiHelper` returned a 0-based line number which caused an issue with line-based test generation.
The generated prompt contained a line above the selected one.

* feat: implement line-based test generation with CUT as a context

When there is no surrounding method about the selected line,
we use the CUT as a context for this line. The CUT must always be present.
Otherwise, the generation action should have been disabled for this line.

* refactor: apply ktlint

* feat: add `See` in TODO

* feat: add TODO and surround $NAME in backticks in `linePrompt` template

* feat: collect class constructor signatures in `PsiClassWrapper`

* feat: remove backticks from `linePrompt`

* feat: fill line-based test generation with additional context

The line-based test generation that has a method as a context
of the line now also accepts constructors of the containing class.

* refactor: use `firstOrNull` for `cut` extraction

* refactor: apply ktlint

* fix: add required parameter to `ClassRepresentation` in tests

* publish: core module version `4.0.0`

The major version increased due to the change of the public API of `PromptGenerator.generatePromptForLine` method.

---------

Co-authored-by: Vladislav Artiukhov <[email protected]>
  • Loading branch information
arksap2002 and Vladislav0Art authored Oct 10, 2024
1 parent 7ba322b commit d696259
Show file tree
Hide file tree
Showing 15 changed files with 281 additions and 63 deletions.
2 changes: 1 addition & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ publishing {
create<MavenPublication>("maven") {
groupId = group as String
artifactId = "testspark-core"
version = "3.0.1"
version = "4.0.0"
from(components["java"])

artifact(tasks["sourcesJar"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ class PromptGenerator(
testSamplesCode: String,
packageName: String,
): String {
val name = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}"
val methodQualifiedName = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}"

val prompt = PromptBuilder(promptTemplates.methodPrompt)
.insertLanguage(context.promptConfiguration.desiredLanguage)
.insertName(name)
.insertName(methodQualifiedName)
.insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform)
.insertMockingFramework(context.promptConfiguration.desiredMockingFramework)
.insertCodeUnderTest(method.text, context.classesToTest)
Expand All @@ -62,8 +63,7 @@ class PromptGenerator(
}

/**
* Generates a prompt for a given line under test.
* It accepts the code of a line under test, a representation of the method that contains the line, and the set of interesting classes (e.g., the containing class of the method, classes listed in parameters of the method and constructors of the containing class).
* Generates a prompt for a given line under test using a surrounding method/function as a context.
*
* @param lineUnderTest The source code of the line to be tested.
* @param method The representation of the method that contains the line.
Expand All @@ -76,18 +76,130 @@ class PromptGenerator(
method: MethodRepresentation,
interestingClassesFromMethod: List<ClassRepresentation>,
testSamplesCode: String,
packageName: String,
): String {
val codeUnderTest = if (context.cut != null) {
// `method` is a method within a class
buildCutDeclaration(context.cut, method)
} else {
// `method` is a top-level function
method.text
}

val methodQualifiedName = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}"
val lineReference = "`${lineUnderTest.trim()}` within `$methodQualifiedName`"

val prompt = PromptBuilder(promptTemplates.linePrompt)
.insertLanguage(context.promptConfiguration.desiredLanguage)
.insertName(lineUnderTest.trim())
.insertName(lineReference)
.insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform)
.insertMockingFramework(context.promptConfiguration.desiredMockingFramework)
.insertCodeUnderTest(method.text, context.classesToTest)
.insertCodeUnderTest(codeUnderTest, context.classesToTest)
.insertMethodsSignatures(interestingClassesFromMethod)
.insertPolymorphismRelations(context.polymorphismRelations)
.insertTestSample(testSamplesCode)
.build()

return prompt
}

/**
* Generates a prompt for a given line under test using CUT as a context.
*
* **Contract: `context.cut` is not `null`.**
*
* @param lineUnderTest The source code of the line to be tested.
* @param interestingClasses The list of `ClassRepresentation` objects related to the line under test.
* @param testSamplesCode The code snippet that serves as test samples.
* @return The generated prompt as `String`.
* @throws IllegalStateException If any of the required keywords are missing in the prompt template.
*/
fun generatePromptForLine(
lineUnderTest: String,
interestingClasses: List<ClassRepresentation>,
testSamplesCode: String,
): String {
val lineReference = "`${lineUnderTest.trim()}` within `${context.cut!!.qualifiedName}`"

val prompt = PromptBuilder(promptTemplates.linePrompt)
.insertLanguage(context.promptConfiguration.desiredLanguage)
.insertName(lineReference)
.insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform)
.insertMockingFramework(context.promptConfiguration.desiredMockingFramework)
.insertCodeUnderTest(context.cut.fullText, context.classesToTest)
.insertMethodsSignatures(interestingClasses)
.insertPolymorphismRelations(context.polymorphismRelations)
.insertTestSample(testSamplesCode)
.build()

return prompt
}
}

/**
* Builds a cut declaration with constructor declarations and a method under test.
*
* Example when there exist non-default constructors:
* ```
* [Instruction]: Use the following constructor declarations to instantiate `org.example.CalcKotlin` and call the method under test `add`:
*
* Constructors of the class org.example.CalcKotlin:
* === (val value: Int)
* === constructor(c: Int, d: Int) : this(c+d)
*
* Method:
* fun add(a: Int, b: Int): Int {
* return a + b
* }
* ```
*
* Example when only a default constructor exists:
* ```
* [Instruction]: Use a default constructor with zero arguments to instantiate `Calc` and call the method under test `sum`:
*
* Constructors of the class Calc:
* === Default constructor
*
* Method:
* public int sum(int a, int b) {
* return a + b;
* }
* ```
*
* @param cut The `ClassRepresentation` object representing the class to be instantiated.
* @param method The `MethodRepresentation` object representing the method under test.
* @return A formatted `String` representing the cut declaration, containing constructor declarations and method text.
*/
private fun buildCutDeclaration(cut: ClassRepresentation, method: MethodRepresentation): String {
val instruction = buildString {
val constructorToUse = if (cut.constructorSignatures.isEmpty()) {
"a default constructor with zero arguments"
} else {
"the following constructor declarations"
}
append("Use $constructorToUse to instantiate `${cut.qualifiedName}` and call the method under test `${method.name}`")
}

val classType = cut.classType.representation

val constructorDeclarations = buildString {
appendLine("Constructors of the $classType ${cut.qualifiedName}:")
if (cut.constructorSignatures.isEmpty()) {
appendLine("=== Default constructor")
}
for (constructor in cut.constructorSignatures) {
appendLine("\t=== $constructor")
}
}.trim()

val cutDeclaration = buildString {
appendLine("[Instruction]: $instruction:")
appendLine()
appendLine(constructorDeclarations)
appendLine()
appendLine("Method:")
appendLine(method.text)
}.trim()

return cutDeclaration
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ data class PromptConfiguration(
data class ClassRepresentation(
val qualifiedName: String,
val fullText: String,
val constructorSignatures: List<String>,
val allMethods: List<MethodRepresentation>,
val classType: ClassType,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper {

override val allMethods: List<PsiMethodWrapper> get() = psiClass.allMethods.map { JavaPsiMethodWrapper(it) }

override val constructorSignatures: List<String> get() = psiClass.constructors.map { JavaPsiMethodWrapper.buildSignature(it) }

override val superClass: PsiClassWrapper? get() = psiClass.superClass?.let { JavaPsiClassWrapper(it) }

override val virtualFile: VirtualFile get() = psiClass.containingFile.virtualFile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper {

override val language: SupportedLanguage get() = SupportedLanguage.Java

/**
* When dealing with Java PSI files, we expect that only classes and their methods are tested.
* Therefore, we expect a **class** to surround a cursor offset.
*
* This requirement ensures that the user is not trying
* to generate tests for a line of code outside the class scope.
*
* @param e `AnActionEvent` representing the current action event.
* @return `true` if the cursor is inside a class, `false` otherwise.
*/
override fun availableForGeneration(e: AnActionEvent): Boolean =
getCurrentListOfCodeTypes(e).any { it.first == CodeType.CLASS }

private val log = Logger.getInstance(this::class.java)

override fun generateMethodDescriptor(psiMethod: PsiMethodWrapper): String {
Expand Down Expand Up @@ -70,6 +83,12 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper {
override fun getSurroundingLineNumber(caretOffset: Int): Int? {
val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null

/**
* See `getLineNumber`'s documentation for details on the numbering.
* It returns an index of the line in the document, starting from 0.
*
* Therefore, we need to increase the result by one to get the line number.
*/
val selectedLine = doc.getLineNumber(caretOffset)
val selectedLineText =
doc.getText(TextRange(doc.getLineStartOffset(selectedLine), doc.getLineEndOffset(selectedLine)))
Expand All @@ -79,7 +98,6 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper {
return null
}
log.info("Surrounding line at caret $caretOffset is $selectedLine")

// increase by one is necessary due to different start of numbering
return selectedLine + 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ class JavaPsiMethodWrapper(private val psiMethod: PsiMethod) : PsiMethodWrapper
}

override val signature: String
get() {
val bodyStart = psiMethod.body?.startOffsetInParent ?: psiMethod.textLength
return psiMethod.text.substring(0, bodyStart).replace("\\n", "").trim()
}
get() = buildSignature(psiMethod)

val parameterList = psiMethod.parameterList

Expand Down Expand Up @@ -117,4 +114,17 @@ class JavaPsiMethodWrapper(private val psiMethod: PsiMethod) : PsiMethodWrapper
}
}
}

companion object {
/**
* Builds a signature for a given `PsiMethod`.
*
* @param method the PsiMethod for which to build the signature
* @return the method signature with the text before the method body, excluding newline characters
*/
fun buildSignature(method: PsiMethod): String {
val bodyStart = method.body?.startOffsetInParent ?: method.textLength
return method.text.substring(0, bodyStart).replace("\\n", "").trim()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.KtClass
import org.jetbrains.kotlin.psi.KtClassOrObject
import org.jetbrains.kotlin.psi.KtObjectDeclaration
import org.jetbrains.kotlin.psi.allConstructors
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils
import org.jetbrains.research.testspark.core.data.ClassType
Expand All @@ -35,6 +36,8 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra

override val allMethods: List<PsiMethodWrapper> get() = methods

override val constructorSignatures: List<String> get() = psiClass.allConstructors.map { KotlinPsiMethodWrapper.buildSignature(it) }

override val superClass: PsiClassWrapper?
get() {
// Get the superTypeListEntries of the Kotlin class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper {

override val language: SupportedLanguage get() = SupportedLanguage.Kotlin

/**
* When dealing with Kotlin PSI files, we expect that only classes, their methods,
* top-level functions are tested.
* Therefore, we expect either a class or a method (top-level function) to surround a cursor offset.
*
* This requirement ensures that the user is not trying
* to generate tests for a line of code outside the aforementioned scopes.
*
* @param e `AnActionEvent` representing the current action event.
* @return `true` if the cursor is inside a class or method, `false` otherwise.
*/
override fun availableForGeneration(e: AnActionEvent): Boolean =
getCurrentListOfCodeTypes(e).any { (it.first == CodeType.CLASS) || (it.first == CodeType.METHOD) }

private val log = Logger.getInstance(this::class.java)

override fun generateMethodDescriptor(psiMethod: PsiMethodWrapper): String {
Expand Down Expand Up @@ -64,6 +78,12 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper {
override fun getSurroundingLineNumber(caretOffset: Int): Int? {
val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null

/**
* See `getLineNumber`'s documentation for details on the numbering.
* It returns an index of the line in the document, starting from 0.
*
* Therefore, we need to increase the result by one to get the line number.
*/
val selectedLine = doc.getLineNumber(caretOffset)
val selectedLineText =
doc.getText(TextRange(doc.getLineStartOffset(selectedLine), doc.getLineEndOffset(selectedLine)))
Expand All @@ -73,7 +93,7 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper {
return null
}
log.info("Surrounding line at caret $caretOffset is $selectedLine")
return selectedLine
return selectedLine + 1
}

override fun collectClassesToTest(
Expand Down Expand Up @@ -150,7 +170,7 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper {

val ktClass = getSurroundingClass(caret.offset)
val ktFunction = getSurroundingMethod(caret.offset)
val line: Int? = getSurroundingLineNumber(caret.offset)?.plus(1)
val line: Int? = getSurroundingLineNumber(caret.offset)

ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) }
ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper {
}

override val signature: String
get() = psiFunction.run {
val bodyStart = bodyExpression?.startOffsetInParent ?: textLength
text.substring(0, bodyStart).replace('\n', ' ').trim()
}
get() = buildSignature(psiFunction)

val parameterList = psiFunction.valueParameterList

Expand Down Expand Up @@ -131,4 +128,17 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper {
else -> "L${type.replace('.', '/')};"
}
}

companion object {
/**
* Builds a signature for a given Kotlin function by extracting the method body portion.
*
* @param function The Kotlin function to build the signature for.
* @return The signature of the function.
*/
fun buildSignature(function: KtFunction) = function.run {
val bodyStart = bodyExpression?.startOffsetInParent ?: textLength
text.substring(0, bodyStart).replace('\n', ' ').trim()
}
}
}
Loading

0 comments on commit d696259

Please sign in to comment.