diff --git a/cli/src/main/kotlin/tools/samt/cli/ASTPrinter.kt b/cli/src/main/kotlin/tools/samt/cli/ASTPrinter.kt index aac32b37..72008581 100644 --- a/cli/src/main/kotlin/tools/samt/cli/ASTPrinter.kt +++ b/cli/src/main/kotlin/tools/samt/cli/ASTPrinter.kt @@ -56,7 +56,7 @@ internal object ASTPrinter { } private fun dumpInfo(node: Node): String? = when (node) { - is FileNode -> gray(node.sourceFile.path.path) + is FileNode -> gray(node.sourceFile.path.toString()) is RequestResponseOperationNode -> if (node.isAsync) red("async") else null is IdentifierNode -> yellow(node.name) is ImportBundleIdentifierNode -> yellow(node.name) + if (node.isWildcard) yellow(".*") else "" diff --git a/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt b/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt index 57e05bcd..167e11b2 100644 --- a/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt +++ b/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt @@ -141,7 +141,7 @@ internal class DiagnosticFormatter( // -----> : append(gray(" ---> ")) - append(diagnosticController.workingDirectory.relativize(errorSourceFilePath)) + append(errorSourceFilePath.toString()) if (message.highlights.isNotEmpty()) { val firstHighlight = message.highlights.first() val firstHighlightLocation = firstHighlight.location diff --git a/cli/src/test/kotlin/tools/samt/cli/ASTPrinterTest.kt b/cli/src/test/kotlin/tools/samt/cli/ASTPrinterTest.kt index 11e7d86f..99bcab8e 100644 --- a/cli/src/test/kotlin/tools/samt/cli/ASTPrinterTest.kt +++ b/cli/src/test/kotlin/tools/samt/cli/ASTPrinterTest.kt @@ -36,7 +36,7 @@ class ASTPrinterTest { val dumpWithoutColorCodes = dump.replace(Regex("\u001B\\[[;\\d]*m"), "") assertEquals(""" - FileNode /tmp/ASTPrinterTest.samt <1:1> + FileNode file:///tmp/ASTPrinterTest.samt <1:1> ├─WildcardImportNode <1:1> │ └─ImportBundleIdentifierNode foo.bar.baz.* <1:8> │ ├─IdentifierNode foo <1:8> diff --git a/cli/src/test/kotlin/tools/samt/cli/DiagnosticFormatterTest.kt b/cli/src/test/kotlin/tools/samt/cli/DiagnosticFormatterTest.kt index d3363f45..052054b4 100644 --- a/cli/src/test/kotlin/tools/samt/cli/DiagnosticFormatterTest.kt +++ b/cli/src/test/kotlin/tools/samt/cli/DiagnosticFormatterTest.kt @@ -8,7 +8,6 @@ import tools.samt.parser.EnumDeclarationNode import tools.samt.parser.FileNode import tools.samt.parser.Parser import java.net.URI -import kotlin.io.path.Path import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -40,8 +39,8 @@ class DiagnosticFormatterTest { @Test fun `file messages with no highlights`() { - val baseDirectory = Path("/tmp").toUri() - val filePath = Path("/tmp", "test.txt").toUri() + val baseDirectory = URI("file:///tmp") + val filePath = URI("file:///tmp/test.txt") val controller = DiagnosticController(baseDirectory) val source = "" val sourceFile = SourceFile(filePath, source) @@ -65,15 +64,15 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> test.txt + ---> file:///tmp/test.txt ──────────────────────────────────────── WARNING: some warning - ---> test.txt + ---> file:///tmp/test.txt ──────────────────────────────────────── INFO: some info - ---> test.txt + ---> file:///tmp/test.txt ──────────────────────────────────────── FAILED in 0ms (1 error(s), 1 warning(s)) @@ -104,7 +103,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:2:1 + ---> file:///tmp/DiagnosticFormatterTest.samt:2:1 1 │ package debug |> 2 │ enum Test { @@ -140,7 +139,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:2:1 + ---> file:///tmp/DiagnosticFormatterTest.samt:2:1 1 │ package debug |> 2 │ enum Test { @@ -183,7 +182,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:3:5 + ---> file:///tmp/DiagnosticFormatterTest.samt:3:5 1 │ package debug 2 │ enum Test { @@ -239,7 +238,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:3:5 + ---> file:///tmp/DiagnosticFormatterTest.samt:3:5 1 │ package debug 2 │ enum Test { @@ -293,7 +292,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:3:5 + ---> file:///tmp/DiagnosticFormatterTest.samt:3:5 1 │ package debug 2 │ enum Test { @@ -339,7 +338,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:3:5 + ---> file:///tmp/DiagnosticFormatterTest.samt:3:5 1 │ package debug 2 │ enum Test { @@ -382,7 +381,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:2:1 + ---> file:///tmp/DiagnosticFormatterTest.samt:2:1 1 │ package debug 2 │ enum Test { @@ -420,7 +419,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:2:1 + ---> file:///tmp/DiagnosticFormatterTest.samt:2:1 1 │ package debug 2 │ enum Test { @@ -461,7 +460,7 @@ class DiagnosticFormatterTest { assertEquals(""" ──────────────────────────────────────── ERROR: some error - ---> DiagnosticFormatterTest.samt:2:1 + ---> file:///tmp/DiagnosticFormatterTest.samt:2:1 1 │ package debug 2 │ enum Test { @@ -480,8 +479,8 @@ class DiagnosticFormatterTest { } private fun parse(source: String): Triple { - val baseDirectory = Path("/tmp").toUri() - val filePath = Path("/tmp", "DiagnosticFormatterTest.samt").toUri() + val baseDirectory = URI("file:///tmp") + val filePath = URI("file:///tmp/DiagnosticFormatterTest.samt") val sourceFile = SourceFile(filePath, source) val diagnosticController = DiagnosticController(baseDirectory) val diagnosticContext = diagnosticController.getOrCreateContext(sourceFile) diff --git a/language-server/src/main/kotlin/tools/samt/ls/Mapping.kt b/language-server/src/main/kotlin/tools/samt/ls/Mapping.kt index 710a3507..4de60438 100644 --- a/language-server/src/main/kotlin/tools/samt/ls/Mapping.kt +++ b/language-server/src/main/kotlin/tools/samt/ls/Mapping.kt @@ -28,6 +28,6 @@ fun DiagnosticMessage.toDiagnostic(): Diagnostic? { fun SamtLocation.toRange(): Range { return Range( Position(start.row, start.col), - Position(start.row, end.col) + Position(end.row, end.col) ) } diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtDeclarationLookup.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtDeclarationLookup.kt new file mode 100644 index 00000000..bb694db4 --- /dev/null +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtDeclarationLookup.kt @@ -0,0 +1,52 @@ +package tools.samt.ls + +import tools.samt.common.Location +import tools.samt.parser.BundleIdentifierNode +import tools.samt.parser.ExpressionNode +import tools.samt.parser.FileNode +import tools.samt.parser.IdentifierNode +import tools.samt.semantic.* + +class SamtDeclarationLookup private constructor() : SamtSemanticLookup() { + override fun markType(node: ExpressionNode, type: Type) { + super.markType(node, type) + + if (type is UserDeclared) { + if (node is BundleIdentifierNode) { + this[node.components.last().location] = type + } else { + this[node.location] = type + } + } + } + + override fun markOperationReference(operation: ServiceType.Operation, reference: IdentifierNode) { + super.markOperationReference(operation, reference) + this[reference.location] = operation + } + + override fun markProviderDeclaration(providerType: ProviderType) { + super.markProviderDeclaration(providerType) + this[providerType.declaration.name.location] = providerType + } + + override fun markServiceDeclaration(serviceType: ServiceType) { + super.markServiceDeclaration(serviceType) + this[serviceType.declaration.name.location] = serviceType + } + + override fun markRecordDeclaration(recordType: RecordType) { + super.markRecordDeclaration(recordType) + this[recordType.declaration.name.location] = recordType + } + + override fun markOperationDeclaration(operation: ServiceType.Operation) { + super.markOperationDeclaration(operation) + this[operation.declaration.name.location] = operation + } + + companion object { + fun analyze(fileNode: FileNode, samtPackage: Package) = + SamtDeclarationLookup().also { it.analyze(fileNode, samtPackage) } + } +} diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt index ce0d194a..a59a5f73 100644 --- a/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtLanguageServer.kt @@ -1,8 +1,11 @@ package tools.samt.ls import org.eclipse.lsp4j.* +import org.eclipse.lsp4j.jsonrpc.messages.Either import org.eclipse.lsp4j.services.* -import tools.samt.common.* +import tools.samt.common.DiagnosticController +import tools.samt.common.collectSamtFiles +import tools.samt.common.readSamtSource import java.io.Closeable import java.net.URI import java.util.concurrent.CompletableFuture @@ -19,7 +22,14 @@ class SamtLanguageServer : LanguageServer, LanguageClientAware, Closeable { CompletableFuture.supplyAsync { buildSamtModel(params) val capabilities = ServerCapabilities().apply { - setTextDocumentSync(TextDocumentSyncKind.Full) + textDocumentSync = Either.forLeft(TextDocumentSyncKind.Full) + semanticTokensProvider = SemanticTokensWithRegistrationOptions().apply { + legend = SamtSemanticTokens.legend + range = Either.forLeft(false) + full = Either.forLeft(true) + } + definitionProvider = Either.forLeft(true) + referencesProvider = Either.forLeft(true) } InitializeResult(capabilities) } diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtReferencesLookup.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtReferencesLookup.kt new file mode 100644 index 00000000..dc4fe6d9 --- /dev/null +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtReferencesLookup.kt @@ -0,0 +1,47 @@ +package tools.samt.ls + +import tools.samt.common.Location +import tools.samt.parser.BundleIdentifierNode +import tools.samt.parser.ExpressionNode +import tools.samt.parser.FileNode +import tools.samt.parser.IdentifierNode +import tools.samt.semantic.Package +import tools.samt.semantic.ServiceType +import tools.samt.semantic.Type +import tools.samt.semantic.UserDeclared + +class SamtReferencesLookup private constructor() : SamtSemanticLookup>() { + private fun addUsage(declaration: UserDeclared, usage: Location) { + if (this[declaration] == null) { + this[declaration] = mutableListOf() + } + (this[declaration] as MutableList) += usage + } + + override fun markType(node: ExpressionNode, type: Type) { + super.markType(node, type) + + if (type is UserDeclared) { + if (node is BundleIdentifierNode) { + addUsage(type, node.components.last().location) + } else { + addUsage(type, node.location) + } + } + } + + override fun markOperationReference(operation: ServiceType.Operation, reference: IdentifierNode) { + super.markOperationReference(operation, reference) + addUsage(operation, reference.location) + } + + companion object { + fun analyze(filesAndPackages: List>): SamtReferencesLookup { + val lookup = SamtReferencesLookup() + for ((fileInfo, samtPackage) in filesAndPackages) { + lookup.analyze(fileInfo, samtPackage) + } + return lookup + } + } +} diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt new file mode 100644 index 00000000..55095440 --- /dev/null +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticLookup.kt @@ -0,0 +1,142 @@ +package tools.samt.ls + +import tools.samt.parser.* +import tools.samt.semantic.* + +abstract class SamtSemanticLookup protected constructor() { + protected fun analyze(fileNode: FileNode, samtPackage: Package) { + for (import in fileNode.imports) { + markStatement(samtPackage, import) + } + markStatement(samtPackage, fileNode.packageDeclaration) + for (statement in fileNode.statements) { + markStatement(samtPackage, statement) + } + } + + operator fun get(key: TKey) = lookup[key] + operator fun set(key: TKey, value: TValue) { + lookup[key] = value + } + + private val lookup = mutableMapOf() + + protected open fun markType(node: ExpressionNode, type: Type) { + when (type) { + is ListType -> { + markTypeReference(type.elementType) + } + + is MapType -> { + markTypeReference(type.keyType) + markTypeReference(type.valueType) + } + + is ConsumerType, + is EnumType, + is ProviderType, + is RecordType, + is ServiceType, + is LiteralType, + is PackageType, + UnknownType, + -> Unit + } + } + + protected open fun markTypeReference(reference: TypeReference) { + check(reference is ResolvedTypeReference) { "Unresolved type reference shouldn't be here" } + markType(reference.typeNode, reference.type) + markConstraints(reference.constraints) + } + + protected open fun markConstraints(constraints: List) {} + + protected open fun markAnnotations(annotations: List) {} + + protected open fun markStatement(samtPackage: Package, statement: StatementNode) { + when (statement) { + is ConsumerDeclarationNode -> markConsumerDeclaration(samtPackage.getTypeByNode(statement)) + is ProviderDeclarationNode -> markProviderDeclaration(samtPackage.getTypeByNode(statement)) + is EnumDeclarationNode -> markEnumDeclaration(samtPackage.getTypeByNode(statement)) + is RecordDeclarationNode -> markRecordDeclaration(samtPackage.getTypeByNode(statement)) + is ServiceDeclarationNode -> markServiceDeclaration(samtPackage.getTypeByNode(statement)) + is TypeAliasNode -> Unit + is PackageDeclarationNode -> markPackageDeclaration(statement) + is ImportNode -> markImport(statement,samtPackage.typeByNode[statement] ?: UnknownType) + } + } + + protected open fun markServiceDeclaration(serviceType: ServiceType) { + markAnnotations(serviceType.declaration.annotations) + for (operation in serviceType.operations) { + markOperationDeclaration(operation) + } + } + + protected open fun markOperationDeclaration(operation: ServiceType.Operation) { + markAnnotations(operation.declaration.annotations) + for (parameter in operation.parameters) { + markOperationParameterDeclaration(parameter) + } + when (operation) { + is ServiceType.OnewayOperation -> Unit + is ServiceType.RequestResponseOperation -> { + operation.raisesTypes.forEach { markTypeReference(it) } + operation.returnType?.let { markTypeReference(it) } + } + } + } + + protected open fun markOperationParameterDeclaration(parameter: ServiceType.Operation.Parameter) { + markAnnotations(parameter.declaration.annotations) + markTypeReference(parameter.type) + } + + protected open fun markRecordDeclaration(recordType: RecordType) { + markAnnotations(recordType.declaration.annotations) + for (field in recordType.fields) { + markRecordFieldDeclaration(field) + } + } + + protected open fun markRecordFieldDeclaration(field: RecordType.Field) { + markAnnotations(field.declaration.annotations) + markTypeReference(field.type) + } + + protected open fun markEnumDeclaration(enumType: EnumType) { + markAnnotations(enumType.declaration.annotations) + } + + protected open fun markProviderDeclaration(providerType: ProviderType) { + for (implements in providerType.implements) { + markTypeReference(implements.service) + markOperationReference(implements.operations, implements.node.serviceOperationNames) + } + } + + protected open fun markConsumerDeclaration(consumerType: ConsumerType) { + markTypeReference(consumerType.provider) + for (use in consumerType.uses) { + markTypeReference(use.service) + markOperationReference(use.operations, use.node.serviceOperationNames) + } + } + + private fun markOperationReference(operations: List, operationReferences: List) { + val opLookup = operations.associateBy { it.name } + for (operationName in operationReferences) { + val operation = opLookup[operationName.name] ?: continue + markOperationReference(operation, operationName) + } + } + + protected open fun markOperationReference(operation: ServiceType.Operation, reference: IdentifierNode) {} + + protected open fun markPackageDeclaration(packageDeclaration: PackageDeclarationNode) {} + + protected open fun markImport(import: ImportNode, importedType: Type) { + markType(import.name, importedType) + } +} diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticTokens.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticTokens.kt new file mode 100644 index 00000000..46add33a --- /dev/null +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtSemanticTokens.kt @@ -0,0 +1,186 @@ +package tools.samt.ls + +import org.eclipse.lsp4j.SemanticTokensLegend +import tools.samt.common.Location +import tools.samt.parser.* +import tools.samt.semantic.* + +class SamtSemanticTokens private constructor() : SamtSemanticLookup() { + override fun markType(node: ExpressionNode, type: Type) { + super.markType(node, type) + val location = if (node is BundleIdentifierNode) { + node.components.last().location + } else { + node.location + } + when (type) { + is ConsumerType -> this[location] = Metadata(TokenType.type) + + is EnumType -> this[location] = Metadata(TokenType.enum) + + is ListType -> { + this[type.node.base.location] = + Metadata(TokenType.type, TokenModifier.defaultLibrary) + } + + is MapType -> { + this[type.node.base.location] = + Metadata(TokenType.type, TokenModifier.defaultLibrary) + } + + is ProviderType -> this[location] = Metadata(TokenType.type) + is RecordType -> this[location] = Metadata(TokenType.`class`) + is ServiceType -> this[location] = Metadata(TokenType.`interface`) + is LiteralType -> this[location] = + Metadata(TokenType.type, TokenModifier.defaultLibrary) + + is PackageType -> this[location] = Metadata(TokenType.namespace) + UnknownType -> this[location] = Metadata(TokenType.type) + } + } + + override fun markConstraints(constraints: List) { + super.markConstraints(constraints) + for (constraint in constraints.map { it.node }.filterIsInstance()) { + this[constraint.base.location] = Metadata(TokenType.function, TokenModifier.defaultLibrary) + } + } + + override fun markAnnotations(annotations: List) { + super.markAnnotations(annotations) + for (annotation in annotations) { + this[annotation.name.location] = Metadata(TokenType.type, TokenModifier.defaultLibrary) + } + } + + override fun markServiceDeclaration(serviceType: ServiceType) { + super.markServiceDeclaration(serviceType) + this[serviceType.declaration.name.location] = Metadata(TokenType.`interface`, TokenModifier.declaration) + } + + override fun markOperationDeclaration(operation: ServiceType.Operation) { + super.markOperationDeclaration(operation) + this[operation.declaration.name.location] = Metadata( + type = TokenType.method, + modifier = if (operation is ServiceType.RequestResponseOperation && operation.isAsync) { + TokenModifier.declaration and TokenModifier.async + } else { + TokenModifier.declaration + } + ) + } + + override fun markOperationParameterDeclaration(parameter: ServiceType.Operation.Parameter) { + super.markOperationParameterDeclaration(parameter) + this[parameter.declaration.name.location] = Metadata(TokenType.parameter, TokenModifier.declaration) + } + + override fun markRecordDeclaration(recordType: RecordType) { + super.markRecordDeclaration(recordType) + this[recordType.declaration.name.location] = Metadata(TokenType.`class`, TokenModifier.declaration) + } + + override fun markRecordFieldDeclaration(field: RecordType.Field) { + super.markRecordFieldDeclaration(field) + this[field.declaration.name.location] = Metadata(TokenType.property, TokenModifier.declaration) + } + + override fun markEnumDeclaration(enumType: EnumType) { + super.markEnumDeclaration(enumType) + this[enumType.declaration.name.location] = Metadata(TokenType.enum, TokenModifier.declaration) + for (enumMember in enumType.declaration.values) { + this[enumMember.location] = Metadata(TokenType.enumMember, TokenModifier.declaration) + } + } + + override fun markProviderDeclaration(providerType: ProviderType) { + super.markProviderDeclaration(providerType) + this[providerType.declaration.name.location] = Metadata(TokenType.type, TokenModifier.declaration) + } + + override fun markOperationReference(operation: ServiceType.Operation, reference: IdentifierNode) { + super.markOperationReference(operation, reference) + this[reference.location] = Metadata( + type = TokenType.method, + modifier = if (operation is ServiceType.RequestResponseOperation && operation.isAsync) { + TokenModifier.async + } else { + TokenModifier.none + } + ) + } + + override fun markPackageDeclaration(packageDeclaration: PackageDeclarationNode) { + super.markPackageDeclaration(packageDeclaration) + for (component in packageDeclaration.name.components) { + this[component.location] = Metadata(TokenType.namespace) + } + } + + override fun markImport(import: ImportNode, importedType: Type) { + super.markImport(import, importedType) + val typeLocation = import.name.components.last().location + if (import is TypeImportNode && import.alias != null) { + this[import.alias!!.location] = this[typeLocation]!!.copy(modifier = TokenModifier.declaration) + } + } + + data class Metadata(val type: TokenType, val modifier: TokenModifier = TokenModifier.none) + + @Suppress("EnumEntryName") + enum class TokenType { + /** SAMT Operations */ + method, + + /** SAMT Constraints */ + function, + + /** SAMT Enum Member */ + enumMember, + + /** SAMT Record Field */ + property, + + /** SAMT Operation Parameter */ + parameter, + + /** SAMT Record */ + `class`, + + /** SAMT Enum */ + `enum`, + + /** SAMT Service */ + `interface`, + + /** SAMT Consumer & Provider */ + type, + + /** SAMT Package */ + namespace, + } + + @JvmInline + value class TokenModifier private constructor(val bitmask: Int) { + + infix fun and(other: TokenModifier) = TokenModifier(this.bitmask.or(other.bitmask)) + + companion object { + val none = TokenModifier(0) + val declaration = TokenModifier(1 shl 0) + val async = TokenModifier(1 shl 1) + val defaultLibrary = TokenModifier(1 shl 2) + fun values() = arrayOf(::declaration, ::async, ::defaultLibrary) + } + } + + companion object { + val legend = SemanticTokensLegend( + TokenType.values().map { it.name }, + TokenModifier.values().map { it.name }, + ) + + fun analyze(fileNode: FileNode, samtPackage: Package) = + SamtSemanticTokens().also { it.analyze(fileNode, samtPackage) } + } +} diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt index ed11a423..8addda78 100644 --- a/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtTextDocumentService.kt @@ -1,11 +1,18 @@ package tools.samt.ls import org.eclipse.lsp4j.* +import org.eclipse.lsp4j.jsonrpc.messages.Either import org.eclipse.lsp4j.services.LanguageClient import org.eclipse.lsp4j.services.LanguageClientAware import org.eclipse.lsp4j.services.TextDocumentService import tools.samt.common.SourceFile +import tools.samt.lexer.Token +import tools.samt.parser.FileNode +import tools.samt.parser.NamedDeclarationNode +import tools.samt.parser.OperationNode +import tools.samt.semantic.Package import java.net.URI +import java.util.concurrent.CompletableFuture import java.util.logging.Logger class SamtTextDocumentService(private val workspaces: Map) : TextDocumentService, @@ -23,7 +30,7 @@ class SamtTextDocumentService(private val workspaces: Map) : val path = params.textDocument.uri.toPathUri() val newText = params.contentChanges.single().text val fileInfo = parseFile(SourceFile(path, newText)) - val workspace = getWorkspace(path) + val workspace = getWorkspace(path) ?: return workspace.add(fileInfo) workspace.buildSemanticModel() @@ -46,10 +53,111 @@ class SamtTextDocumentService(private val workspaces: Map) : logger.info("Saved document ${params.textDocument.uri}") } + override fun definition(params: DefinitionParams): CompletableFuture, List>> = + CompletableFuture.supplyAsync { + val path = params.textDocument.uri.toPathUri() + val workspace = getWorkspace(path) + + val fileInfo = workspace?.get(path) ?: return@supplyAsync Either.forRight(emptyList()) + + val fileNode: FileNode = fileInfo.fileNode ?: return@supplyAsync Either.forRight(emptyList()) + val globalPackage: Package = workspace.samtPackage ?: return@supplyAsync Either.forRight(emptyList()) + + val token = fileInfo.tokens.findAt(params.position) ?: return@supplyAsync Either.forRight(emptyList()) + + val samtPackage = globalPackage.resolveSubPackage(fileNode.packageDeclaration.name) + + val typeLookup = SamtDeclarationLookup.analyze(fileNode, samtPackage) + val type = typeLookup[token.location] ?: return@supplyAsync Either.forRight(emptyList()) + + val definition = type.declaration + val location = definition.location + + val targetLocation = when (definition) { + is NamedDeclarationNode -> definition.name.location + is OperationNode -> definition.name.location + else -> error("Unexpected definition type") + } + val locationLink = LocationLink().apply { + targetUri = location.source.path.toString() + targetRange = location.toRange() + targetSelectionRange = targetLocation.toRange() + } + return@supplyAsync Either.forRight(listOf(locationLink)) + } + + override fun references(params: ReferenceParams): CompletableFuture> = + CompletableFuture.supplyAsync { + val path = params.textDocument.uri.toPathUri() + val workspace = getWorkspace(path) ?: return@supplyAsync emptyList() + + val relevantFileInfo = workspace[path] ?: return@supplyAsync emptyList() + val relevantFileNode = relevantFileInfo.fileNode ?: return@supplyAsync emptyList() + val token = relevantFileInfo.tokens.findAt(params.position) ?: return@supplyAsync emptyList() + + val globalPackage: Package = workspace.samtPackage ?: return@supplyAsync emptyList() + + val typeLookup = SamtDeclarationLookup.analyze(relevantFileNode, globalPackage.resolveSubPackage(relevantFileInfo.fileNode.packageDeclaration.name)) + val type = typeLookup[token.location] ?: return@supplyAsync emptyList() + + val filesAndPackages = buildList { + for (fileInfo in workspace) { + val fileNode: FileNode = fileInfo.fileNode ?: continue + val samtPackage = globalPackage.resolveSubPackage(fileNode.packageDeclaration.name) + add(fileNode to samtPackage) + } + } + + val typeReferencesLookup = SamtReferencesLookup.analyze(filesAndPackages) + + val references = typeReferencesLookup[type] ?: emptyList() + + return@supplyAsync references.map { Location(it.source.path.toString(), it.toRange()) } + } + + override fun semanticTokensFull(params: SemanticTokensParams): CompletableFuture = + CompletableFuture.supplyAsync { + val path = params.textDocument.uri.toPathUri() + val workspace = getWorkspace(path) + + val fileInfo = workspace?.get(path) ?: return@supplyAsync SemanticTokens(emptyList()) + + val tokens: List = fileInfo.tokens + val fileNode: FileNode = fileInfo.fileNode ?: return@supplyAsync SemanticTokens(emptyList()) + val globalPackage: Package = workspace.samtPackage ?: return@supplyAsync SemanticTokens(emptyList()) + val samtPackage = globalPackage.resolveSubPackage(fileNode.packageDeclaration.name) + + val semanticTokens = SamtSemanticTokens.analyze(fileNode, samtPackage) + + var lastLine = 0 + var lastStartChar = 0 + + val encodedData = buildList { + for (token in tokens) { + val (tokenType, modifier) = semanticTokens[token.location] ?: continue + val (_, start, end) = token.location + val line = start.row + val deltaLine = line - lastLine + val startChar = start.col + val deltaStartChar = if (deltaLine == 0) startChar - lastStartChar else startChar + val length = end.charIndex - start.charIndex + add(deltaLine) + add(deltaStartChar) + add(length) + add(tokenType.ordinal) + add(modifier.bitmask) + lastLine = line + lastStartChar = startChar + } + } + + SemanticTokens(encodedData) + } + override fun connect(client: LanguageClient) { this.client = client } - private fun getWorkspace(filePath: URI): SamtWorkspace = - workspaces.values.first { filePath in it } + private fun getWorkspace(filePath: URI): SamtWorkspace? = + workspaces.values.singleOrNull { filePath in it } } diff --git a/language-server/src/main/kotlin/tools/samt/ls/SamtWorkspace.kt b/language-server/src/main/kotlin/tools/samt/ls/SamtWorkspace.kt index 8c03d857..7a181f80 100644 --- a/language-server/src/main/kotlin/tools/samt/ls/SamtWorkspace.kt +++ b/language-server/src/main/kotlin/tools/samt/ls/SamtWorkspace.kt @@ -8,7 +8,8 @@ import java.net.URI class SamtWorkspace(private val parserController: DiagnosticController) : Iterable { private val files = mutableMapOf() - private var samtPackage: Package? = null + var samtPackage: Package? = null + private set private var semanticController: DiagnosticController = DiagnosticController(parserController.workingDirectory) diff --git a/language-server/src/main/kotlin/tools/samt/ls/Tokens.kt b/language-server/src/main/kotlin/tools/samt/ls/Tokens.kt new file mode 100644 index 00000000..7e14d1b0 --- /dev/null +++ b/language-server/src/main/kotlin/tools/samt/ls/Tokens.kt @@ -0,0 +1,34 @@ +package tools.samt.ls + +import org.eclipse.lsp4j.Position +import tools.samt.lexer.StructureToken +import tools.samt.lexer.Token + +/** + * Finds a non-structure token at the given position. + * A token is considered to be at the given position if the position is within the token's location. + * + * For example, given the following source code: + * + * ```samt + *package foo.bar.baz + *``` + * + * Any position within the token `package` will return that token, including bordering positions (before the p or after the e). + * + * @return the token at the given position, or null if there is no token at the given position + */ +fun List.findAt(position: Position): Token? { + val relevantTokens = this.filter { it !is StructureToken } + val tokenIndex = relevantTokens.binarySearch { + when { + it.location.end.row < position.line -> -1 + it.location.start.row > position.line -> 1 + it.location.end.col < position.character -> -1 + it.location.start.col > position.character -> 1 + else -> 0 + } + } + if (tokenIndex < 0) return null + return relevantTokens[tokenIndex] +} diff --git a/language-server/src/test/kotlin/tools/samt/ls/SamtDeclarationLookupTest.kt b/language-server/src/test/kotlin/tools/samt/ls/SamtDeclarationLookupTest.kt new file mode 100644 index 00000000..11a45793 --- /dev/null +++ b/language-server/src/test/kotlin/tools/samt/ls/SamtDeclarationLookupTest.kt @@ -0,0 +1,140 @@ +package tools.samt.ls + +import tools.samt.common.DiagnosticController +import tools.samt.common.SourceFile +import tools.samt.lexer.Lexer +import tools.samt.parser.* +import tools.samt.semantic.SemanticModelBuilder +import java.net.URI +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class SamtDeclarationLookupTest { + @Test + fun `correctly find definition in complex model`() { + val serviceSource = """ + package test + + record Person { + name: List + age: Int + } + + service PersonService { + async getNeighbors(person: Person): Map (*..100) + } + """.trimIndent() + val providerSource = """ + package test + + provide PersonEndpoint { + implements PersonService { getNeighbors } + + transport HTTP + } + """.trimIndent() + val consumerOneSource = """ + import test.* + import test.PersonService as Service + + package some.other.^package + + consume PersonEndpoint { + uses Service { getNeighbors } + } + """.trimIndent() + val consumerTwoSource = """ + package somewhere.else + + consume test.PersonEndpoint { uses test.PersonService } + """.trimIndent() + parseAndCheck( + serviceSource to listOf( + ExpectedDefinition("8:31" to "8:37") { it is RecordDeclarationNode && it.name.name == "Person" }, + ExpectedDefinition("8:52" to "8:58") { it is RecordDeclarationNode && it.name.name == "Person" }, + ), + providerSource to listOf( + ExpectedDefinition("3:15" to "3:28") { it is ServiceDeclarationNode && it.name.name == "PersonService" }, + ExpectedDefinition("3:31" to "3:43") { it is OperationNode && it.name.name == "getNeighbors" }, + ), + consumerOneSource to listOf( + ExpectedDefinition("1:12" to "1:25") { it is ServiceDeclarationNode && it.name.name == "PersonService" }, + ExpectedDefinition("5:8" to "5:22") { it is ProviderDeclarationNode && it.name.name == "PersonEndpoint" }, + ExpectedDefinition("6:9" to "6:16") { it is ServiceDeclarationNode && it.name.name == "PersonService" }, + ExpectedDefinition("6:19" to "6:31") { it is OperationNode && it.name.name == "getNeighbors" }, + ), + consumerTwoSource to listOf( + ExpectedDefinition("2:13" to "2:27") { it is ProviderDeclarationNode && it.name.name == "PersonEndpoint" }, + ExpectedDefinition("2:40" to "2:53") { it is ServiceDeclarationNode && it.name.name == "PersonService" }, + ), + ) + } + + @Test + fun `finds definition for name of the user defined types themselves`() { + val serviceSource = """ + package test + + record Person { + name: List + age: Int + } + + service PersonService { + foo() + } + """.trimIndent() + val providerSource = """ + package test + + provide PersonEndpoint { + implements PersonService + + transport HTTP + } + """.trimIndent() + parseAndCheck( + serviceSource to listOf( + ExpectedDefinition("2:7" to "2:13") { it is RecordDeclarationNode && it.name.name == "Person" }, + ExpectedDefinition("7:8" to "7:21") { it is ServiceDeclarationNode && it.name.name == "PersonService" }, + ExpectedDefinition("8:4" to "8:7") { it is OperationNode && it.name.name == "foo" }, + ), + providerSource to listOf( + ExpectedDefinition("2:8" to "2:22") { it is ProviderDeclarationNode && it.name.name == "PersonEndpoint" }, + ), + ) + } + + private data class ExpectedDefinition(val range: Pair, val matcher: (definition: Node) -> Boolean) { + val testLocation = TestLocation(range) + } + + private fun parseAndCheck( + vararg sourceAndExpectedMessages: Pair>, + ) { + val diagnosticController = DiagnosticController(URI("file:///tmp")) + val fileTree = sourceAndExpectedMessages.mapIndexed { index, (source) -> + val filePath = URI("file:///tmp/SamtDeclarationLookupTest-${index}.samt") + val sourceFile = SourceFile(filePath, source) + val parseContext = diagnosticController.getOrCreateContext(sourceFile) + val stream = Lexer.scan(source.reader(), parseContext) + val fileTree = Parser.parse(sourceFile, stream, parseContext) + assertFalse(parseContext.hasErrors(), "Expected no parse errors, but had errors: ${parseContext.messages}}") + fileTree + } + + val samtPackage = SemanticModelBuilder.build(fileTree, diagnosticController) + + for ((fileNode, expectedMetadata) in fileTree.zip(sourceAndExpectedMessages.map { it.second })) { + val filePackage = samtPackage.resolveSubPackage(fileNode.packageDeclaration.name) + val definitionLookup = SamtDeclarationLookup.analyze(fileNode, filePackage) + for (expected in expectedMetadata) { + val actual = definitionLookup[expected.testLocation.getLocation(fileNode.sourceFile)] + assertNotNull(actual, "No definition found for ${expected.range}") + assertTrue(expected.matcher(actual.declaration), "Matcher for ${expected.range} did not match") + } + } + } +} diff --git a/language-server/src/test/kotlin/tools/samt/ls/SamtReferencesLookupTest.kt b/language-server/src/test/kotlin/tools/samt/ls/SamtReferencesLookupTest.kt new file mode 100644 index 00000000..7f7acea4 --- /dev/null +++ b/language-server/src/test/kotlin/tools/samt/ls/SamtReferencesLookupTest.kt @@ -0,0 +1,98 @@ +package tools.samt.ls + +import tools.samt.common.DiagnosticController +import tools.samt.common.SourceFile +import tools.samt.lexer.Lexer +import tools.samt.parser.Parser +import tools.samt.semantic.Package +import tools.samt.semantic.SemanticModelBuilder +import java.net.URI +import kotlin.test.* + +class SamtReferencesLookupTest { + @Test + fun `correctly find references in complex model`() { + val serviceSource = """ + package test + + record Person { } + + service PersonService { + getNeighbors(person: Person): Map? (*..100) + } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-serviceSource.samt"), it) } + val providerSource = """ + package test + + provide PersonEndpoint { + implements PersonService { getNeighbors } + + transport HTTP + } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-providerSource.samt"), it) } + val consumerOneSource = """ + import test.* + import test.PersonService as Service + + package some.other.^package + + consume PersonEndpoint { + uses Service { getNeighbors } + } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-consumerOneSource.samt"), it) } + val consumerTwoSource = """ + package somewhere.else + + consume test.PersonEndpoint { uses test.PersonService } + """.trimIndent().let { SourceFile(URI("file:///tmp/SamtSemanticTokensTest-consumerTwoSource.samt"), it) } + val (samtPackage, referencesLookup) = parse(serviceSource, providerSource, consumerOneSource, consumerTwoSource) + + val testPackage = samtPackage.subPackages.single { it.name == "test" } + val person = testPackage.records.single { it.name == "Person" } + val personService = testPackage.services.single { it.name == "PersonService" } + val personEndpoint = testPackage.providers.single { it.name == "PersonEndpoint" } + val getNeighbors = personService.operations.single { it.name == "getNeighbors" } + + val personReferences = referencesLookup[person] + assertNotNull(personReferences) + assertEquals(2, personReferences.size, "Following list had unexpected amount of entries: $personReferences") + assertContains(personReferences, TestLocation("5:25" to "5:31").getLocation(serviceSource)) + assertContains(personReferences, TestLocation("5:46" to "5:52").getLocation(serviceSource)) + + val personServiceReferences = referencesLookup[personService] + assertNotNull(personServiceReferences) + assertEquals(4, personServiceReferences.size, "Following list had unexpected amount of entries: $personServiceReferences") + assertContains(personServiceReferences, TestLocation("3:15" to "3:28").getLocation(providerSource)) + assertContains(personServiceReferences, TestLocation("1:12" to "1:25").getLocation(consumerOneSource)) + assertContains(personServiceReferences, TestLocation("6:9" to "6:16").getLocation(consumerOneSource)) + assertContains(personServiceReferences, TestLocation("2:40" to "2:53").getLocation(consumerTwoSource)) + + val personEndpointReferences = referencesLookup[personEndpoint] + assertNotNull(personEndpointReferences) + assertEquals(2, personEndpointReferences.size, "Following list had unexpected amount of entries: $personEndpointReferences") + assertContains(personEndpointReferences, TestLocation("5:8" to "5:22").getLocation(consumerOneSource)) + assertContains(personEndpointReferences, TestLocation("2:13" to "2:27").getLocation(consumerTwoSource)) + + val getNeighborsReferences = referencesLookup[getNeighbors] + assertNotNull(getNeighborsReferences) + assertEquals(2, getNeighborsReferences.size, "Following list had unexpected amount of entries: $getNeighborsReferences") + assertContains(getNeighborsReferences, TestLocation("3:31" to "3:43").getLocation(providerSource)) + assertContains(getNeighborsReferences, TestLocation("6:19" to "6:31").getLocation(consumerOneSource)) + } + + private fun parse(vararg sourceAndExpectedMessages: SourceFile): Pair { + val diagnosticController = DiagnosticController(URI("file:///tmp")) + val fileTree = sourceAndExpectedMessages.map { sourceFile -> + val parseContext = diagnosticController.getOrCreateContext(sourceFile) + val stream = Lexer.scan(sourceFile.content.reader(), parseContext) + val fileTree = Parser.parse(sourceFile, stream, parseContext) + assertFalse(parseContext.hasErrors(), "Expected no parse errors, but had errors: ${parseContext.messages}}") + fileTree + } + + val samtPackage = SemanticModelBuilder.build(fileTree, diagnosticController) + + val filesAndPackages = fileTree.map { it to samtPackage.resolveSubPackage(it.packageDeclaration.name) } + return Pair(samtPackage, SamtReferencesLookup.analyze(filesAndPackages)) + } +} diff --git a/language-server/src/test/kotlin/tools/samt/ls/SamtSemanticTokensTest.kt b/language-server/src/test/kotlin/tools/samt/ls/SamtSemanticTokensTest.kt new file mode 100644 index 00000000..bbbce1d4 --- /dev/null +++ b/language-server/src/test/kotlin/tools/samt/ls/SamtSemanticTokensTest.kt @@ -0,0 +1,174 @@ +package tools.samt.ls + +import tools.samt.common.DiagnosticController +import tools.samt.common.SourceFile +import tools.samt.lexer.Lexer +import tools.samt.parser.Parser +import tools.samt.semantic.SemanticModelBuilder +import java.net.URI +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import tools.samt.ls.SamtSemanticTokens.Metadata as Meta +import tools.samt.ls.SamtSemanticTokens.TokenModifier as Mod +import tools.samt.ls.SamtSemanticTokens.TokenType as T + +class SamtSemanticTokensTest { + @Test + fun `correctly tokenizes complex model`() { + val serviceSource = """ + package test + + enum Age { Underage, Legal, Senior } + + record Person { + name: List + age: Age + } + + service PersonService { + async getNeighbors(person: Person): Map (*..100) + oneway reloadNeighbors() + } + """.trimIndent() + val providerSource = """ + package test + + provide PersonEndpoint { + implements PersonService { getNeighbors, reloadNeighbors } + + transport HTTP + } + """.trimIndent() + val consumerOneSource = """ + import test.* + import test.PersonService as Service + + package some.other.^package + + consume PersonEndpoint { + uses Service { getNeighbors } + } + """.trimIndent() + val consumerTwoSource = """ + package somewhere.else + + consume test.PersonEndpoint { uses test.PersonService } + """.trimIndent() + parseAndCheck( + serviceSource to listOf( + ExpectedMetadata("0:8" to "0:12", Meta(T.namespace)), + ExpectedMetadata("2:5" to "2:8", Meta(T.enum, Mod.declaration)), + ExpectedMetadata("2:11" to "2:19", Meta(T.enumMember, Mod.declaration)), + ExpectedMetadata("2:21" to "2:26", Meta(T.enumMember, Mod.declaration)), + ExpectedMetadata("2:28" to "2:34", Meta(T.enumMember, Mod.declaration)), + ExpectedMetadata("4:7" to "4:13", Meta(T.`class`, Mod.declaration)), + ExpectedMetadata("5:4" to "5:8", Meta(T.property, Mod.declaration)), + ExpectedMetadata("5:10" to "5:14", Meta(T.type, Mod.defaultLibrary)), + ExpectedMetadata("5:15" to "5:21", Meta(T.type, Mod.defaultLibrary)), + ExpectedMetadata("5:24" to "5:28", Meta(T.function, Mod.defaultLibrary)), + ExpectedMetadata("6:4" to "6:7", Meta(T.property, Mod.declaration)), + ExpectedMetadata("6:9" to "6:12", Meta(T.enum)), + ExpectedMetadata("9:8" to "9:21", Meta(T.`interface`, Mod.declaration)), + ExpectedMetadata("10:10" to "10:22", Meta(T.method, Mod.declaration and Mod.async)), + ExpectedMetadata("10:23" to "10:29", Meta(T.parameter, Mod.declaration)), + ExpectedMetadata("10:31" to "10:37", Meta(T.`class`)), + ExpectedMetadata("10:40" to "10:43", Meta(T.type, Mod.defaultLibrary)), + ExpectedMetadata("10:44" to "10:50", Meta(T.type, Mod.defaultLibrary)), + ExpectedMetadata("10:52" to "10:58", Meta(T.`class`)), + ExpectedMetadata("11:11" to "11:26", Meta(T.method, Mod.declaration)), + ), + providerSource to listOf( + ExpectedMetadata("0:8" to "0:12", Meta(T.namespace)), + ExpectedMetadata("2:8" to "2:22", Meta(T.type, Mod.declaration)), + ExpectedMetadata("3:15" to "3:28", Meta(T.`interface`)), + ExpectedMetadata("3:31" to "3:43", Meta(T.method, Mod.async)), + ExpectedMetadata("3:45" to "3:60", Meta(T.method)), + ), + consumerOneSource to listOf( + ExpectedMetadata("0:7" to "0:11", Meta(T.namespace)), + ExpectedMetadata("1:12" to "1:25", Meta(T.`interface`)), + ExpectedMetadata("1:29" to "1:36", Meta(T.`interface`, Mod.declaration)), + ExpectedMetadata("3:19" to "3:27", Meta(T.namespace)), + ExpectedMetadata("5:8" to "5:22", Meta(T.type)), + ExpectedMetadata("6:9" to "6:16", Meta(T.`interface`)), + ExpectedMetadata("6:19" to "6:31", Meta(T.method, Mod.async)), + ), + consumerTwoSource to listOf( + ExpectedMetadata("0:18" to "0:22", Meta(T.namespace)), + ExpectedMetadata("2:13" to "2:27", Meta(T.type)), + ExpectedMetadata("2:40" to "2:53", Meta(T.`interface`)), + ), + ) + } + + @Test + fun `correctly tokenizes somewhat broken models`() { + val serviceSource = """ + package broken + + record Person { + name: Name + } + + @Description("A enum for people") + enum Person { + foo, + bar + } + + service Person { + async foo() + oneway foo() + foo(): Name + } + """.trimIndent() + parseAndCheck( + serviceSource to listOf( + ExpectedMetadata("0:8" to "0:14", Meta(T.namespace)), + ExpectedMetadata("2:7" to "2:13", Meta(T.`class`, Mod.declaration)), + ExpectedMetadata("3:4" to "3:8", Meta(T.property, Mod.declaration)), + ExpectedMetadata("3:10" to "3:14", Meta(T.type)), + ExpectedMetadata("6:1" to "6:12", Meta(T.type, Mod.defaultLibrary)), + ExpectedMetadata("7:5" to "7:11", Meta(T.enum, Mod.declaration)), + ExpectedMetadata("8:4" to "8:7", Meta(T.enumMember, Mod.declaration)), + ExpectedMetadata("9:4" to "9:7", Meta(T.enumMember, Mod.declaration)), + ExpectedMetadata("12:8" to "12:14", Meta(T.`interface`, Mod.declaration)), + ExpectedMetadata("13:10" to "13:13", Meta(T.method, Mod.declaration and Mod.async)), + ExpectedMetadata("14:11" to "14:14", Meta(T.method, Mod.declaration)), + ExpectedMetadata("15:4" to "15:7", Meta(T.method, Mod.declaration)), + ExpectedMetadata("15:11" to "15:15", Meta(T.type)), + ), + ) + } + + private data class ExpectedMetadata(val range: Pair, val metadata: Meta) { + val testLocation = TestLocation(range) + } + + private fun parseAndCheck( + vararg sourceAndExpectedMessages: Pair>, + ) { + val diagnosticController = DiagnosticController(URI("file:///tmp")) + val fileTree = sourceAndExpectedMessages.mapIndexed { index, (source) -> + val filePath = URI("file:///tmp/SamtSemanticTokensTest-${index}.samt") + val sourceFile = SourceFile(filePath, source) + val parseContext = diagnosticController.getOrCreateContext(sourceFile) + val stream = Lexer.scan(source.reader(), parseContext) + val fileTree = Parser.parse(sourceFile, stream, parseContext) + assertFalse(parseContext.hasErrors(), "Expected no parse errors, but had errors: ${parseContext.messages}}") + fileTree + } + + val samtPackage = SemanticModelBuilder.build(fileTree, diagnosticController) + + for ((fileNode, expectedMetadata) in fileTree.zip(sourceAndExpectedMessages.map { it.second })) { + val filePackage = samtPackage.resolveSubPackage(fileNode.packageDeclaration.name) + val semanticTokens = SamtSemanticTokens.analyze(fileNode, filePackage) + for (expected in expectedMetadata) { + val actual = semanticTokens[expected.testLocation.getLocation(fileNode.sourceFile)] + assertEquals(expected.metadata, actual, "Metadata for ${expected.range} did not match") + } + } + } +} diff --git a/language-server/src/test/kotlin/tools/samt/ls/TestLocation.kt b/language-server/src/test/kotlin/tools/samt/ls/TestLocation.kt new file mode 100644 index 00000000..dac997b4 --- /dev/null +++ b/language-server/src/test/kotlin/tools/samt/ls/TestLocation.kt @@ -0,0 +1,44 @@ +package tools.samt.ls + +import tools.samt.common.FileOffset +import tools.samt.common.Location +import tools.samt.common.SourceFile + +data class TestLocation(val range: Pair) { + private val startRow = range.first.substringBefore(":").toInt() + private val startCol = range.first.substringAfter(":").toInt() + private val endRow = range.second.substringBefore(":").toInt() + private val endCol = range.second.substringAfter(":").toInt() + private fun countUntil(source: String, row: Int, col: Int): Int { + var currentRow = 0 + var currentCol = 0 + var currentIndex = 0 + for (c in source) { + if (currentRow == row && currentCol == col) { + return currentIndex + } + currentIndex++ + if (c == '\n') { + currentRow++ + currentCol = 0 + } else { + currentCol++ + } + } + return -1 + } + + fun getLocation(file: SourceFile) = Location( + source = file, + start = FileOffset( + charIndex = countUntil(file.content, startRow, startCol), + row = startRow, + col = startCol, + ), + end = FileOffset( + charIndex = countUntil(file.content, endRow, endCol), + row = endRow, + col = endCol, + ), + ) +} diff --git a/language-server/src/test/kotlin/tools/samt/ls/TokensTest.kt b/language-server/src/test/kotlin/tools/samt/ls/TokensTest.kt new file mode 100644 index 00000000..c3bc5537 --- /dev/null +++ b/language-server/src/test/kotlin/tools/samt/ls/TokensTest.kt @@ -0,0 +1,36 @@ +package tools.samt.ls + +import org.eclipse.lsp4j.Position +import tools.samt.common.DiagnosticContext +import tools.samt.common.SourceFile +import tools.samt.lexer.Lexer +import java.net.URI +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull + +class TokensTest { + @Test + fun `finds token at position`() { + val source = """ + package foo.bar.baz + + record Foo { + bar: Int + } + """.trimIndent() + val tokens = Lexer.scan(source.reader(), DiagnosticContext(SourceFile(URI("file:///tmp/test.samt"), source))).toList() + + assertEquals(tokens[0], tokens.findAt(Position(0, 0))) + assertEquals(tokens[0], tokens.findAt(Position(0, 3))) + assertEquals(tokens[0], tokens.findAt(Position(0, 7))) + assertEquals(tokens[1], tokens.findAt(Position(0, 8))) + assertEquals(tokens[3], tokens.findAt(Position(0, 12))) + assertEquals(tokens[3], tokens.findAt(Position(0, 15))) + assertEquals(tokens[5], tokens.findAt(Position(0, 19))) + assertEquals(tokens[11], tokens.findAt(Position(3, 10))) + + assertNull(tokens.findAt(Position(1, 0))) + assertNull(tokens.findAt(Position(2, 11))) + } +} diff --git a/semantic/src/main/kotlin/tools/samt/semantic/Package.kt b/semantic/src/main/kotlin/tools/samt/semantic/Package.kt index 8a010f43..fa8e26e5 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/Package.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/Package.kt @@ -1,7 +1,6 @@ package tools.samt.semantic -import tools.samt.parser.IdentifierNode -import tools.samt.parser.NamedDeclarationNode +import tools.samt.parser.* class Package(val name: String) { val subPackages: MutableList = mutableListOf() @@ -14,44 +13,69 @@ class Package(val name: String) { val providers: MutableList = mutableListOf() val consumers: MutableList = mutableListOf() + val typeByNode: MutableMap = mutableMapOf() + val types: MutableMap = mutableMapOf() + inline fun getTypeOrNullByNode(node: Node): T? { + val type = typeByNode[node] + check(type is T?) { "Expected type ${T::class.simpleName} for ${node.javaClass.simpleName} at ${node.location} but got ${type!!.javaClass.simpleName}" } + return type + } + + inline fun getTypeByNode(node: Node): T { + val type = getTypeOrNullByNode(node) + checkNotNull(type) { "No type found for node of type ${node.javaClass.simpleName} at ${node.location}" } + return type + } + + fun resolveSubPackage(name: BundleIdentifierNode): Package { + var samtPackage = this + for (namespace in name.components) { + samtPackage = samtPackage.subPackages.first { it.name == namespace.name } + } + return samtPackage + } + fun resolveType(identifier: IdentifierNode): Type? = subPackages.find { it.name == identifier.name }?.let { PackageType(it) } ?: types[identifier.name] + fun linkType(source: Node, type: Type) { + typeByNode[source] = type + } + operator fun plusAssign(record: RecordType) { records.add(record) types[record.name] = record + typeByNode[record.declaration] = record } operator fun plusAssign(enum: EnumType) { enums.add(enum) types[enum.name] = enum + typeByNode[enum.declaration] = enum } operator fun plusAssign(service: ServiceType) { services.add(service) types[service.name] = service + typeByNode[service.declaration] = service } operator fun plusAssign(provider: ProviderType) { providers.add(provider) types[provider.name] = provider + typeByNode[provider.declaration] = provider } operator fun plusAssign(consumer: ConsumerType) { consumers.add(consumer) + typeByNode[consumer.declaration] = consumer } - operator fun contains(declaration: NamedDeclarationNode): Boolean = - types.containsKey(declaration.name.name) - operator fun contains(identifier: IdentifierNode): Boolean = types.containsKey(identifier.name) - operator fun contains(name: String): Boolean = - types.containsKey(name) - val allSubPackages: List get() = subPackages + subPackages.flatMap { it.allSubPackages } } diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt index 8258ef55..ba08c2bb 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt @@ -1,7 +1,6 @@ package tools.samt.semantic import tools.samt.common.DiagnosticController -import tools.samt.common.Location import tools.samt.common.SourceFile import tools.samt.parser.* @@ -17,49 +16,12 @@ class SemanticModelBuilder private constructor( private val controller: DiagnosticController, ) { private val global = Package(name = "") - private val constraintBuilder = ConstraintBuilder(controller) + private val preProcessor = SemanticModelPreProcessor(controller) private val postProcessor = SemanticModelPostProcessor(controller) - - private inline fun ensureNameIsAvailable( - parentPackage: Package, - statement: NamedDeclarationNode, - block: () -> Unit, - ) { - if (statement.name !in parentPackage) { - block() - } else { - val existingType = parentPackage.types.getValue(statement.name.name) - controller.getOrCreateContext(statement.location.source).error { - message("'${statement.name.name}' is already declared") - highlight("duplicate declaration", statement.name.location) - if (existingType is UserDefinedType) { - highlight("previous declaration", existingType.definition.location) - } - } - } - } - - private inline fun reportDuplicates( - items: List, - what: String, - identifierGetter: (node: T) -> IdentifierNode, - ) { - val existingItems = mutableMapOf() - for (item in items) { - val name = identifierGetter(item).name - val existingLocation = existingItems.putIfAbsent(name, item.location) - if (existingLocation != null) { - controller.getOrCreateContext(item.location.source).error { - message("$what '$name' is defined more than once") - highlight("duplicate declaration", identifierGetter(item).location) - highlight("previous declaration", existingLocation) - } - } - } - } + private val referenceResolver = SemanticModelReferenceResolver(controller, global) private fun build(): Package { - buildPackages() + preProcessor.fillPackage(global, files) val fileScopeBySource = files.associate { it.sourceFile to createFileScope(it) } @@ -70,134 +32,12 @@ class SemanticModelBuilder private constructor( return global } - private fun buildPackages() { - for (file in files) { - var parentPackage = global - for (component in file.packageDeclaration.name.components) { - var subPackage = parentPackage.subPackages.find { it.name == component.name } - if (subPackage == null) { - subPackage = Package(component.name) - parentPackage.subPackages.add(subPackage) - } - parentPackage = subPackage - } - - for (statement in file.statements) { - when (statement) { - is RecordDeclarationNode -> { - ensureNameIsAvailable(parentPackage, statement) { - reportDuplicates(statement.fields, "Record field") { it.name } - if (statement.extends.isNotEmpty()) { - controller.getOrCreateContext(statement.location.source).error { - message("Record extends are not yet supported") - highlight("cannot extend other records", statement.extends.first().location) - } - } - val fields = statement.fields.map { field -> - RecordType.Field(field.name.name, UnresolvedTypeReference(field.type)) - } - parentPackage += RecordType(statement.name.name, fields, statement) - } - } - - is EnumDeclarationNode -> { - ensureNameIsAvailable(parentPackage, statement) { - reportDuplicates(statement.values, "Enum value") { it } - val values = statement.values.map { it.name } - parentPackage += EnumType(statement.name.name, values, statement) - } - } - - is ServiceDeclarationNode -> { - ensureNameIsAvailable(parentPackage, statement) { - reportDuplicates(statement.operations, "Operation") { it.name } - val operations = statement.operations.map { operation -> - reportDuplicates(operation.parameters, "Parameter") { it.name } - val parameters = operation.parameters.map { parameter -> - ServiceType.Operation.Parameter( - name = parameter.name.name, - type = UnresolvedTypeReference(parameter.type) - ) - } - when (operation) { - is OnewayOperationNode -> { - ServiceType.OnewayOperation( - name = operation.name.name, - parameters = parameters, - ) - } - - is RequestResponseOperationNode -> { - if (operation.isAsync) { - controller.getOrCreateContext(operation.location.source).error { - message("Async operations are not yet supported") - highlight("unsupported async operation", operation.location) - } - } - ServiceType.RequestResponseOperation( - name = operation.name.name, - parameters = parameters, - returnType = operation.returnType?.let { UnresolvedTypeReference(it) }, - raisesTypes = operation.raises.map { UnresolvedTypeReference(it) }, - ) - } - } - } - parentPackage += ServiceType(statement.name.name, operations, statement) - } - } - - is ProviderDeclarationNode -> { - ensureNameIsAvailable(parentPackage, statement) { - val implements = statement.implements.map { implements -> - ProviderType.Implements( - UnresolvedTypeReference(implements.serviceName), - emptyList(), - implements - ) - } - val transport = ProviderType.Transport( - name = statement.transport.protocolName.name, - configuration = statement.transport.configuration - ) - parentPackage += ProviderType(statement.name.name, implements, transport, statement) - } - } - - is ConsumerDeclarationNode -> { - parentPackage += ConsumerType( - provider = UnresolvedTypeReference(statement.providerName), - uses = statement.usages.map { - ConsumerType.Uses( - service = UnresolvedTypeReference(it.serviceName), - operations = emptyList(), - definition = it - ) - }, - definition = statement - ) - } - - is TypeAliasNode -> { - controller.getOrCreateContext(statement.location.source).error { - message("Type aliases are not yet supported") - highlight("unsupported feature", statement.location) - } - } - - is PackageDeclarationNode, - is ImportNode, - -> Unit - } - } - } - } - private fun resolveTypes(fileScopeBySource: Map) { fun TypeReference.resolve(): ResolvedTypeReference { check(this is UnresolvedTypeReference) { "Type reference must be unresolved" } - return resolveExpression(fileScopeBySource, expression) + val fileScope = fileScopeBySource.getValue(expression.location.source) + return referenceResolver.resolveAndLinkExpression(fileScope, expression) } for (subPackage in global.allSubPackages) { @@ -234,51 +74,14 @@ class SemanticModelBuilder private constructor( } } - private fun resolveType(bundleIdentifierNode: BundleIdentifierNode) = resolveType(bundleIdentifierNode.components) - private fun resolveType(components: List, start: Package = global): Type? { - var currentPackage = start - val iterator = components.listIterator() - while (iterator.hasNext()) { - val component = iterator.next() - when (val resolvedType = currentPackage.resolveType(component)) { - is PackageType -> { - currentPackage = resolvedType.sourcePackage - } - - null -> { - controller.getOrCreateContext(component.location.source).error { - message("Could not resolve reference '${component.name}'") - highlight("unresolved reference", component.location) - } - return null - } - - else -> { - if (iterator.hasNext()) { - // We resolved a non-package type but there are still components left - - controller.getOrCreateContext(component.location.source).error { - message("Type '${component.name}' is not a package, cannot access sub-types") - highlight("must be a package", component.location) - } - return null - } - return resolvedType - } - } - } - - return PackageType(currentPackage) - } - data class FileScope(val filePackage: PackageType, val typeLookup: Map) private fun createFileScope(file: FileNode): FileScope { - // Add all types from the file package - val filePackage = resolveType(file.packageDeclaration.name) + val filePackage = referenceResolver.resolveType(file.packageDeclaration.name) check(filePackage is PackageType) val typeLookup: Map = buildMap { + // Add all types from the file package putAll(filePackage.sourcePackage.types) // Add all imports to scope @@ -288,8 +91,8 @@ class SemanticModelBuilder private constructor( controller.getOrCreateContext(file.sourceFile).error { message("Import '$name' conflicts with locally defined type with same name") highlight("conflicting import", import.location) - if (existingType is UserDefinedType) { - highlight("local type with same name", existingType.definition.location) + if (existingType is UserDeclared) { + highlight("local type with same name", existingType.declaration.location) } } } @@ -297,8 +100,10 @@ class SemanticModelBuilder private constructor( when (import) { is TypeImportNode -> { // Just import one type - val type = resolveType(import.name) + val type = referenceResolver.resolveType(import.name) if (type != null) { + filePackage.sourcePackage.linkType(import, type) + val name = if (import.alias != null) { import.alias!!.name } else { @@ -311,8 +116,9 @@ class SemanticModelBuilder private constructor( is WildcardImportNode -> { // Import all types from the package - val type = resolveType(import.name) + val type = referenceResolver.resolveType(import.name) if (type != null) { + filePackage.sourcePackage.linkType(import, type) if (type is PackageType) { type.sourcePackage.types.forEach { (name, type) -> addImportedType(name, type) @@ -340,8 +146,8 @@ class SemanticModelBuilder private constructor( putIfAbsent(name, type)?.let { existingType -> controller.getOrCreateContext(file.sourceFile).error { message("Type '$name' shadows built-in type with same name") - if (existingType is UserDefinedType) { - val definition = existingType.definition + if (existingType is UserDeclared) { + val definition = existingType.declaration if (definition is NamedDeclarationNode) { highlight("Shadows built-in type", definition.name.location) } else { @@ -367,141 +173,6 @@ class SemanticModelBuilder private constructor( return FileScope(filePackage, typeLookup) } - private fun resolveExpression( - fileScopes: Map, - rootExpression: ExpressionNode, - ): ResolvedTypeReference { - fun resolveExpression(expression: ExpressionNode): ResolvedTypeReference { - val scope = fileScopes.getValue(expression.location.source) - when (expression) { - is IdentifierNode -> { - scope.typeLookup[expression.name]?.let { - return ResolvedTypeReference(expression, it) - } - - controller.getOrCreateContext(expression.location.source).error { - message("Type '${expression.name}' could not be resolved") - highlight("unresolved type", expression.location) - } - } - - is BundleIdentifierNode -> { - // Bundle identifiers with one component are treated like normal identifiers - if (expression.components.size == 1) { - return resolveExpression(expression.components.first()) - } - // Type is foo.bar.Baz - // Resolve foo first, it must be a package - when (val expectedPackageType = scope.typeLookup[expression.components.first().name]) { - is PackageType -> { - resolveType( - expression.components.subList(1, expression.components.size), - expectedPackageType.sourcePackage - )?.let { - return ResolvedTypeReference(expression, it) - } - } - - null -> { - controller.getOrCreateContext(expression.location.source).error { - message("Type '${expression.name}' could not be resolved") - highlight("unresolved type", expression.location) - } - } - - else -> { - controller.getOrCreateContext(expression.location.source).error { - message("Type '${expression.components.first().name}' is not a package, cannot access sub-types") - highlight("not a package", expression.components.first().location) - } - } - } - } - - is CallExpressionNode -> { - val baseType = resolveExpression(expression.base) - val constraints = expression.arguments.mapNotNull { constraintBuilder.build(baseType.type, it) } - if (baseType.constraints.isNotEmpty()) { - controller.getOrCreateContext(expression.location.source).error { - message("Cannot have nested constraints") - highlight("illegal nested constraint", expression.location) - } - } - return baseType.copy(constraints = constraints) - } - - is GenericSpecializationNode -> { - val name = expression.base.let { - when (it) { - is IdentifierNode -> it.name - is BundleIdentifierNode -> it.name - else -> null - } - } - when (name) { - "List" -> { - if (expression.arguments.size == 1) { - return ResolvedTypeReference( - expression, - ListType(resolveExpression(expression.arguments[0])) - ) - } - } - - "Map" -> { - if (expression.arguments.size == 2) { - return ResolvedTypeReference( - expression, - MapType( - keyType = resolveExpression(expression.arguments[0]), - valueType = resolveExpression(expression.arguments[1]) - ) - ) - } - } - } - controller.getOrCreateContext(expression.location.source).error { - message("Unsupported generic type") - highlight(expression.location) - help("Valid generic types are List and Map") - } - } - - is OptionalDeclarationNode -> { - val baseType = resolveExpression(expression.base) - if (baseType.isOptional) { - controller.getOrCreateContext(expression.location.source).warn { - message("Type is already optional, ignoring '?'") - highlight("already optional", expression.base.location) - } - } - return baseType.copy(isOptional = true) - } - - is BooleanNode, - is NumberNode, - is StringNode, - -> controller.getOrCreateContext(expression.location.source).error { - message("Cannot use literal value as type") - highlight("not a type expression", expression.location) - } - - is ObjectNode, - is ArrayNode, - is RangeExpressionNode, - is WildcardNode, - -> controller.getOrCreateContext(expression.location.source).error { - message("Invalid type expression") - highlight("not a type expression", expression.location) - } - } - - return ResolvedTypeReference(expression, UnknownType) - } - - return resolveExpression(rootExpression) - } - companion object { fun build(files: List, controller: DiagnosticController): Package { // Sort by path to ensure deterministic order diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt index 14e056b3..7243c2f7 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt @@ -27,24 +27,24 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont check(typeReference is ResolvedTypeReference) when (val type = typeReference.type) { is ServiceType -> { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.typeNode.location.source).error { // error message applies to both record fields and return types message("Cannot use service '${type.name}' as type") - highlight("service type not allowed here", typeReference.definition.location) + highlight("service type not allowed here", typeReference.typeNode.location) } } is ProviderType -> { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.typeNode.location.source).error { message("Cannot use provider '${type.name}' as type") - highlight("provider type not allowed here", typeReference.definition.location) + highlight("provider type not allowed here", typeReference.typeNode.location) } } is PackageType -> { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.typeNode.location.source).error { message("Cannot use package '${type.packageName}' as type") - highlight("package type not allowed here", typeReference.definition.location) + highlight("package type not allowed here", typeReference.typeNode.location) } } @@ -64,7 +64,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont private inline fun checkServiceType(typeReference: TypeReference, block: (serviceType: ServiceType) -> Unit) { check(typeReference is ResolvedTypeReference) if (typeReference.constraints.isNotEmpty()) { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.fullNode.location.source).error { message("Cannot have constraints on service") for (constraint in typeReference.constraints) { highlight("illegal constraint", constraint.node.location) @@ -72,9 +72,9 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont } } if (typeReference.isOptional) { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.fullNode.location.source).error { message("Cannot have optional service") - highlight("illegal optional", typeReference.definition.location) + highlight("illegal optional", typeReference.fullNode.location) } } when (val type = typeReference.type) { @@ -84,9 +84,9 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont is UnknownType -> Unit else -> { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.typeNode.location.source).error { message("Expected a service but got '${type.humanReadableName}'") - highlight("illegal type", typeReference.definition.location) + highlight("illegal type", typeReference.typeNode.location) } } } @@ -95,7 +95,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont private inline fun checkProviderType(typeReference: TypeReference, block: (providerType: ProviderType) -> Unit) { check(typeReference is ResolvedTypeReference) if (typeReference.constraints.isNotEmpty()) { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.fullNode.location.source).error { message("Cannot have constraints on provider") for (constraint in typeReference.constraints) { highlight("illegal constraint", constraint.node.location) @@ -103,9 +103,9 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont } } if (typeReference.isOptional) { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.fullNode.location.source).error { message("Cannot have optional provider") - highlight("illegal optional", typeReference.definition.location) + highlight("illegal optional", typeReference.fullNode.location) } } when (val type = typeReference.type) { @@ -115,9 +115,9 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont is UnknownType -> Unit else -> { - controller.getOrCreateContext(typeReference.definition.location.source).error { + controller.getOrCreateContext(typeReference.typeNode.location.source).error { message("Expected a provider but got '${type.humanReadableName}'") - highlight("illegal type", typeReference.definition.location) + highlight("illegal type", typeReference.typeNode.location) } } } @@ -143,24 +143,24 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont val implementsTypes = mutableMapOf() provider.implements.forEach { implements -> checkServiceType(implements.service) { type -> - implementsTypes.putIfAbsent(type, implements.definition.location)?.let { existingLocation -> - controller.getOrCreateContext(implements.definition.location.source).error { + implementsTypes.putIfAbsent(type, implements.node.location)?.let { existingLocation -> + controller.getOrCreateContext(implements.node.location.source).error { message("Service '${type.name}' already implemented") - highlight("duplicate implements", implements.definition.location) + highlight("duplicate implements", implements.node.location) highlight("previous implements", existingLocation) } return@forEach } - implements.operations = if (implements.definition.serviceOperationNames.isEmpty()) { + implements.operations = if (implements.node.serviceOperationNames.isEmpty()) { type.operations } else { - implements.definition.serviceOperationNames.mapNotNull { serviceOperationName -> + implements.node.serviceOperationNames.mapNotNull { serviceOperationName -> val matchingOperation = type.operations.find { it.name == serviceOperationName.name } if (matchingOperation != null) { matchingOperation } else { - controller.getOrCreateContext(provider.definition.location.source).error { + controller.getOrCreateContext(provider.declaration.location.source).error { message("Operation '${serviceOperationName.name}' not found in service '${type.name}'") highlight("unknown operation", serviceOperationName.location) } @@ -177,10 +177,10 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont checkProviderType(consumer.provider) { providerType -> consumer.uses.forEach { uses -> checkServiceType(uses.service) { type -> - usesTypes.putIfAbsent(type, uses.definition.location)?.let { existingLocation -> - controller.getOrCreateContext(uses.definition.location.source).error { + usesTypes.putIfAbsent(type, uses.node.location)?.let { existingLocation -> + controller.getOrCreateContext(uses.node.location.source).error { message("Service '${type.name}' already used") - highlight("duplicate uses", uses.definition.location) + highlight("duplicate uses", uses.node.location) highlight("previous uses", existingLocation) } return@forEach @@ -189,28 +189,28 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont val matchingImplements = providerType.implements.find { (it.service as ResolvedTypeReference).type == type } if (matchingImplements == null) { - controller.getOrCreateContext(uses.definition.location.source).error { + controller.getOrCreateContext(uses.node.location.source).error { message("Service '${type.name}' is not implemented by provider '${providerType.name}'") - highlight("unavailable service", uses.definition.serviceName.location) + highlight("unavailable service", uses.node.serviceName.location) } return@forEach } - uses.operations = if (uses.definition.serviceOperationNames.isEmpty()) { + uses.operations = if (uses.node.serviceOperationNames.isEmpty()) { matchingImplements.operations } else { - uses.definition.serviceOperationNames.mapNotNull { serviceOperationName -> + uses.node.serviceOperationNames.mapNotNull { serviceOperationName -> val matchingOperation = matchingImplements.operations.find { it.name == serviceOperationName.name } if (matchingOperation != null) { matchingOperation } else { if (type.operations.any { it.name == serviceOperationName.name }) { - controller.getOrCreateContext(uses.definition.location.source).error { + controller.getOrCreateContext(uses.node.location.source).error { message("Operation '${serviceOperationName.name}' in service '${type.name}' is not implemented by provider '${providerType.name}'") highlight("unavailable operation", serviceOperationName.location) } } else { - controller.getOrCreateContext(uses.definition.location.source).error { + controller.getOrCreateContext(uses.node.location.source).error { message("Operation '${serviceOperationName.name}' not found in service '${type.name}'") highlight("unknown operation", serviceOperationName.location) } diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPreProcessor.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPreProcessor.kt new file mode 100644 index 00000000..068388eb --- /dev/null +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPreProcessor.kt @@ -0,0 +1,174 @@ +package tools.samt.semantic + +import tools.samt.common.DiagnosticController +import tools.samt.common.Location +import tools.samt.parser.* + +internal class SemanticModelPreProcessor(private val controller: DiagnosticController) { + + private fun reportDuplicateDeclaration( + parentPackage: Package, + statement: NamedDeclarationNode, + ) { + if (statement.name in parentPackage) { + val existingType = parentPackage.types.getValue(statement.name.name) + controller.getOrCreateContext(statement.location.source).error { + message("'${statement.name.name}' is already declared") + highlight("duplicate declaration", statement.name.location) + if (existingType is UserDeclared) { + highlight("previous declaration", existingType.declaration.location) + } + } + } + } + + private inline fun reportDuplicates( + items: List, + what: String, + identifierGetter: (node: T) -> IdentifierNode, + ) { + val existingItems = mutableMapOf() + for (item in items) { + val name = identifierGetter(item).name + val existingLocation = existingItems.putIfAbsent(name, item.location) + if (existingLocation != null) { + controller.getOrCreateContext(item.location.source).error { + message("$what '$name' is defined more than once") + highlight("duplicate declaration", identifierGetter(item).location) + highlight("previous declaration", existingLocation) + } + } + } + } + + fun fillPackage(samtPackage: Package, files: List) { + for (file in files) { + var parentPackage = samtPackage + for (component in file.packageDeclaration.name.components) { + var subPackage = parentPackage.subPackages.find { it.name == component.name } + if (subPackage == null) { + subPackage = Package(component.name) + parentPackage.subPackages.add(subPackage) + } + parentPackage = subPackage + } + + for (statement in file.statements) { + when (statement) { + is RecordDeclarationNode -> { + reportDuplicateDeclaration(parentPackage, statement) + reportDuplicates(statement.fields, "Record field") { it.name } + if (statement.extends.isNotEmpty()) { + controller.getOrCreateContext(statement.location.source).error { + message("Record extends are not yet supported") + highlight("cannot extend other records", statement.extends.first().location) + } + } + val fields = statement.fields.map { field -> + RecordType.Field( + name = field.name.name, + type = UnresolvedTypeReference(field.type), + declaration = field + ) + } + parentPackage += RecordType( + name = statement.name.name, + fields = fields, + declaration = statement + ) + } + + is EnumDeclarationNode -> { + reportDuplicateDeclaration(parentPackage, statement) + reportDuplicates(statement.values, "Enum value") { it } + val values = statement.values.map { it.name } + parentPackage += EnumType(statement.name.name, values, statement) + } + + is ServiceDeclarationNode -> { + reportDuplicateDeclaration(parentPackage, statement) + reportDuplicates(statement.operations, "Operation") { it.name } + val operations = statement.operations.map { operation -> + reportDuplicates(operation.parameters, "Parameter") { it.name } + val parameters = operation.parameters.map { parameter -> + ServiceType.Operation.Parameter( + name = parameter.name.name, + type = UnresolvedTypeReference(parameter.type), + declaration = parameter, + ) + } + when (operation) { + is OnewayOperationNode -> { + ServiceType.OnewayOperation( + name = operation.name.name, + parameters = parameters, + declaration = operation, + ) + } + + is RequestResponseOperationNode -> { + if (operation.isAsync) { + controller.getOrCreateContext(operation.location.source).error { + message("Async operations are not yet supported") + highlight("unsupported async operation", operation.location) + } + } + ServiceType.RequestResponseOperation( + name = operation.name.name, + parameters = parameters, + declaration = operation, + returnType = operation.returnType?.let { UnresolvedTypeReference(it) }, + raisesTypes = operation.raises.map { UnresolvedTypeReference(it) }, + isAsync = operation.isAsync, + ) + } + } + } + parentPackage += ServiceType(statement.name.name, operations, statement) + } + + is ProviderDeclarationNode -> { + reportDuplicateDeclaration(parentPackage, statement) + val implements = statement.implements.map { implements -> + ProviderType.Implements( + UnresolvedTypeReference(implements.serviceName), + emptyList(), + implements + ) + } + val transport = ProviderType.Transport( + name = statement.transport.protocolName.name, + configuration = statement.transport.configuration + ) + parentPackage += ProviderType(statement.name.name, implements, transport, statement) + } + + is ConsumerDeclarationNode -> { + parentPackage += ConsumerType( + provider = UnresolvedTypeReference(statement.providerName), + uses = statement.usages.map { + ConsumerType.Uses( + service = UnresolvedTypeReference(it.serviceName), + operations = emptyList(), + node = it + ) + }, + declaration = statement + ) + } + + is TypeAliasNode -> { + controller.getOrCreateContext(statement.location.source).error { + message("Type aliases are not yet supported") + highlight("unsupported feature", statement.location) + } + } + + is PackageDeclarationNode, + is ImportNode, + -> Unit + } + } + } + } +} diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelReferenceResolver.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelReferenceResolver.kt new file mode 100644 index 00000000..6a58ee34 --- /dev/null +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelReferenceResolver.kt @@ -0,0 +1,196 @@ +package tools.samt.semantic + +import tools.samt.common.DiagnosticController +import tools.samt.parser.* + +internal class SemanticModelReferenceResolver( + private val controller: DiagnosticController, + private val global: Package, +) { + private val constraintBuilder = ConstraintBuilder(controller) + + fun resolveAndLinkExpression( + scope: SemanticModelBuilder.FileScope, + rootExpression: ExpressionNode, + ): ResolvedTypeReference { + val resolvedTypeReference = resolveExpressionInternal(scope, rootExpression) + scope.filePackage.sourcePackage.linkType(resolvedTypeReference.typeNode, resolvedTypeReference.type) + return resolvedTypeReference + } + + private fun resolveExpressionInternal( + scope: SemanticModelBuilder.FileScope, + expression: ExpressionNode, + ): ResolvedTypeReference { + when (expression) { + is IdentifierNode -> { + scope.typeLookup[expression.name]?.let { type -> + return ResolvedTypeReference(type, expression) + } + + controller.getOrCreateContext(expression.location.source).error { + message("Type '${expression.name}' could not be resolved") + highlight("unresolved type", expression.location) + } + } + + is BundleIdentifierNode -> { + // Bundle identifiers with one component are treated like normal identifiers + if (expression.components.size == 1) { + return resolveAndLinkExpression(scope, expression.components.first()).also { + scope.filePackage.sourcePackage.linkType(expression, it.type) + } + } + // Type is foo.bar.Baz + // Resolve foo first, it must be a package + when (val expectedPackageType = scope.typeLookup[expression.components.first().name]) { + is PackageType -> { + resolveType( + expression.components.subList(1, expression.components.size), + expectedPackageType.sourcePackage + )?.let { type -> + return ResolvedTypeReference(type, expression) + } + } + + null -> { + controller.getOrCreateContext(expression.location.source).error { + message("Type '${expression.name}' could not be resolved") + highlight("unresolved type", expression.location) + } + } + + else -> { + controller.getOrCreateContext(expression.location.source).error { + message("Type '${expression.components.first().name}' is not a package, cannot access sub-types") + highlight("not a package", expression.components.first().location) + } + } + } + } + + is CallExpressionNode -> { + val baseType = resolveAndLinkExpression(scope, expression.base) + val constraints = expression.arguments.mapNotNull { constraintBuilder.build(baseType.type, it) } + if (baseType.constraints.isNotEmpty()) { + controller.getOrCreateContext(expression.location.source).error { + message("Cannot have nested constraints") + highlight("illegal nested constraint", expression.location) + } + } + return baseType.copy(constraints = constraints, fullNode = expression) + } + + is GenericSpecializationNode -> { + val name = expression.base.let { + when (it) { + is IdentifierNode -> it.name + is BundleIdentifierNode -> it.name + else -> null + } + } + when (name) { + "List" -> { + if (expression.arguments.size == 1) { + return ResolvedTypeReference( + type = ListType( + elementType = resolveAndLinkExpression(scope, expression.arguments[0]), + node = expression, + ), + typeNode = expression.base, + fullNode = expression, + ) + } + } + + "Map" -> { + if (expression.arguments.size == 2) { + return ResolvedTypeReference( + type = MapType( + keyType = resolveAndLinkExpression(scope, expression.arguments[0]), + valueType = resolveAndLinkExpression(scope, expression.arguments[1]), + node = expression, + ), + typeNode = expression.base, + fullNode = expression, + ) + } + } + } + controller.getOrCreateContext(expression.location.source).error { + message("Unsupported generic type") + highlight(expression.location) + help("Valid generic types are List and Map") + } + } + + is OptionalDeclarationNode -> { + val baseType = resolveAndLinkExpression(scope, expression.base) + if (baseType.isOptional) { + controller.getOrCreateContext(expression.location.source).warn { + message("Type is already optional, ignoring '?'") + highlight("already optional", expression.base.location) + } + } + return baseType.copy(isOptional = true, fullNode = expression) + } + + is BooleanNode, + is NumberNode, + is StringNode, + -> controller.getOrCreateContext(expression.location.source).error { + message("Cannot use literal value as type") + highlight("not a type expression", expression.location) + } + + is ObjectNode, + is ArrayNode, + is RangeExpressionNode, + is WildcardNode, + -> controller.getOrCreateContext(expression.location.source).error { + message("Invalid type expression") + highlight("not a type expression", expression.location) + } + } + + return ResolvedTypeReference(UnknownType, expression) + } + + + fun resolveType(bundleIdentifierNode: BundleIdentifierNode) = resolveType(bundleIdentifierNode.components) + private fun resolveType(components: List, start: Package = global): Type? { + var currentPackage = start + val iterator = components.listIterator() + while (iterator.hasNext()) { + val component = iterator.next() + when (val resolvedType = currentPackage.resolveType(component)) { + is PackageType -> { + currentPackage = resolvedType.sourcePackage + } + + null -> { + controller.getOrCreateContext(component.location.source).error { + message("Could not resolve reference '${component.name}'") + highlight("unresolved reference", component.location) + } + return null + } + + else -> { + if (iterator.hasNext()) { + // We resolved a non-package type but there are still components left + + controller.getOrCreateContext(component.location.source).error { + message("Type '${component.name}' is not a package, cannot access sub-types") + highlight("must be a package", component.location) + } + return null + } + return resolvedType + } + } + } + + return PackageType(currentPackage) + } +} diff --git a/semantic/src/main/kotlin/tools/samt/semantic/Types.kt b/semantic/src/main/kotlin/tools/samt/semantic/Types.kt index 99e82750..d80b9a63 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/Types.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/Types.kt @@ -112,12 +112,13 @@ object DurationType : LiteralType { sealed interface CompoundType : Type -sealed interface UserDefinedType : Type { - val definition: Node +sealed interface UserDeclared { + val declaration: Node } data class ListType( val elementType: TypeReference, + val node: GenericSpecializationNode, ) : CompoundType { override val humanReadableName: String = "List<${elementType.humanReadableName}>" } @@ -125,6 +126,7 @@ data class ListType( data class MapType( val keyType: TypeReference, val valueType: TypeReference, + val node: GenericSpecializationNode, ) : CompoundType { override val humanReadableName: String = "Map<${keyType.humanReadableName}, ${valueType.humanReadableName}>" } @@ -132,11 +134,12 @@ data class MapType( data class RecordType( val name: String, val fields: List, - override val definition: RecordDeclarationNode, -) : CompoundType, UserDefinedType { + override val declaration: RecordDeclarationNode, +) : CompoundType, UserDeclared { data class Field( val name: String, var type: TypeReference, + val declaration: RecordFieldNode, ) override val humanReadableName: String = name @@ -145,37 +148,41 @@ data class RecordType( data class EnumType( val name: String, val values: List, - override val definition: EnumDeclarationNode, -) : CompoundType, UserDefinedType { + override val declaration: EnumDeclarationNode, +) : CompoundType, UserDeclared { override val humanReadableName: String = name } data class ServiceType( val name: String, val operations: List, - override val definition: ServiceDeclarationNode, -) : CompoundType, UserDefinedType { - sealed class Operation( - val name: String, - val parameters: List, - ) { + override val declaration: ServiceDeclarationNode, +) : CompoundType, UserDeclared { + sealed interface Operation : UserDeclared { + val name: String + val parameters: List + override val declaration: OperationNode data class Parameter( val name: String, var type: TypeReference, - ) + override val declaration: OperationParameterNode, + ): UserDeclared } - class RequestResponseOperation( - name: String, - parameters: List, + data class RequestResponseOperation( + override val name: String, + override val parameters: List, + override val declaration: RequestResponseOperationNode, var returnType: TypeReference?, var raisesTypes: List, - ) : Operation(name, parameters) + val isAsync: Boolean, + ) : Operation - class OnewayOperation( - name: String, - parameters: List, - ) : Operation(name, parameters) + data class OnewayOperation( + override val name: String, + override val parameters: List, + override val declaration: OnewayOperationNode, + ) : Operation override val humanReadableName: String = name } @@ -184,12 +191,12 @@ data class ProviderType( val name: String, val implements: List, val transport: Transport, - override val definition: ProviderDeclarationNode, -) : CompoundType, UserDefinedType { + override val declaration: ProviderDeclarationNode, +) : CompoundType, UserDeclared { data class Implements( var service: TypeReference, var operations: List, - val definition: ProviderImplementsNode, + val node: ProviderImplementsNode, ) data class Transport( @@ -203,12 +210,12 @@ data class ProviderType( data class ConsumerType( var provider: TypeReference, var uses: List, - override val definition: ConsumerDeclarationNode, -) : CompoundType, UserDefinedType { + override val declaration: ConsumerDeclarationNode, +) : CompoundType, UserDeclared { data class Uses( var service: TypeReference, var operations: List, - val definition: ConsumerUsesNode, + val node: ConsumerUsesNode, ) override val humanReadableName: String = "consumer for ${provider.humanReadableName}" @@ -226,8 +233,11 @@ data class UnresolvedTypeReference( } data class ResolvedTypeReference( - val definition: ExpressionNode, val type: Type, + /** Includes only the type reference, e.g. "foo.bar.Baz", "Map" or "String" */ + val typeNode: ExpressionNode, + /** Includes the full type reference, e.g. "List? (1..100)" */ + val fullNode: ExpressionNode = typeNode, val isOptional: Boolean = false, val constraints: List = emptyList(), ) : TypeReference { diff --git a/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt b/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt index b9bd330c..03163d6b 100644 --- a/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt +++ b/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt @@ -72,8 +72,9 @@ class SemanticModelTest { enum B { } service C { } + service D { } provide C { - implements C + implements D transport HTTP } """.trimIndent()