From d696259543f5a1881bf659d4363e12d729e19cb4 Mon Sep 17 00:00:00 2001 From: Arkadii Sapozhnikov <47223481+arksap2002@users.noreply.github.com> Date: Thu, 10 Oct 2024 12:41:33 +0200 Subject: [PATCH] Restrict TestSpark action to suitable code types and fix generation for a line (#344) * fix update function * create availableForGeneration * ktlint * feat: add javadoc for `JavaPsiHelper.availableForGeneration` * feat: check for nullness of a PSI file in `TestSparkAction.update` * feat: update javadocs in `PsiComponents.kt` * feat: check for a class or method/func in `KotlinPsiHelper.availableForGeneration` * feat: add TODO to `ToolUtils` about a potential bug The bug is reflected in the issue #375. * feat: make `PsiHelper.getSurroundingLineNumber` return 1-based line numbers Before, the `KotlinPsiHelper` returned a 0-based line number which caused an issue with line-based test generation. The generated prompt contained a line above the selected one. * feat: implement line-based test generation with CUT as a context When there is no surrounding method about the selected line, we use the CUT as a context for this line. The CUT must always be present. Otherwise, the generation action should have been disabled for this line. * refactor: apply ktlint * feat: add `See` in TODO * feat: add TODO and surround $NAME in backticks in `linePrompt` template * feat: collect class constructor signatures in `PsiClassWrapper` * feat: remove backticks from `linePrompt` * feat: fill line-based test generation with additional context The line-based test generation that has a method as a context of the line now also accepts constructors of the containing class. * refactor: use `firstOrNull` for `cut` extraction * refactor: apply ktlint * fix: add required parameter to `ClassRepresentation` in tests * publish: core module version `4.0.0` The major version increased due to the change of the public API of `PromptGenerator.generatePromptForLine` method. --------- Co-authored-by: Vladislav Artiukhov --- core/build.gradle.kts | 2 +- .../generation/llm/prompt/PromptGenerator.kt | 124 +++++++++++++++++- .../llm/prompt/configuration/Configuration.kt | 1 + .../testspark/java/JavaPsiClassWrapper.kt | 2 + .../research/testspark/java/JavaPsiHelper.kt | 20 ++- .../testspark/java/JavaPsiMethodWrapper.kt | 18 ++- .../testspark/kotlin/KotlinPsiClassWrapper.kt | 3 + .../testspark/kotlin/KotlinPsiHelper.kt | 24 +++- .../kotlin/KotlinPsiMethodWrapper.kt | 18 ++- .../testspark/langwrappers/PsiComponents.kt | 20 ++- .../testspark/actions/TestSparkAction.kt | 14 +- .../research/testspark/tools/ToolUtils.kt | 2 + .../tools/llm/generation/PromptManager.kt | 91 ++++++++----- .../properties/llm/LLMDefaults.properties | 2 +- .../llm/prompt/PromptBuilderTest.kt | 3 + 15 files changed, 281 insertions(+), 63 deletions(-) diff --git a/core/build.gradle.kts b/core/build.gradle.kts index a14baef5f..ef8060a1e 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -42,7 +42,7 @@ publishing { create("maven") { groupId = group as String artifactId = "testspark-core" - version = "3.0.1" + version = "4.0.0" from(components["java"]) artifact(tasks["sourcesJar"]) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt index 72340867a..eb40c9ea9 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptGenerator.kt @@ -46,10 +46,11 @@ class PromptGenerator( testSamplesCode: String, packageName: String, ): String { - val name = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}" + val methodQualifiedName = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}" + val prompt = PromptBuilder(promptTemplates.methodPrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName(name) + .insertName(methodQualifiedName) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) .insertCodeUnderTest(method.text, context.classesToTest) @@ -62,8 +63,7 @@ class PromptGenerator( } /** - * Generates a prompt for a given line under test. - * It accepts the code of a line under test, a representation of the method that contains the line, and the set of interesting classes (e.g., the containing class of the method, classes listed in parameters of the method and constructors of the containing class). + * Generates a prompt for a given line under test using a surrounding method/function as a context. * * @param lineUnderTest The source code of the line to be tested. * @param method The representation of the method that contains the line. @@ -76,13 +76,25 @@ class PromptGenerator( method: MethodRepresentation, interestingClassesFromMethod: List, testSamplesCode: String, + packageName: String, ): String { + val codeUnderTest = if (context.cut != null) { + // `method` is a method within a class + buildCutDeclaration(context.cut, method) + } else { + // `method` is a top-level function + method.text + } + + val methodQualifiedName = context.cut?.let { "${it.qualifiedName}.${method.name}" } ?: "$packageName.${method.name}" + val lineReference = "`${lineUnderTest.trim()}` within `$methodQualifiedName`" + val prompt = PromptBuilder(promptTemplates.linePrompt) .insertLanguage(context.promptConfiguration.desiredLanguage) - .insertName(lineUnderTest.trim()) + .insertName(lineReference) .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) - .insertCodeUnderTest(method.text, context.classesToTest) + .insertCodeUnderTest(codeUnderTest, context.classesToTest) .insertMethodsSignatures(interestingClassesFromMethod) .insertPolymorphismRelations(context.polymorphismRelations) .insertTestSample(testSamplesCode) @@ -90,4 +102,104 @@ class PromptGenerator( return prompt } + + /** + * Generates a prompt for a given line under test using CUT as a context. + * + * **Contract: `context.cut` is not `null`.** + * + * @param lineUnderTest The source code of the line to be tested. + * @param interestingClasses The list of `ClassRepresentation` objects related to the line under test. + * @param testSamplesCode The code snippet that serves as test samples. + * @return The generated prompt as `String`. + * @throws IllegalStateException If any of the required keywords are missing in the prompt template. + */ + fun generatePromptForLine( + lineUnderTest: String, + interestingClasses: List, + testSamplesCode: String, + ): String { + val lineReference = "`${lineUnderTest.trim()}` within `${context.cut!!.qualifiedName}`" + + val prompt = PromptBuilder(promptTemplates.linePrompt) + .insertLanguage(context.promptConfiguration.desiredLanguage) + .insertName(lineReference) + .insertTestingPlatform(context.promptConfiguration.desiredTestingPlatform) + .insertMockingFramework(context.promptConfiguration.desiredMockingFramework) + .insertCodeUnderTest(context.cut.fullText, context.classesToTest) + .insertMethodsSignatures(interestingClasses) + .insertPolymorphismRelations(context.polymorphismRelations) + .insertTestSample(testSamplesCode) + .build() + + return prompt + } +} + +/** + * Builds a cut declaration with constructor declarations and a method under test. + * + * Example when there exist non-default constructors: + * ``` + * [Instruction]: Use the following constructor declarations to instantiate `org.example.CalcKotlin` and call the method under test `add`: + * + * Constructors of the class org.example.CalcKotlin: + * === (val value: Int) + * === constructor(c: Int, d: Int) : this(c+d) + * + * Method: + * fun add(a: Int, b: Int): Int { + * return a + b + * } + * ``` + * + * Example when only a default constructor exists: + * ``` + * [Instruction]: Use a default constructor with zero arguments to instantiate `Calc` and call the method under test `sum`: + * + * Constructors of the class Calc: + * === Default constructor + * + * Method: + * public int sum(int a, int b) { + * return a + b; + * } + * ``` + * + * @param cut The `ClassRepresentation` object representing the class to be instantiated. + * @param method The `MethodRepresentation` object representing the method under test. + * @return A formatted `String` representing the cut declaration, containing constructor declarations and method text. + */ +private fun buildCutDeclaration(cut: ClassRepresentation, method: MethodRepresentation): String { + val instruction = buildString { + val constructorToUse = if (cut.constructorSignatures.isEmpty()) { + "a default constructor with zero arguments" + } else { + "the following constructor declarations" + } + append("Use $constructorToUse to instantiate `${cut.qualifiedName}` and call the method under test `${method.name}`") + } + + val classType = cut.classType.representation + + val constructorDeclarations = buildString { + appendLine("Constructors of the $classType ${cut.qualifiedName}:") + if (cut.constructorSignatures.isEmpty()) { + appendLine("=== Default constructor") + } + for (constructor in cut.constructorSignatures) { + appendLine("\t=== $constructor") + } + }.trim() + + val cutDeclaration = buildString { + appendLine("[Instruction]: $instruction:") + appendLine() + appendLine(constructorDeclarations) + appendLine() + appendLine("Method:") + appendLine(method.text) + }.trim() + + return cutDeclaration } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt index 6b87e8941..a98abe4ad 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/configuration/Configuration.kt @@ -43,6 +43,7 @@ data class PromptConfiguration( data class ClassRepresentation( val qualifiedName: String, val fullText: String, + val constructorSignatures: List, val allMethods: List, val classType: ClassType, ) diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt index 087485827..9f6a5a28c 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiClassWrapper.kt @@ -27,6 +27,8 @@ class JavaPsiClassWrapper(private val psiClass: PsiClass) : PsiClassWrapper { override val allMethods: List get() = psiClass.allMethods.map { JavaPsiMethodWrapper(it) } + override val constructorSignatures: List get() = psiClass.constructors.map { JavaPsiMethodWrapper.buildSignature(it) } + override val superClass: PsiClassWrapper? get() = psiClass.superClass?.let { JavaPsiClassWrapper(it) } override val virtualFile: VirtualFile get() = psiClass.containingFile.virtualFile diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt index f6f132a29..8562c382a 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiHelper.kt @@ -26,6 +26,19 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { override val language: SupportedLanguage get() = SupportedLanguage.Java + /** + * When dealing with Java PSI files, we expect that only classes and their methods are tested. + * Therefore, we expect a **class** to surround a cursor offset. + * + * This requirement ensures that the user is not trying + * to generate tests for a line of code outside the class scope. + * + * @param e `AnActionEvent` representing the current action event. + * @return `true` if the cursor is inside a class, `false` otherwise. + */ + override fun availableForGeneration(e: AnActionEvent): Boolean = + getCurrentListOfCodeTypes(e).any { it.first == CodeType.CLASS } + private val log = Logger.getInstance(this::class.java) override fun generateMethodDescriptor(psiMethod: PsiMethodWrapper): String { @@ -70,6 +83,12 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null + /** + * See `getLineNumber`'s documentation for details on the numbering. + * It returns an index of the line in the document, starting from 0. + * + * Therefore, we need to increase the result by one to get the line number. + */ val selectedLine = doc.getLineNumber(caretOffset) val selectedLineText = doc.getText(TextRange(doc.getLineStartOffset(selectedLine), doc.getLineEndOffset(selectedLine))) @@ -79,7 +98,6 @@ class JavaPsiHelper(private val psiFile: PsiFile) : PsiHelper { return null } log.info("Surrounding line at caret $caretOffset is $selectedLine") - // increase by one is necessary due to different start of numbering return selectedLine + 1 } diff --git a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiMethodWrapper.kt b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiMethodWrapper.kt index d7fd7ba04..fd60cd488 100644 --- a/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiMethodWrapper.kt +++ b/java/src/main/kotlin/org/jetbrains/research/testspark/java/JavaPsiMethodWrapper.kt @@ -35,10 +35,7 @@ class JavaPsiMethodWrapper(private val psiMethod: PsiMethod) : PsiMethodWrapper } override val signature: String - get() { - val bodyStart = psiMethod.body?.startOffsetInParent ?: psiMethod.textLength - return psiMethod.text.substring(0, bodyStart).replace("\\n", "").trim() - } + get() = buildSignature(psiMethod) val parameterList = psiMethod.parameterList @@ -117,4 +114,17 @@ class JavaPsiMethodWrapper(private val psiMethod: PsiMethod) : PsiMethodWrapper } } } + + companion object { + /** + * Builds a signature for a given `PsiMethod`. + * + * @param method the PsiMethod for which to build the signature + * @return the method signature with the text before the method body, excluding newline characters + */ + fun buildSignature(method: PsiMethod): String { + val bodyStart = method.body?.startOffsetInParent ?: method.textLength + return method.text.substring(0, bodyStart).replace("\\n", "").trim() + } + } } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt index 50cc12f0f..844230b88 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiClassWrapper.kt @@ -14,6 +14,7 @@ import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.psi.KtClass import org.jetbrains.kotlin.psi.KtClassOrObject import org.jetbrains.kotlin.psi.KtObjectDeclaration +import org.jetbrains.kotlin.psi.allConstructors import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils import org.jetbrains.research.testspark.core.data.ClassType @@ -35,6 +36,8 @@ class KotlinPsiClassWrapper(private val psiClass: KtClassOrObject) : PsiClassWra override val allMethods: List get() = methods + override val constructorSignatures: List get() = psiClass.allConstructors.map { KotlinPsiMethodWrapper.buildSignature(it) } + override val superClass: PsiClassWrapper? get() { // Get the superTypeListEntries of the Kotlin class diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt index fd8a78a1b..760568909 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiHelper.kt @@ -25,6 +25,20 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { override val language: SupportedLanguage get() = SupportedLanguage.Kotlin + /** + * When dealing with Kotlin PSI files, we expect that only classes, their methods, + * top-level functions are tested. + * Therefore, we expect either a class or a method (top-level function) to surround a cursor offset. + * + * This requirement ensures that the user is not trying + * to generate tests for a line of code outside the aforementioned scopes. + * + * @param e `AnActionEvent` representing the current action event. + * @return `true` if the cursor is inside a class or method, `false` otherwise. + */ + override fun availableForGeneration(e: AnActionEvent): Boolean = + getCurrentListOfCodeTypes(e).any { (it.first == CodeType.CLASS) || (it.first == CodeType.METHOD) } + private val log = Logger.getInstance(this::class.java) override fun generateMethodDescriptor(psiMethod: PsiMethodWrapper): String { @@ -64,6 +78,12 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { override fun getSurroundingLineNumber(caretOffset: Int): Int? { val doc = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return null + /** + * See `getLineNumber`'s documentation for details on the numbering. + * It returns an index of the line in the document, starting from 0. + * + * Therefore, we need to increase the result by one to get the line number. + */ val selectedLine = doc.getLineNumber(caretOffset) val selectedLineText = doc.getText(TextRange(doc.getLineStartOffset(selectedLine), doc.getLineEndOffset(selectedLine))) @@ -73,7 +93,7 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { return null } log.info("Surrounding line at caret $caretOffset is $selectedLine") - return selectedLine + return selectedLine + 1 } override fun collectClassesToTest( @@ -150,7 +170,7 @@ class KotlinPsiHelper(private val psiFile: PsiFile) : PsiHelper { val ktClass = getSurroundingClass(caret.offset) val ktFunction = getSurroundingMethod(caret.offset) - val line: Int? = getSurroundingLineNumber(caret.offset)?.plus(1) + val line: Int? = getSurroundingLineNumber(caret.offset) ktClass?.let { result.add(CodeType.CLASS to getClassHTMLDisplayName(it)) } ktFunction?.let { result.add(CodeType.METHOD to getMethodHTMLDisplayName(it)) } diff --git a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt index c993fd808..93f39d6ba 100644 --- a/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt +++ b/kotlin/src/main/kotlin/org/jetbrains/research/testspark/kotlin/KotlinPsiMethodWrapper.kt @@ -35,10 +35,7 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper { } override val signature: String - get() = psiFunction.run { - val bodyStart = bodyExpression?.startOffsetInParent ?: textLength - text.substring(0, bodyStart).replace('\n', ' ').trim() - } + get() = buildSignature(psiFunction) val parameterList = psiFunction.valueParameterList @@ -131,4 +128,17 @@ class KotlinPsiMethodWrapper(val psiFunction: KtFunction) : PsiMethodWrapper { else -> "L${type.replace('.', '/')};" } } + + companion object { + /** + * Builds a signature for a given Kotlin function by extracting the method body portion. + * + * @param function The Kotlin function to build the signature for. + * @return The signature of the function. + */ + fun buildSignature(function: KtFunction) = function.run { + val bodyStart = bodyExpression?.startOffsetInParent ?: textLength + text.substring(0, bodyStart).replace('\n', ' ').trim() + } + } } diff --git a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt index 0aa5dfd0f..3f7f1d0c8 100644 --- a/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt +++ b/langwrappers/src/main/kotlin/org/jetbrains/research/testspark/langwrappers/PsiComponents.kt @@ -13,7 +13,7 @@ typealias CodeTypeDisplayName = Pair /** * Interface representing a wrapper for PSI methods, - * providing common API to handle method-related data for different languages. + * providing a common API to handle method-related data for different languages. * * @property name The name of a method * @property methodDescriptor Human-readable method signature @@ -40,13 +40,14 @@ interface PsiMethodWrapper { /** * Interface representing a wrapper for PSI classes, - * providing common API to handle class-related data for different languages. + * providing a common API to handle class-related data for different languages. * @property name The name of a class * @property qualifiedName The qualified name of the class. * @property text The text of the class. * @property methods All methods in the class * @property allMethods All methods in the class and all its superclasses - * @property superClass The super class of the class + * @property constructorSignatures The signatures of all constructors in the class + * @property superClass The superclass of the class * @property virtualFile Virtual file where the class is located * @property containingFile File where the method is located * @property fullText The source code of the class (with package and imports). @@ -59,6 +60,7 @@ interface PsiClassWrapper { val text: String? val methods: List val allMethods: List + val constructorSignatures: List val superClass: PsiClassWrapper? val virtualFile: VirtualFile val containingFile: PsiFile @@ -90,13 +92,21 @@ interface PsiClassWrapper { interface PsiHelper { val language: SupportedLanguage + /** + * Checks if a code construct is valid for unit test generation at the given caret offset. + * + * @param e The AnActionEvent representing the current action event. + * @return `true` if a code construct is valid for unit test generation at the caret offset, `false` otherwise. + */ + fun availableForGeneration(e: AnActionEvent): Boolean + /** * Returns the surrounding PsiClass object based on the caret position within the specified PsiFile. * The surrounding class is determined by finding the PsiClass objects within the PsiFile and checking * if the caret is within any of them. * * @param caretOffset The offset of the caret position within the PsiFile. - * @return The surrounding PsiClass object if found, null otherwise. + * @return The surrounding `PsiClass` object if found, `null` otherwise. */ fun getSurroundingClass(caretOffset: Int): PsiClassWrapper? @@ -111,6 +121,8 @@ interface PsiHelper { /** * Returns the line number of the selected line where the caret is positioned. * + * The returned line number is **1-based**. + * * @param caretOffset The caret offset within the PSI file. * @return The line number of the selected line, otherwise null. */ diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index 58447b7a8..6304b258d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -74,14 +74,18 @@ class TestSparkAction : AnAction() { /** * Updates the state of the action based on the provided event. * - * @param e the AnActionEvent object representing the event + * @param e `AnActionEvent` object representing the event */ override fun update(e: AnActionEvent) { - val file = e.dataContext.getData(CommonDataKeys.PSI_FILE)!! - val psiHelper = PsiHelperProvider.getPsiHelper(file) - if (psiHelper == null) { - // TODO exception + val file = e.dataContext.getData(CommonDataKeys.PSI_FILE) + + if (file == null) { + e.presentation.isEnabledAndVisible = false + return } + + val psiHelper = PsiHelperProvider.getPsiHelper(file) + e.presentation.isEnabledAndVisible = (psiHelper != null) && psiHelper.availableForGeneration(e) } /** diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 99e1fc4fa..4d4aff983 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -114,6 +114,8 @@ object ToolUtils { fun isProcessCanceled(indicator: CustomProgressIndicator): Boolean { if (indicator.isCanceled()) { + // TODO: we must not stop this indicator! cancellation MAY imply stoppage + // See: https://github.com/JetBrains-Research/TestSpark/issues/375 indicator.stop() return true } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index 0c493f71a..19e7b6092 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -57,7 +57,7 @@ class PromptManager( /** * The `cut` is null when we work with the function outside the class. */ - private val cut: PsiClassWrapper? = if (classesToTest.isNotEmpty()) classesToTest[0] else null + private val cut: PsiClassWrapper? = classesToTest.firstOrNull() private val llmSettingsState: LLMSettingsState get() = project.getService(LLMSettingsService::class.java).state @@ -124,32 +124,49 @@ class PromptManager( } CodeType.LINE -> { + // two possible cases: the line inside a method/function or inside a class val lineNumber = codeType.objectIndex - val psiMethod = getPsiMethod(cut, getMethodDescriptor(cut, lineNumber))!! - // get code of line under test - val document = psiHelper.getDocumentFromPsiFile() - val lineStartOffset = document!!.getLineStartOffset(lineNumber - 1) - val lineEndOffset = document.getLineEndOffset(lineNumber - 1) - - val lineUnderTest = document.getText(TextRange.create(lineStartOffset, lineEndOffset)) - val method = createMethodRepresentation(psiMethod)!! - val interestingClassesFromMethod = - psiHelper.getInterestingPsiClassesWithQualifiedNames(cut, psiMethod) - .map(this::createClassRepresentation) - .toList() - - promptGenerator.generatePromptForLine( - lineUnderTest, - method, - interestingClassesFromMethod, - testSamplesCode, - ) + val lineUnderTest = psiHelper.getDocumentFromPsiFile()!!.let { document -> + val lineStartOffset = document.getLineStartOffset(lineNumber - 1) + val lineEndOffset = document.getLineEndOffset(lineNumber - 1) + document.getText(TextRange.create(lineStartOffset, lineEndOffset)) + } + + val psiMethod = getMethodDescriptor(cut, lineNumber)?.let { descriptor -> + getPsiMethod(cut, descriptor) + } + /** + * if psiMethod exists, then use it as a context for a line, + * otherwise use the cut as a context + */ + if (psiMethod != null) { + val method = createMethodRepresentation(psiMethod)!! + val interestingClassesFromMethod = + psiHelper.getInterestingPsiClassesWithQualifiedNames(cut, psiMethod) + .map(this::createClassRepresentation) + .toList() + + return@Computable promptGenerator.generatePromptForLine( + lineUnderTest, + method, + interestingClassesFromMethod, + testSamplesCode, + packageName = psiHelper.getPackageName(), + ) + } else { + return@Computable promptGenerator.generatePromptForLine( + lineUnderTest, + interestingClasses, + testSamplesCode, + ) + } } } }, ) + LLMSettingsBundle.get("commonPromptPart") log.info("Prompt is:\n$prompt") + println("Prompt is:\n$prompt") return prompt } @@ -167,6 +184,7 @@ class PromptManager( return ClassRepresentation( psiClass.qualifiedName, psiClass.fullText, + psiClass.constructorSignatures, psiClass.allMethods.map(this::createMethodRepresentation).toList().filterNotNull(), psiClass.classType, ) @@ -282,26 +300,29 @@ class PromptManager( * * @param psiClass the PsiClassWrapper containing the method * @param lineNumber the line number within the file where the method is located - * @return the method descriptor as `String`, or an empty string if no method is found + * @return the method descriptor as `String` if the surrounding method exists, or `null` when no method found */ private fun getMethodDescriptor( psiClass: PsiClassWrapper?, lineNumber: Int, - ): String { - // Processing function outside the class - if (psiClass == null) { - val currentPsiMethod = psiHelper.getSurroundingMethod(caret)!! + ): String? { + if (psiClass != null) { + val containingPsiMethod = psiClass.allMethods.find { it.containsLine(lineNumber) } ?: return null + + val file = psiClass.containingFile + val psiHelper = PsiHelperProvider.getPsiHelper(file) + /** + * psiHelper will not be null here because at this point, + * we already know that the current language is supported + */ + return psiHelper!!.generateMethodDescriptor(containingPsiMethod) + } else { + /** + * When no PSI class provided we are dealing with a top-level function. + * Processing function outside the class + */ + val currentPsiMethod = psiHelper.getSurroundingMethod(caret) ?: return null return psiHelper.generateMethodDescriptor(currentPsiMethod) } - for (currentPsiMethod in psiClass.allMethods) { - if (currentPsiMethod.containsLine(lineNumber)) { - val file = psiClass.containingFile - val psiHelper = PsiHelperProvider.getPsiHelper(file) - // psiHelper will not be null here - // because if we are here, then we already know that the current language is supported - return psiHelper!!.generateMethodDescriptor(currentPsiMethod) - } - } - return "" } } diff --git a/src/main/resources/properties/llm/LLMDefaults.properties b/src/main/resources/properties/llm/LLMDefaults.properties index 1eddae6e2..f95c62ef1 100644 --- a/src/main/resources/properties/llm/LLMDefaults.properties +++ b/src/main/resources/properties/llm/LLMDefaults.properties @@ -13,7 +13,7 @@ maxInputParamsDepth=2 maxPolyDepth=2 classPrompt=["Generate unit tests in $LANGUAGE for $NAME to achieve 100% line coverage for this class.\nDont use @Before and @After test methods.\nMake tests as atomic as possible.\nAll tests should be for $TESTING_PLATFORM.\nIn case of mocking, use $MOCKING_FRAMEWORK. But, do not use mocking for all tests.\nName all methods according to the template - [MethodUnderTest][Scenario]Test, and use only English letters.\nThe source code of class under test is as follows:\n$CODE\n$METHODS\n$POLYMORPHISM\n$TEST_SAMPLE"] methodPrompt=["Generate unit tests in $LANGUAGE for $NAME to achieve 100% line coverage for this method.\nDont use @Before and @After test methods.\nMake tests as atomic as possible.\nAll tests should be for $TESTING_PLATFORM.\nIn case of mocking, use $MOCKING_FRAMEWORK. But, do not use mocking for all tests.\nName all methods according to the template - [MethodUnderTest][Scenario]Test, and use only English letters.\nThe source code of method under test is as follows:\n$CODE\n$METHODS\n$POLYMORPHISM\n$TEST_SAMPLE"] -linePrompt=["Generate unit tests in $LANGUAGE for line $NAME in the following code:\n$CODE\nDont use @Before and @After test methods.\nMake tests as atomic as possible.\nAll tests should be for $TESTING_PLATFORM.\nIn case of mocking, use $MOCKING_FRAMEWORK. But, do not use mocking for all tests.\nName all methods according to the template - [MethodUnderTest][Scenario]Test, and use only English letters.\n$METHODS\n$POLYMORPHISM\n$TEST_SAMPLE"] +linePrompt=["Generate unit tests in $LANGUAGE for the line $NAME in the following code:\n$CODE\nDont use @Before and @After test methods.\nMake tests as atomic as possible.\nAll tests should be for $TESTING_PLATFORM.\nIn case of mocking, use $MOCKING_FRAMEWORK. But, do not use mocking for all tests.\nName all methods according to the template - [MethodUnderTest][Scenario]Test, and use only English letters.\n$METHODS\n$POLYMORPHISM\n$TEST_SAMPLE"] classPromptName=["Class line coverage prompt"] methodPromptName=["Method line coverage prompt"] linePromptName=["Line coverage prompt"] diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt index 97094bb75..3cbe4d238 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -181,6 +181,7 @@ class PromptBuilderTest { } """.trimIndent(), allMethods = listOf(method1, method2), + constructorSignatures = emptyList(), classType = ClassType.CLASS, ) @@ -215,6 +216,7 @@ class PromptBuilderTest { } """.trimIndent(), allMethods = emptyList(), + constructorSignatures = emptyList(), classType = ClassType.INTERFACE, ) val mySubClass = ClassRepresentation( @@ -224,6 +226,7 @@ class PromptBuilderTest { } """.trimIndent(), allMethods = emptyList(), + constructorSignatures = emptyList(), classType = ClassType.CLASS, ) val polymorphicRelations = mapOf(myInterface to listOf(mySubClass))