Skip to content

Commit

Permalink
feat(ls): implement "find references"
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalHonegger committed May 7, 2023
1 parent c57babb commit ce8aee0
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class SamtLanguageServer : LanguageServer, LanguageClientAware, Closeable {
full = Either.forLeft(true)
}
definitionProvider = Either.forLeft(true)
referencesProvider = Either.forLeft(true)
}
InitializeResult(capabilities)
}
Expand Down
142 changes: 142 additions & 0 deletions language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package tools.samt.ls

import tools.samt.parser.*
import tools.samt.semantic.*

abstract class SamtSemanticLookup<TKey, TValue> protected constructor() {
protected fun analyze(fileNode: FileNode, samtPackage: Package) {
for (import in fileNode.imports) {
markStatement(samtPackage, import)
}
markStatement(samtPackage, fileNode.packageDeclaration)
for (statement in fileNode.statements) {
markStatement(samtPackage, statement)
}
}

operator fun get(key: TKey) = lookup[key]
operator fun set(key: TKey, value: TValue) {
lookup[key] = value
}

private val lookup = mutableMapOf<TKey, TValue>()

protected open fun markType(node: ExpressionNode, type: Type) {
when (type) {
is ListType -> {
markTypeReference(type.elementType)
}

is MapType -> {
markTypeReference(type.keyType)
markTypeReference(type.valueType)
}

is ConsumerType,
is EnumType,
is ProviderType,
is RecordType,
is ServiceType,
is LiteralType,
is PackageType,
UnknownType,
-> Unit
}
}

protected open fun markTypeReference(reference: TypeReference) {
check(reference is ResolvedTypeReference) { "Unresolved type reference shouldn't be here" }
markType(reference.typeNode, reference.type)
markConstraints(reference.constraints)
}

protected open fun markConstraints(constraints: List<ResolvedTypeReference.Constraint>) {}

protected open fun markAnnotations(annotations: List<AnnotationNode>) {}

protected open fun markStatement(samtPackage: Package, statement: StatementNode) {
when (statement) {
is ConsumerDeclarationNode -> markConsumerDeclaration(samtPackage.getTypeByNode<ConsumerType>(statement))
is ProviderDeclarationNode -> markProviderDeclaration(samtPackage.getTypeByNode<ProviderType>(statement))
is EnumDeclarationNode -> markEnumDeclaration(samtPackage.getTypeByNode<EnumType>(statement))
is RecordDeclarationNode -> markRecordDeclaration(samtPackage.getTypeByNode<RecordType>(statement))
is ServiceDeclarationNode -> markServiceDeclaration(samtPackage.getTypeByNode<ServiceType>(statement))
is TypeAliasNode -> Unit
is PackageDeclarationNode -> markPackageDeclaration(statement)
is ImportNode -> markImport(statement,samtPackage.typeByNode[statement] ?: UnknownType)
}
}

protected open fun markServiceDeclaration(serviceType: ServiceType) {
markAnnotations(serviceType.declaration.annotations)
for (operation in serviceType.operations) {
markOperationDeclaration(operation)
}
}

protected open fun markOperationDeclaration(operation: ServiceType.Operation) {
markAnnotations(operation.declaration.annotations)
for (parameter in operation.parameters) {
markOperationParameterDeclaration(parameter)
}
when (operation) {
is ServiceType.OnewayOperation -> Unit
is ServiceType.RequestResponseOperation -> {
operation.raisesTypes.forEach { markTypeReference(it) }
operation.returnType?.let { markTypeReference(it) }
}
}
}

protected open fun markOperationParameterDeclaration(parameter: ServiceType.Operation.Parameter) {
markAnnotations(parameter.declaration.annotations)
markTypeReference(parameter.type)
}

protected open fun markRecordDeclaration(recordType: RecordType) {
markAnnotations(recordType.declaration.annotations)
for (field in recordType.fields) {
markRecordFieldDeclaration(field)
}
}

protected open fun markRecordFieldDeclaration(field: RecordType.Field) {
markAnnotations(field.declaration.annotations)
markTypeReference(field.type)
}

protected open fun markEnumDeclaration(enumType: EnumType) {
markAnnotations(enumType.declaration.annotations)
}

protected open fun markProviderDeclaration(providerType: ProviderType) {
for (implements in providerType.implements) {
markTypeReference(implements.service)
markOperationReference(implements.operations, implements.node.serviceOperationNames)
}
}

protected open fun markConsumerDeclaration(consumerType: ConsumerType) {
markTypeReference(consumerType.provider)
for (use in consumerType.uses) {
markTypeReference(use.service)
markOperationReference(use.operations, use.node.serviceOperationNames)
}
}

private fun markOperationReference(operations: List<ServiceType.Operation>, operationReferences: List<IdentifierNode>) {
val opLookup = operations.associateBy { it.name }
for (operationName in operationReferences) {
val operation = opLookup[operationName.name] ?: continue
markOperationReference(operation, operationName)
}
}

protected open fun markOperationReference(operation: ServiceType.Operation, reference: IdentifierNode) {}

protected open fun markPackageDeclaration(packageDeclaration: PackageDeclarationNode) {}

protected open fun markImport(import: ImportNode, importedType: Type) {
markType(import.name, importedType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,35 @@ class SamtTextDocumentService(private val workspaces: Map<URI, SamtWorkspace>) :
return@supplyAsync Either.forRight(listOf(locationLink))
}

override fun references(params: ReferenceParams): CompletableFuture<List<Location>> =
CompletableFuture.supplyAsync {
val path = params.textDocument.uri.toPathUri()
val workspace = getWorkspace(path)

val relevantFileInfo = workspace[path] ?: return@supplyAsync emptyList()
val relevantFileNode = relevantFileInfo.fileNode ?: return@supplyAsync emptyList()
val token = relevantFileInfo.tokens.findAt(params.position) ?: return@supplyAsync emptyList()

val globalPackage: Package = workspace.samtPackage ?: return@supplyAsync emptyList()

val typeLookup = SamtDeclarationLookup.analyze(relevantFileNode, globalPackage.resolveSubPackage(relevantFileInfo.fileNode.packageDeclaration.name))
val type = typeLookup[token.location] ?: return@supplyAsync emptyList()

val filesAndPackages = buildList {
for (fileInfo in workspace) {
val fileNode: FileNode = fileInfo.fileNode ?: continue
val samtPackage = globalPackage.resolveSubPackage(fileNode.packageDeclaration.name)
add(fileNode to samtPackage)
}
}

val typeReferencesLookup = SamtReferencesLookup.analyze(filesAndPackages)

val references = typeReferencesLookup[type] ?: emptyList()

return@supplyAsync references.map { Location(it.source.path.toString(), it.toRange()) }
}

override fun semanticTokensFull(params: SemanticTokensParams): CompletableFuture<SemanticTokens> =
CompletableFuture.supplyAsync {
val path = params.textDocument.uri.toPathUri()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package tools.samt.ls

import tools.samt.common.DiagnosticController
import tools.samt.common.SourceFile
import tools.samt.lexer.Lexer
import tools.samt.parser.Parser
import tools.samt.semantic.Package
import tools.samt.semantic.SemanticModelBuilder
import java.net.URI
import kotlin.test.*

class SamtReferencesLookupTest {
@Test
fun `correctly find references in complex model`() {
val serviceSource = """
package test
record Person { }
service PersonService {
getNeighbors(person: Person): Map<String, Person?>? (*..100)
}
""".trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-serviceSource.samt"), it) }
val providerSource = """
package test
provide PersonEndpoint {
implements PersonService { getNeighbors }
transport HTTP
}
""".trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-providerSource.samt"), it) }
val consumerOneSource = """
import test.*
import test.PersonService as Service
package some.other.^package
consume PersonEndpoint {
uses Service { getNeighbors }
}
""".trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-consumerOneSource.samt"), it) }
val consumerTwoSource = """
package somewhere.else
consume test.PersonEndpoint { uses test.PersonService }
""".trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-consumerTwoSource.samt"), it) }
val (samtPackage, referencesLookup) = parse(serviceSource, providerSource, consumerOneSource, consumerTwoSource)

val testPackage = samtPackage.subPackages.single { it.name == "test" }
val person = testPackage.records.single { it.name == "Person" }
val personService = testPackage.services.single { it.name == "PersonService" }
val personEndpoint = testPackage.providers.single { it.name == "PersonEndpoint" }
val getNeighbors = personService.operations.single { it.name == "getNeighbors" }

val personReferences = referencesLookup[person]
assertNotNull(personReferences)
assertEquals(2, personReferences.size, "Following list had unexpected amount of entries: $personReferences")
assertContains(personReferences, TestLocation("5:25" to "5:31").getLocation(serviceSource))
assertContains(personReferences, TestLocation("5:46" to "5:52").getLocation(serviceSource))

val personServiceReferences = referencesLookup[personService]
assertNotNull(personServiceReferences)
assertEquals(4, personServiceReferences.size, "Following list had unexpected amount of entries: $personServiceReferences")
assertContains(personServiceReferences, TestLocation("3:15" to "3:28").getLocation(providerSource))
assertContains(personServiceReferences, TestLocation("1:12" to "1:25").getLocation(consumerOneSource))
assertContains(personServiceReferences, TestLocation("6:9" to "6:16").getLocation(consumerOneSource))
assertContains(personServiceReferences, TestLocation("2:40" to "2:53").getLocation(consumerTwoSource))

val personEndpointReferences = referencesLookup[personEndpoint]
assertNotNull(personEndpointReferences)
assertEquals(2, personEndpointReferences.size, "Following list had unexpected amount of entries: $personEndpointReferences")
assertContains(personEndpointReferences, TestLocation("5:8" to "5:22").getLocation(consumerOneSource))
assertContains(personEndpointReferences, TestLocation("2:13" to "2:27").getLocation(consumerTwoSource))

val getNeighborsReferences = referencesLookup[getNeighbors]
assertNotNull(getNeighborsReferences)
assertEquals(2, getNeighborsReferences.size, "Following list had unexpected amount of entries: $getNeighborsReferences")
assertContains(getNeighborsReferences, TestLocation("3:31" to "3:43").getLocation(providerSource))
assertContains(getNeighborsReferences, TestLocation("6:19" to "6:31").getLocation(consumerOneSource))
}

private fun parse(vararg sourceAndExpectedMessages: SourceFile): Pair<Package, SamtReferencesLookup> {
val diagnosticController = DiagnosticController(URI("file:///tmp"))
val fileTree = sourceAndExpectedMessages.map { sourceFile ->
val parseContext = diagnosticController.getOrCreateContext(sourceFile)
val stream = Lexer.scan(sourceFile.content.reader(), parseContext)
val fileTree = Parser.parse(sourceFile, stream, parseContext)
assertFalse(parseContext.hasErrors(), "Expected no parse errors, but had errors: ${parseContext.messages}}")
fileTree
}

val samtPackage = SemanticModelBuilder.build(fileTree, diagnosticController)

val filesAndPackages = fileTree.map { it to samtPackage.resolveSubPackage(it.packageDeclaration.name) }
return Pair(samtPackage, SamtReferencesLookup.analyze(filesAndPackages))
}
}

0 comments on commit ce8aee0

Please sign in to comment.