From ce8aee079aeea137acfa7412be097e9936f4f2e8 Mon Sep 17 00:00:00 2001 From: Pascal Honegger Date: Sun, 7 May 2023 15:51:37 +0200 Subject: [PATCH] feat(ls): implement "find references" --- .../tools/samt/ls/SamtLanguageServer.kt | 1 + .../tools/samt/ls/SamtSemanticLookup.kt | 142 ++++++++++++++++++ .../tools/samt/ls/SamtTextDocumentService.kt | 29 ++++ .../tools/samt/ls/SamtReferencesLookupTest.kt | 98 ++++++++++++ 4 files changed, 270 insertions(+) create mode 100644 language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt create mode 100644 language-server/src/test/kotlin/tools/samt/ls/SamtReferencesLookupTest.kt diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt index ad4fa0d7..a59a5f73 100644 --- a/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt @@ -29,6 +29,7 @@ class SamtLanguageServer : LanguageServer, LanguageClientAware, Closeable { full = Either.forLeft(true) } definitionProvider = Either.forLeft(true) + referencesProvider = Either.forLeft(true) } InitializeResult(capabilities) } diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt new file mode 100644 index 00000000..629842c3 --- /dev/null +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt @@ -0,0 +1,142 @@ +package tools.samt.ls + +import tools.samt.parser.* +import tools.samt.semantic.* + +abstract class SamtSemanticLookup protected constructor() { + protected fun analyze(fileNode: FileNode, samtPackage: Package) { + for (import in fileNode.imports) { + markStatement(samtPackage, import) + } + markStatement(samtPackage, fileNode.packageDeclaration) + for (statement in fileNode.statements) { + markStatement(samtPackage, statement) + } + } + + operator fun get(key: TKey) = lookup[key] + operator fun set(key: TKey, value: TValue) { + lookup[key] = value + } + + private val lookup = mutableMapOf() + + protected open fun markType(node: ExpressionNode, type: Type) { + when (type) { + is ListType -> { + markTypeReference(type.elementType) + } + + is MapType -> { + markTypeReference(type.keyType) + markTypeReference(type.valueType) + } + + is ConsumerType, + is EnumType, + is ProviderType, + is RecordType, + is ServiceType, + is LiteralType, + is PackageType, + UnknownType, + -> Unit + } + } + + protected open fun markTypeReference(reference: TypeReference) { + check(reference is ResolvedTypeReference) { "Unresolved type reference shouldn't be here" } + markType(reference.typeNode, reference.type) + markConstraints(reference.constraints) + } + + protected open fun markConstraints(constraints: List) {} + + protected open fun markAnnotations(annotations: List) {} + + protected open fun markStatement(samtPackage: Package, statement: StatementNode) { + when (statement) { + is ConsumerDeclarationNode -> markConsumerDeclaration(samtPackage.getTypeByNode(statement)) + is ProviderDeclarationNode -> markProviderDeclaration(samtPackage.getTypeByNode(statement)) + is EnumDeclarationNode -> markEnumDeclaration(samtPackage.getTypeByNode(statement)) + is RecordDeclarationNode -> markRecordDeclaration(samtPackage.getTypeByNode(statement)) + is ServiceDeclarationNode -> markServiceDeclaration(samtPackage.getTypeByNode(statement)) + is TypeAliasNode -> Unit + is PackageDeclarationNode -> markPackageDeclaration(statement) + is ImportNode -> markImport(statement,samtPackage.typeByNode[statement] ?: UnknownType) + } + } + + protected open fun markServiceDeclaration(serviceType: ServiceType) { + markAnnotations(serviceType.declaration.annotations) + for (operation in serviceType.operations) { + markOperationDeclaration(operation) + } + } + + protected open fun markOperationDeclaration(operation: ServiceType.Operation) { + markAnnotations(operation.declaration.annotations) + for (parameter in operation.parameters) { + markOperationParameterDeclaration(parameter) + } + when (operation) { + is ServiceType.OnewayOperation -> Unit + is ServiceType.RequestResponseOperation -> { + operation.raisesTypes.forEach { markTypeReference(it) } + operation.returnType?.let { markTypeReference(it) } + } + } + } + + protected open fun markOperationParameterDeclaration(parameter: ServiceType.Operation.Parameter) { + markAnnotations(parameter.declaration.annotations) + markTypeReference(parameter.type) + } + + protected open fun markRecordDeclaration(recordType: RecordType) { + markAnnotations(recordType.declaration.annotations) + for (field in recordType.fields) { + markRecordFieldDeclaration(field) + } + } + + protected open fun markRecordFieldDeclaration(field: RecordType.Field) { + markAnnotations(field.declaration.annotations) + markTypeReference(field.type) + } + + protected open fun markEnumDeclaration(enumType: EnumType) { + markAnnotations(enumType.declaration.annotations) + } + + protected open fun markProviderDeclaration(providerType: ProviderType) { + for (implements in providerType.implements) { + markTypeReference(implements.service) + markOperationReference(implements.operations, implements.node.serviceOperationNames) + } + } + + protected open fun markConsumerDeclaration(consumerType: ConsumerType) { + markTypeReference(consumerType.provider) + for (use in consumerType.uses) { + markTypeReference(use.service) + markOperationReference(use.operations, use.node.serviceOperationNames) + } + } + + private fun markOperationReference(operations: List, operationReferences: List) { + val opLookup = operations.associateBy { it.name } + for (operationName in operationReferences) { + val operation = opLookup[operationName.name] ?: continue + markOperationReference(operation, operationName) + } + } + + protected open fun markOperationReference(operation: ServiceType.Operation, reference: IdentifierNode) {} + + protected open fun markPackageDeclaration(packageDeclaration: PackageDeclarationNode) {} + + protected open fun markImport(import: ImportNode, importedType: Type) { + markType(import.name, importedType) + } +} diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt index 444e643f..cd6ab1d2 100644 --- a/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt @@ -86,6 +86,35 @@ class SamtTextDocumentService(private val workspaces: Map) : return@supplyAsync Either.forRight(listOf(locationLink)) } + override fun references(params: ReferenceParams): CompletableFuture> = + CompletableFuture.supplyAsync { + val path = params.textDocument.uri.toPathUri() + val workspace = getWorkspace(path) + + val relevantFileInfo = workspace[path] ?: return@supplyAsync emptyList() + val relevantFileNode = relevantFileInfo.fileNode ?: return@supplyAsync emptyList() + val token = relevantFileInfo.tokens.findAt(params.position) ?: return@supplyAsync emptyList() + + val globalPackage: Package = workspace.samtPackage ?: return@supplyAsync emptyList() + + val typeLookup = SamtDeclarationLookup.analyze(relevantFileNode, globalPackage.resolveSubPackage(relevantFileInfo.fileNode.packageDeclaration.name)) + val type = typeLookup[token.location] ?: return@supplyAsync emptyList() + + val filesAndPackages = buildList { + for (fileInfo in workspace) { + val fileNode: FileNode = fileInfo.fileNode ?: continue + val samtPackage = globalPackage.resolveSubPackage(fileNode.packageDeclaration.name) + add(fileNode to samtPackage) + } + } + + val typeReferencesLookup = SamtReferencesLookup.analyze(filesAndPackages) + + val references = typeReferencesLookup[type] ?: emptyList() + + return@supplyAsync references.map { Location(it.source.path.toString(), it.toRange()) } + } + override fun semanticTokensFull(params: SemanticTokensParams): CompletableFuture = CompletableFuture.supplyAsync { val path = params.textDocument.uri.toPathUri() diff --git a/language-server/src/test/kotlin/tools/samt/ls/SamtReferencesLookupTest.kt b/language-server/src/test/kotlin/tools/samt/ls/SamtReferencesLookupTest.kt new file mode 100644 index 00000000..7f7acea4 --- /dev/null +++ b/language-server/src/test/kotlin/tools/samt/ls/SamtReferencesLookupTest.kt @@ -0,0 +1,98 @@ +package tools.samt.ls + +import tools.samt.common.DiagnosticController +import tools.samt.common.SourceFile +import tools.samt.lexer.Lexer +import tools.samt.parser.Parser +import tools.samt.semantic.Package +import tools.samt.semantic.SemanticModelBuilder +import java.net.URI +import kotlin.test.* + +class SamtReferencesLookupTest { + @Test + fun `correctly find references in complex model`() { + val serviceSource = """ + package test + + record Person { } + + service PersonService { + getNeighbors(person: Person): Map? (*..100) + } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-serviceSource.samt"), it) } + val providerSource = """ + package test + + provide PersonEndpoint { + implements PersonService { getNeighbors } + + transport HTTP + } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-providerSource.samt"), it) } + val consumerOneSource = """ + import test.* + import test.PersonService as Service + + package some.other.^package + + consume PersonEndpoint { + uses Service { getNeighbors } + } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-consumerOneSource.samt"), it) } + val consumerTwoSource = """ + package somewhere.else + + consume test.PersonEndpoint { uses test.PersonService } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-consumerTwoSource.samt"), it) } + val (samtPackage, referencesLookup) = parse(serviceSource, providerSource, consumerOneSource, consumerTwoSource) + + val testPackage = samtPackage.subPackages.single { it.name == "test" } + val person = testPackage.records.single { it.name == "Person" } + val personService = testPackage.services.single { it.name == "PersonService" } + val personEndpoint = testPackage.providers.single { it.name == "PersonEndpoint" } + val getNeighbors = personService.operations.single { it.name == "getNeighbors" } + + val personReferences = referencesLookup[person] + assertNotNull(personReferences) + assertEquals(2, personReferences.size, "Following list had unexpected amount of entries: $personReferences") + assertContains(personReferences, TestLocation("5:25" to "5:31").getLocation(serviceSource)) + assertContains(personReferences, TestLocation("5:46" to "5:52").getLocation(serviceSource)) + + val personServiceReferences = referencesLookup[personService] + assertNotNull(personServiceReferences) + assertEquals(4, personServiceReferences.size, "Following list had unexpected amount of entries: $personServiceReferences") + assertContains(personServiceReferences, TestLocation("3:15" to "3:28").getLocation(providerSource)) + assertContains(personServiceReferences, TestLocation("1:12" to "1:25").getLocation(consumerOneSource)) + assertContains(personServiceReferences, TestLocation("6:9" to "6:16").getLocation(consumerOneSource)) + assertContains(personServiceReferences, TestLocation("2:40" to "2:53").getLocation(consumerTwoSource)) + + val personEndpointReferences = referencesLookup[personEndpoint] + assertNotNull(personEndpointReferences) + assertEquals(2, personEndpointReferences.size, "Following list had unexpected amount of entries: $personEndpointReferences") + assertContains(personEndpointReferences, TestLocation("5:8" to "5:22").getLocation(consumerOneSource)) + assertContains(personEndpointReferences, TestLocation("2:13" to "2:27").getLocation(consumerTwoSource)) + + val getNeighborsReferences = referencesLookup[getNeighbors] + assertNotNull(getNeighborsReferences) + assertEquals(2, getNeighborsReferences.size, "Following list had unexpected amount of entries: $getNeighborsReferences") + assertContains(getNeighborsReferences, TestLocation("3:31" to "3:43").getLocation(providerSource)) + assertContains(getNeighborsReferences, TestLocation("6:19" to "6:31").getLocation(consumerOneSource)) + } + + private fun parse(vararg sourceAndExpectedMessages: SourceFile): Pair { + val diagnosticController = DiagnosticController(URI("file:///tmp")) + val fileTree = sourceAndExpectedMessages.map { sourceFile -> + val parseContext = diagnosticController.getOrCreateContext(sourceFile) + val stream = Lexer.scan(sourceFile.content.reader(), parseContext) + val fileTree = Parser.parse(sourceFile, stream, parseContext) + assertFalse(parseContext.hasErrors(), "Expected no parse errors, but had errors: ${parseContext.messages}}") + fileTree + } + + val samtPackage = SemanticModelBuilder.build(fileTree, diagnosticController) + + val filesAndPackages = fileTree.map { it to samtPackage.resolveSubPackage(it.packageDeclaration.name) } + return Pair(samtPackage, SamtReferencesLookup.analyze(filesAndPackages)) + } +}