diff --git a/.gitignore b/.gitignore index 1a49ce5e..363bf1ca 100644 --- a/.gitignore +++ b/.gitignore @@ -14,5 +14,10 @@ build ehthumbs.db Thumbs.db +# SAMT wrapper generated files # +.samt +samtw +samtw.bat + # Random files used for debugging specification/examples/debug.samt diff --git a/build.gradle.kts b/build.gradle.kts index e33109fc..70ad21bc 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -14,6 +14,7 @@ dependencies { kover(project(":cli")) kover(project(":language-server")) kover(project(":samt-config")) + kover(project(":codegen")) } koverReport { diff --git a/cli/build.gradle.kts b/cli/build.gradle.kts index a2cab7e4..8260b821 100644 --- a/cli/build.gradle.kts +++ b/cli/build.gradle.kts @@ -14,6 +14,9 @@ dependencies { implementation(project(":lexer")) implementation(project(":parser")) implementation(project(":semantic")) + implementation(project(":samt-config")) + implementation(project(":codegen")) + implementation(project(":public-api")) } application { diff --git a/cli/src/main/kotlin/tools/samt/cli/CliArgs.kt b/cli/src/main/kotlin/tools/samt/cli/CliArgs.kt index 75d11099..586b66ae 100644 --- a/cli/src/main/kotlin/tools/samt/cli/CliArgs.kt +++ b/cli/src/main/kotlin/tools/samt/cli/CliArgs.kt @@ -10,8 +10,8 @@ class CliArgs { @Parameters(commandDescription = "Compile SAMT files") class CompileCommand { - @Parameter(description = "Files to compile, defaults to all .samt files in the current directory") - var files: List = mutableListOf() + @Parameter(description = "SAMT project to compile, defaults to the 'samt.yaml' file in the current directory") + var file: String = "./samt.yaml" } @Parameters(commandDescription = "Dump SAMT files in various formats for debugging purposes") @@ -25,8 +25,8 @@ class DumpCommand { @Parameter(names = ["--types"], description = "Dump a visual representation of the resolved types") var dumpTypes: Boolean = false - @Parameter(description = "Files to dump, defaults to all .samt files in the current directory") - var files: List = mutableListOf() + @Parameter(description = "SAMT project to dump, defaults to the 'samt.yaml' file in the current directory") + var file: String = "./samt.yaml" } @Parameters(commandDescription = "Initialize or update the SAMT wrapper") diff --git a/cli/src/main/kotlin/tools/samt/cli/CliCompiler.kt b/cli/src/main/kotlin/tools/samt/cli/CliCompiler.kt index 1e26c0dc..b6d10e99 100644 --- a/cli/src/main/kotlin/tools/samt/cli/CliCompiler.kt +++ b/cli/src/main/kotlin/tools/samt/cli/CliCompiler.kt @@ -1,13 +1,26 @@ package tools.samt.cli +import tools.samt.codegen.Codegen import tools.samt.common.DiagnosticController import tools.samt.common.DiagnosticException +import tools.samt.common.collectSamtFiles +import tools.samt.common.readSamtSource import tools.samt.lexer.Lexer import tools.samt.parser.Parser import tools.samt.semantic.SemanticModel +import java.io.IOException +import kotlin.io.path.isDirectory +import kotlin.io.path.notExists internal fun compile(command: CompileCommand, controller: DiagnosticController) { - val sourceFiles = command.files.readSamtSourceFiles(controller) + val (configuration ,_) = CliConfigParser.readConfig(command.file, controller) ?: return + + if (configuration.source.notExists() || !configuration.source.isDirectory()) { + controller.reportGlobalError("Source path '${configuration.source.toUri()}' does not point to valid directory") + return + } + + val sourceFiles = collectSamtFiles(configuration.source.toUri()).readSamtSource(controller) if (controller.hasErrors()) { return @@ -40,7 +53,24 @@ internal fun compile(command: CompileCommand, controller: DiagnosticController) } // build up the semantic model from the AST - SemanticModel.build(fileNodes, controller) + val model = SemanticModel.build(fileNodes, controller) - // Code Generators will be called here + // if the semantic model failed to build, exit + if (controller.hasErrors()) { + return + } + + if (configuration.generators.isEmpty()) { + controller.reportGlobalInfo("No generators configured, did you forget to add a 'generators' section to the 'samt.yaml' configuration?") + return + } + + for (generator in configuration.generators) { + val files = Codegen.generate(model, generator, controller) + try { + OutputWriter.write(generator.output, files) + } catch (e: IOException) { + controller.reportGlobalError("Failed to write output for generator '${generator.name}': ${e.message}") + } + } } diff --git a/cli/src/main/kotlin/tools/samt/cli/CliConfigParser.kt b/cli/src/main/kotlin/tools/samt/cli/CliConfigParser.kt new file mode 100644 index 00000000..db39d7a1 --- /dev/null +++ b/cli/src/main/kotlin/tools/samt/cli/CliConfigParser.kt @@ -0,0 +1,41 @@ +package tools.samt.cli + +import tools.samt.common.DiagnosticController +import tools.samt.common.SamtConfiguration +import tools.samt.common.SamtLinterConfiguration +import tools.samt.config.SamtConfigurationParser +import java.nio.file.InvalidPathException +import kotlin.io.path.Path +import kotlin.io.path.notExists + +internal object CliConfigParser { + fun readConfig(file: String, controller: DiagnosticController): Pair? { + val configFile = try { + Path(file) + } catch (e: InvalidPathException) { + controller.reportGlobalError("Invalid path '${file}': ${e.message}") + return null + } + if (configFile.notExists()) { + controller.reportGlobalInfo("Configuration file '${configFile.toUri()}' does not exist, using default configuration") + } + val configuration = try { + SamtConfigurationParser.parseConfiguration(configFile) + } catch (e: Exception) { + controller.reportGlobalError("Failed to parse configuration file '${configFile.toUri()}': ${e.message}") + return null + } + val samtLintConfigFile = configFile.resolveSibling(".samtrc.yaml") + if (samtLintConfigFile.notExists()) { + controller.reportGlobalInfo("Lint configuration file '${samtLintConfigFile.toUri()}' does not exist, using default lint configuration") + } + val linterConfiguration = try { + SamtConfigurationParser.parseLinterConfiguration(samtLintConfigFile) + } catch (e: Exception) { + controller.reportGlobalError("Failed to parse lint configuration file '${samtLintConfigFile.toUri()}': ${e.message}") + return null + } + + return Pair(configuration, linterConfiguration) + } +} diff --git a/cli/src/main/kotlin/tools/samt/cli/CliDumper.kt b/cli/src/main/kotlin/tools/samt/cli/CliDumper.kt index cb4f95c1..39ed2428 100644 --- a/cli/src/main/kotlin/tools/samt/cli/CliDumper.kt +++ b/cli/src/main/kotlin/tools/samt/cli/CliDumper.kt @@ -3,12 +3,27 @@ package tools.samt.cli import com.github.ajalt.mordant.terminal.Terminal import tools.samt.common.DiagnosticController import tools.samt.common.DiagnosticException +import tools.samt.common.collectSamtFiles +import tools.samt.common.readSamtSource import tools.samt.lexer.Lexer import tools.samt.parser.Parser import tools.samt.semantic.SemanticModel +import kotlin.io.path.isDirectory +import kotlin.io.path.notExists internal fun dump(command: DumpCommand, terminal: Terminal, controller: DiagnosticController) { - val sourceFiles = command.files.readSamtSourceFiles(controller) + val (configuration ,_) = CliConfigParser.readConfig(command.file, controller) ?: return + + if (configuration.source.notExists() || !configuration.source.isDirectory()) { + controller.reportGlobalError("Source path '${configuration.source.toUri()}' does not point to valid directory") + return + } + + val sourceFiles = collectSamtFiles(configuration.source.toUri()).readSamtSource(controller) + + if (controller.hasErrors()) { + return + } if (controller.hasErrors()) { return diff --git a/cli/src/main/kotlin/tools/samt/cli/CliFileResolution.kt b/cli/src/main/kotlin/tools/samt/cli/CliFileResolution.kt deleted file mode 100644 index 4a959abe..00000000 --- a/cli/src/main/kotlin/tools/samt/cli/CliFileResolution.kt +++ /dev/null @@ -1,12 +0,0 @@ -package tools.samt.cli - -import tools.samt.common.DiagnosticController -import tools.samt.common.SourceFile -import tools.samt.common.collectSamtFiles -import tools.samt.common.readSamtSource -import java.io.File - -internal fun List.readSamtSourceFiles(controller: DiagnosticController): List = - map { File(it) }.ifEmpty { collectSamtFiles(controller.workingDirectory) } - .readSamtSource(controller) - diff --git a/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt b/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt index 167e11b2..e40b2781 100644 --- a/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt +++ b/cli/src/main/kotlin/tools/samt/cli/DiagnosticFormatter.kt @@ -15,9 +15,6 @@ internal class DiagnosticFormatter( companion object { private const val CONTEXT_ROW_COUNT = 3 - // FIXME: this is a bit of a hack to get the terminal width - // it also means we're assuming this output will only ever be printed in a terminal - // i don't actually know what happens if it doesn't run in a tty setting fun format(controller: DiagnosticController, startTimestamp: Long, currentTimestamp: Long, terminalWidth: Int = Terminal().info.width): String { val formatter = DiagnosticFormatter(controller, startTimestamp, currentTimestamp, terminalWidth) return formatter.format() diff --git a/cli/src/main/kotlin/tools/samt/cli/OutputWriter.kt b/cli/src/main/kotlin/tools/samt/cli/OutputWriter.kt new file mode 100644 index 00000000..7a031474 --- /dev/null +++ b/cli/src/main/kotlin/tools/samt/cli/OutputWriter.kt @@ -0,0 +1,39 @@ +package tools.samt.cli + +import tools.samt.api.plugin.CodegenFile +import java.io.IOException +import java.nio.file.InvalidPathException +import java.nio.file.Path +import kotlin.io.path.* + +internal object OutputWriter { + @Throws(IOException::class) + fun write(outputDirectory: Path, files: List) { + if (!outputDirectory.exists()) { + try { + outputDirectory.createDirectories() + } catch (e: IOException) { + throw IOException("Failed to create output directory '${outputDirectory}'", e) + } + } + if (!outputDirectory.isDirectory()) { + throw IOException("Path '${outputDirectory}' does not point to a directory") + } + for (file in files) { + val outputFile = try { + outputDirectory.resolve(file.filepath) + } catch (e: InvalidPathException) { + throw IOException("Invalid path '${file.filepath}'", e) + } + try { + outputFile.parent.createDirectories() + if (outputFile.notExists()) { + outputFile.createFile() + } + outputFile.writeText(file.source) + } catch (e: IOException) { + throw IOException("Failed to write file '${outputFile.toUri()}'", e) + } + } + } +} diff --git a/cli/src/main/kotlin/tools/samt/cli/TypePrinter.kt b/cli/src/main/kotlin/tools/samt/cli/TypePrinter.kt index 15fe1ea1..0a13e3bf 100644 --- a/cli/src/main/kotlin/tools/samt/cli/TypePrinter.kt +++ b/cli/src/main/kotlin/tools/samt/cli/TypePrinter.kt @@ -6,7 +6,12 @@ import tools.samt.semantic.Package internal object TypePrinter { fun dump(samtPackage: Package): String = buildString { - appendLine(blue(samtPackage.name.ifEmpty { "" })) + if (samtPackage.isRootPackage) { + appendLine(red("")) + } else { + appendLine(blue(samtPackage.name)) + } + for (enum in samtPackage.enums) { appendLine(" ${bold("enum")} ${yellow(enum.name)}") } diff --git a/codegen/build.gradle.kts b/codegen/build.gradle.kts new file mode 100644 index 00000000..dbbdcad2 --- /dev/null +++ b/codegen/build.gradle.kts @@ -0,0 +1,13 @@ +plugins { + id("samt-core.kotlin-conventions") + alias(libs.plugins.kover) +} + +dependencies { + implementation(project(":common")) + implementation(project(":parser")) + implementation(project(":semantic")) + implementation(project(":public-api")) + testImplementation(project(":lexer")) + testImplementation(project(":samt-config")) +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/Codegen.kt b/codegen/src/main/kotlin/tools/samt/codegen/Codegen.kt new file mode 100644 index 00000000..f9f28e6e --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/Codegen.kt @@ -0,0 +1,61 @@ +package tools.samt.codegen + +import tools.samt.api.plugin.CodegenFile +import tools.samt.api.plugin.Generator +import tools.samt.api.plugin.GeneratorParams +import tools.samt.api.plugin.TransportConfigurationParser +import tools.samt.api.types.SamtPackage +import tools.samt.codegen.http.HttpTransportConfigurationParser +import tools.samt.codegen.kotlin.KotlinTypesGenerator +import tools.samt.codegen.kotlin.ktor.KotlinKtorConsumerGenerator +import tools.samt.codegen.kotlin.ktor.KotlinKtorProviderGenerator +import tools.samt.common.DiagnosticController +import tools.samt.common.SamtGeneratorConfiguration +import tools.samt.semantic.SemanticModel + +object Codegen { + private val generators: List = listOf( + KotlinTypesGenerator, + KotlinKtorProviderGenerator, + KotlinKtorConsumerGenerator, + ) + + private val transports: List = listOf( + HttpTransportConfigurationParser, + ) + + internal class SamtGeneratorParams( + semanticModel: SemanticModel, + private val controller: DiagnosticController, + override val options: Map, + ) : GeneratorParams { + private val apiMapper = PublicApiMapper(transports, controller) + override val packages: List = semanticModel.global.allSubPackages.map { apiMapper.toPublicApi(it) } + + override fun reportError(message: String) { + controller.reportGlobalError(message) + } + + override fun reportWarning(message: String) { + controller.reportGlobalWarning(message) + } + + override fun reportInfo(message: String) { + controller.reportGlobalInfo(message) + } + } + + fun generate( + semanticModel: SemanticModel, + configuration: SamtGeneratorConfiguration, + controller: DiagnosticController, + ): List { + val matchingGenerators = generators.filter { it.name == configuration.name } + when (matchingGenerators.size) { + 0 -> controller.reportGlobalError("No matching generator found for '${configuration.name}'") + 1 -> return matchingGenerators.single().generate(SamtGeneratorParams(semanticModel, controller, configuration.options)) + else -> controller.reportGlobalError("Multiple matching generators found for '${configuration.name}'") + } + return emptyList() + } +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/PublicApiMapper.kt b/codegen/src/main/kotlin/tools/samt/codegen/PublicApiMapper.kt new file mode 100644 index 00000000..447ef220 --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/PublicApiMapper.kt @@ -0,0 +1,292 @@ +package tools.samt.codegen + +import tools.samt.api.plugin.* +import tools.samt.api.types.* +import tools.samt.common.DiagnosticController +import tools.samt.parser.reportError +import tools.samt.parser.reportInfo +import tools.samt.parser.reportWarning + +class PublicApiMapper( + private val transportParsers: List, + private val controller: DiagnosticController, +) { + private val typeCache = mutableMapOf() + + /** + * Returns a lazy delegate that will initialize its value only once, without synchronization. + * Because we are in a single-threaded environment, this is safe and significantly faster. + */ + fun unsafeLazy(initializer: () -> T) = lazy(LazyThreadSafetyMode.NONE, initializer) + + fun toPublicApi(samtPackage: tools.samt.semantic.Package) = object : SamtPackage { + override val name = samtPackage.name + override val qualifiedName = samtPackage.nameComponents.joinToString(".") + override val records = samtPackage.records.map { it.toPublicRecord() } + override val enums = samtPackage.enums.map { it.toPublicEnum() } + override val services = samtPackage.services.map { it.toPublicService() } + override val providers = samtPackage.providers.map { it.toPublicProvider() } + override val consumers = samtPackage.consumers.map { it.toPublicConsumer() } + override val aliases = samtPackage.aliases.map { it.toPublicAlias() } + } + + private fun tools.samt.semantic.RecordType.toPublicRecord() = object : RecordType { + override val name get() = this@toPublicRecord.name + override val qualifiedName by unsafeLazy { this@toPublicRecord.getQualifiedName() } + override val fields by unsafeLazy { this@toPublicRecord.fields.map { it.toPublicField() } } + } + + private fun tools.samt.semantic.RecordType.Field.toPublicField() = object : RecordField { + override val name get() = this@toPublicField.name + override val type by unsafeLazy { this@toPublicField.type.toPublicTypeReference() } + } + + private fun tools.samt.semantic.EnumType.toPublicEnum() = object : EnumType { + override val name get() = this@toPublicEnum.name + override val qualifiedName by unsafeLazy { this@toPublicEnum.getQualifiedName() } + override val values get() = this@toPublicEnum.values + } + + private fun tools.samt.semantic.ServiceType.toPublicService() = object : ServiceType { + override val name get() = this@toPublicService.name + override val qualifiedName by unsafeLazy { this@toPublicService.getQualifiedName() } + override val operations by unsafeLazy { this@toPublicService.operations.map { it.toPublicOperation() } } + } + + private fun tools.samt.semantic.ServiceType.Operation.toPublicOperation() = when (this) { + is tools.samt.semantic.ServiceType.OnewayOperation -> toPublicOnewayOperation() + is tools.samt.semantic.ServiceType.RequestResponseOperation -> toPublicRequestResponseOperation() + } + + private fun tools.samt.semantic.ServiceType.OnewayOperation.toPublicOnewayOperation() = object : OnewayOperation { + override val name get() = this@toPublicOnewayOperation.name + override val parameters by unsafeLazy { this@toPublicOnewayOperation.parameters.map { it.toPublicParameter() } } + } + + private fun tools.samt.semantic.ServiceType.RequestResponseOperation.toPublicRequestResponseOperation() = + object : RequestResponseOperation { + override val name get() = this@toPublicRequestResponseOperation.name + override val parameters by unsafeLazy { this@toPublicRequestResponseOperation.parameters.map { it.toPublicParameter() } } + override val returnType by unsafeLazy { this@toPublicRequestResponseOperation.returnType?.toPublicTypeReference() } + override val isAsync get() = this@toPublicRequestResponseOperation.isAsync + } + + private fun tools.samt.semantic.ServiceType.Operation.Parameter.toPublicParameter() = + object : ServiceOperationParameter { + override val name get() = this@toPublicParameter.name + override val type by unsafeLazy { this@toPublicParameter.type.toPublicTypeReference() } + } + + private fun tools.samt.semantic.ProviderType.toPublicProvider() = object : ProviderType { + override val name get() = this@toPublicProvider.name + override val qualifiedName by unsafeLazy { this@toPublicProvider.getQualifiedName() } + override val implements by unsafeLazy { this@toPublicProvider.implements.map { it.toPublicImplements() } } + override val transport by unsafeLazy { this@toPublicProvider.transport.toPublicTransport(this) } + } + + private class Params( + override val config: ConfigurationObject, + val controller: DiagnosticController + ) : TransportConfigurationParserParams { + + override fun reportError(message: String, context: ConfigurationElement?) { + if (context != null && context is PublicApiConfigurationMapping) { + context.original.reportError(controller) { + message(message) + highlight("offending configuration", context.original.location) + } + } else { + controller.reportGlobalError(message) + } + } + + override fun reportWarning(message: String, context: ConfigurationElement?) { + if (context != null && context is PublicApiConfigurationMapping) { + context.original.reportWarning(controller) { + message(message) + highlight("offending configuration", context.original.location) + } + } else { + controller.reportGlobalWarning(message) + } + } + + override fun reportInfo(message: String, context: ConfigurationElement?) { + if (context != null && context is PublicApiConfigurationMapping) { + context.original.reportInfo(controller) { + message(message) + highlight("related configuration", context.original.location) + } + } else { + controller.reportGlobalInfo(message) + } + } + } + + private fun tools.samt.semantic.ProviderType.Transport.toPublicTransport(provider: ProviderType): TransportConfiguration { + val transportConfigurationParsers = transportParsers.filter { it.transportName == name } + when (transportConfigurationParsers.size) { + 0 -> controller.reportGlobalWarning("No transport configuration parser found for transport '$name'") + 1 -> { + val transportConfigurationParser = transportConfigurationParsers.single() + if (configuration != null) { + val transportConfigNode = TransportConfigurationMapper(provider, controller).parse(configuration!!) + val config = Params(transportConfigNode, controller) + try { + return transportConfigurationParser.parse(config) + } catch (e: Exception) { + controller.reportGlobalError("Failed to parse transport configuration for transport '$name': ${e.message}") + } + } else { + return transportConfigurationParser.default() + } + } + + else -> controller.reportGlobalError("Multiple transport configuration parsers found for transport '$name'") + } + + return object : TransportConfiguration {} + } + + private fun tools.samt.semantic.ProviderType.Implements.toPublicImplements() = object : ProvidedService { + override val service = this@toPublicImplements.service.toPublicTypeReference().type as ServiceType + private val implementedOperationNames by unsafeLazy { this@toPublicImplements.operations.mapTo(mutableSetOf()) { it.name } } + override val implementedOperations by unsafeLazy { service.operations.filter { it.name in implementedOperationNames } } + override val unimplementedOperations by unsafeLazy { service.operations.filter { it.name !in implementedOperationNames } } + } + + private fun tools.samt.semantic.ConsumerType.toPublicConsumer() = object : ConsumerType { + override val provider by unsafeLazy { this@toPublicConsumer.provider.toPublicTypeReference().type as ProviderType } + override val uses by unsafeLazy { this@toPublicConsumer.uses.map { it.toPublicUses() } } + override val samtPackage by unsafeLazy { this@toPublicConsumer.parentPackage.nameComponents.joinToString(".") } + } + + private fun tools.samt.semantic.ConsumerType.Uses.toPublicUses() = object : ConsumedService { + override val service by unsafeLazy { this@toPublicUses.service.toPublicTypeReference().type as ServiceType } + private val consumedOperationNames by unsafeLazy { this@toPublicUses.operations.mapTo(mutableSetOf()) { it.name } } + override val consumedOperations by unsafeLazy { service.operations.filter { it.name in consumedOperationNames } } + override val unconsumedOperations by unsafeLazy { service.operations.filter { it.name !in consumedOperationNames } } + } + + private fun tools.samt.semantic.AliasType.toPublicAlias() = object : AliasType { + override val name get() = this@toPublicAlias.name + override val qualifiedName by unsafeLazy { this@toPublicAlias.getQualifiedName() } + override val aliasedType by unsafeLazy { this@toPublicAlias.aliasedType.toPublicTypeReference() } + override val fullyResolvedType by unsafeLazy { this@toPublicAlias.fullyResolvedType.toPublicTypeReference() } + } + + private inline fun List.findConstraint() = + firstOrNull { it is T } as T? + + private fun tools.samt.semantic.TypeReference?.toPublicTypeReference(): TypeReference { + check(this is tools.samt.semantic.ResolvedTypeReference) + val typeReference: tools.samt.semantic.ResolvedTypeReference = this@toPublicTypeReference + val runtimeTypeReference = when (val type = typeReference.type) { + is tools.samt.semantic.AliasType -> checkNotNull(type.fullyResolvedType) { "Found unresolved alias when generating code" } + else -> typeReference + } + return object : TypeReference { + override val type by lazy { typeReference.type.toPublicType() } + override val isOptional get() = typeReference.isOptional + override val rangeConstraint by unsafeLazy { + typeReference.constraints.findConstraint() + ?.toPublicRangeConstraint() + } + override val sizeConstraint by unsafeLazy { + typeReference.constraints.findConstraint() + ?.toPublicSizeConstraint() + } + override val patternConstraint by unsafeLazy { + typeReference.constraints.findConstraint() + ?.toPublicPatternConstraint() + } + override val valueConstraint by unsafeLazy { + typeReference.constraints.findConstraint() + ?.toPublicValueConstraint() + } + + override val runtimeType by unsafeLazy { runtimeTypeReference.type.toPublicType() } + override val isRuntimeOptional get() = isOptional || runtimeTypeReference.isOptional + override val runtimeRangeConstraint by unsafeLazy { + rangeConstraint + ?: runtimeTypeReference.constraints.findConstraint() + ?.toPublicRangeConstraint() + } + override val runtimeSizeConstraint by unsafeLazy { + sizeConstraint + ?: runtimeTypeReference.constraints.findConstraint() + ?.toPublicSizeConstraint() + } + override val runtimePatternConstraint by unsafeLazy { + patternConstraint + ?: runtimeTypeReference.constraints.findConstraint() + ?.toPublicPatternConstraint() + } + override val runtimeValueConstraint by unsafeLazy { + valueConstraint + ?: runtimeTypeReference.constraints.findConstraint() + ?.toPublicValueConstraint() + } + } + } + + private fun tools.samt.semantic.Type.toPublicType() = typeCache.computeIfAbsent(this@toPublicType) { + when (this) { + tools.samt.semantic.IntType -> object : IntType {} + tools.samt.semantic.LongType -> object : LongType {} + tools.samt.semantic.FloatType -> object : FloatType {} + tools.samt.semantic.DoubleType -> object : DoubleType {} + tools.samt.semantic.DecimalType -> object : DecimalType {} + tools.samt.semantic.BooleanType -> object : BooleanType {} + tools.samt.semantic.StringType -> object : StringType {} + tools.samt.semantic.BytesType -> object : BytesType {} + tools.samt.semantic.DateType -> object : DateType {} + tools.samt.semantic.DateTimeType -> object : DateTimeType {} + tools.samt.semantic.DurationType -> object : DurationType {} + is tools.samt.semantic.ListType -> object : ListType { + override val elementType by unsafeLazy { this@toPublicType.elementType.toPublicTypeReference() } + } + + is tools.samt.semantic.MapType -> object : MapType { + override val keyType by unsafeLazy { this@toPublicType.keyType.toPublicTypeReference() } + override val valueType by unsafeLazy { this@toPublicType.valueType.toPublicTypeReference() } + } + + is tools.samt.semantic.AliasType -> toPublicAlias() + is tools.samt.semantic.ConsumerType -> toPublicConsumer() + is tools.samt.semantic.EnumType -> toPublicEnum() + is tools.samt.semantic.ProviderType -> toPublicProvider() + is tools.samt.semantic.RecordType -> toPublicRecord() + is tools.samt.semantic.ServiceType -> toPublicService() + is tools.samt.semantic.PackageType -> error("Package type cannot be converted to public API") + tools.samt.semantic.UnknownType -> error("Unknown type cannot be converted to public API") + } + } + + private fun tools.samt.semantic.ResolvedTypeReference.Constraint.Range.toPublicRangeConstraint() = + object : Constraint.Range { + override val lowerBound get() = this@toPublicRangeConstraint.lowerBound + override val upperBound get() = this@toPublicRangeConstraint.upperBound + } + + private fun tools.samt.semantic.ResolvedTypeReference.Constraint.Size.toPublicSizeConstraint() = + object : Constraint.Size { + override val lowerBound get() = this@toPublicSizeConstraint.lowerBound + override val upperBound get() = this@toPublicSizeConstraint.upperBound + } + + private fun tools.samt.semantic.ResolvedTypeReference.Constraint.Pattern.toPublicPatternConstraint() = + object : Constraint.Pattern { + override val pattern get() = this@toPublicPatternConstraint.pattern + } + + private fun tools.samt.semantic.ResolvedTypeReference.Constraint.Value.toPublicValueConstraint() = + object : Constraint.Value { + override val value get() = this@toPublicValueConstraint.value + } + + private fun tools.samt.semantic.UserDeclaredNamedType.getQualifiedName(): String { + val components = parentPackage.nameComponents + name + return components.joinToString(".") + } +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/TransportConfigurationMapper.kt b/codegen/src/main/kotlin/tools/samt/codegen/TransportConfigurationMapper.kt new file mode 100644 index 00000000..a6bcb132 --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/TransportConfigurationMapper.kt @@ -0,0 +1,207 @@ +package tools.samt.codegen + +import tools.samt.api.plugin.ConfigurationElement +import tools.samt.api.plugin.ConfigurationList +import tools.samt.api.plugin.ConfigurationObject +import tools.samt.api.plugin.ConfigurationValue +import tools.samt.api.types.ProviderType +import tools.samt.api.types.ServiceOperation +import tools.samt.api.types.ServiceType +import tools.samt.common.DiagnosticController +import tools.samt.parser.reportError + +interface PublicApiConfigurationMapping { + val original: tools.samt.parser.Node +} + +class TransportConfigurationMapper( + private val provider: ProviderType, + private val controller: DiagnosticController, +) { + fun parse(configuration: tools.samt.parser.ObjectNode): ConfigurationObject { + return configuration.toConfigurationObject() + } + + private fun tools.samt.parser.Node.reportAndThrow(message: String): Nothing { + reportError(controller) { + message(message) + highlight("offending configuration", location) + } + error(message) + } + + private fun tools.samt.parser.ExpressionNode.toConfigurationElement(): ConfigurationElement = when (this) { + is tools.samt.parser.ArrayNode -> toConfigurationList() + is tools.samt.parser.BooleanNode -> toConfigurationValue() + is tools.samt.parser.BundleIdentifierNode -> components.last().toConfigurationValue() + is tools.samt.parser.IdentifierNode -> toConfigurationValue() + is tools.samt.parser.FloatNode -> toConfigurationValue() + is tools.samt.parser.IntegerNode -> toConfigurationValue() + is tools.samt.parser.ObjectNode -> toConfigurationObject() + is tools.samt.parser.StringNode -> toConfigurationValue() + else -> reportAndThrow("Unexpected expression") + } + + private fun tools.samt.parser.IntegerNode.toConfigurationValue() = + object : ConfigurationValue, PublicApiConfigurationMapping { + override val original = this@toConfigurationValue + override val asString: String get() = reportAndThrow("Unexpected integer, expected a string") + override val asIdentifier: String get() = reportAndThrow("Unexpected integer, expected an identifier") + + override fun > asEnum(enum: Class): T = + reportAndThrow("Unexpected integer, expected an enum (${enum.simpleName})") + + override val asLong: Long get() = original.value + override val asDouble: Double = original.value.toDouble() + override val asBoolean: Boolean get() = reportAndThrow("Unexpected integer, expected a boolean") + override val asServiceName: ServiceType get() = reportAndThrow("Unexpected integer, expected a service name") + override fun asOperationName(service: ServiceType): ServiceOperation = + reportAndThrow("Unexpected integer, expected an operation name") + + override val asObject: ConfigurationObject get() = reportAndThrow("Unexpected integer, expected an object") + override val asValue: ConfigurationValue get() = this + override val asList: ConfigurationList get() = reportAndThrow("Unexpected integer, expected a list") + } + + private fun tools.samt.parser.FloatNode.toConfigurationValue() = + object : ConfigurationValue, PublicApiConfigurationMapping { + override val original = this@toConfigurationValue + override val asString: String get() = reportAndThrow("Unexpected float, expected a string") + override val asIdentifier: String get() = reportAndThrow("Unexpected float, expected an identifier") + + override fun > asEnum(enum: Class): T = + reportAndThrow("Unexpected float, expected an enum (${enum.simpleName})") + + override val asLong: Long get() = reportAndThrow("Unexpected float, expected an integer") + override val asDouble: Double = original.value + override val asBoolean: Boolean get() = reportAndThrow("Unexpected float, expected a boolean") + override val asServiceName: ServiceType get() = reportAndThrow("Unexpected float, expected a service name") + override fun asOperationName(service: ServiceType): ServiceOperation = + reportAndThrow("Unexpected float, expected an operation name") + + override val asObject: ConfigurationObject get() = reportAndThrow("Unexpected float, expected an object") + override val asValue: ConfigurationValue get() = this + override val asList: ConfigurationList get() = reportAndThrow("Unexpected float, expected a list") + } + + private fun tools.samt.parser.StringNode.toConfigurationValue() = + object : ConfigurationValue, PublicApiConfigurationMapping { + override val original = this@toConfigurationValue + override val asString: String get() = original.value + override val asIdentifier: String get() = reportAndThrow("Unexpected string, expected an identifier") + + override fun > asEnum(enum: Class): T { + check(enum.isEnum) + return enum.enumConstants.find { it.name.equals(original.value, ignoreCase = true) } + ?: reportAndThrow("Illegal enum value, expected one of ${enum.enumConstants.joinToString { it.name }}") + } + + override val asLong: Long get() = reportAndThrow("Unexpected string, expected an integer") + override val asDouble: Double get() = reportAndThrow("Unexpected string, expected a float") + override val asBoolean: Boolean get() = reportAndThrow("Unexpected string, expected a boolean") + override val asServiceName: ServiceType get() = reportAndThrow("Unexpected string, expected a service name") + override fun asOperationName(service: ServiceType): ServiceOperation = + reportAndThrow("Unexpected string, expected an operation name") + + override val asObject: ConfigurationObject get() = reportAndThrow("Unexpected string, expected an object") + override val asValue: ConfigurationValue get() = this + override val asList: ConfigurationList get() = reportAndThrow("Unexpected string, expected a list") + } + + private fun tools.samt.parser.BooleanNode.toConfigurationValue() = + object : ConfigurationValue, PublicApiConfigurationMapping { + override val original = this@toConfigurationValue + override val asString: String get() = reportAndThrow("Unexpected boolean, expected a string") + override val asIdentifier: String get() = reportAndThrow("Unexpected boolean, expected an identifier") + + override fun > asEnum(enum: Class): T = + reportAndThrow("Unexpected boolean, expected an enum (${enum.simpleName})") + + override val asLong: Long get() = reportAndThrow("Unexpected boolean, expected an integer") + override val asDouble: Double get() = reportAndThrow("Unexpected boolean, expected a float") + override val asBoolean: Boolean get() = value + override val asServiceName: ServiceType get() = reportAndThrow("Unexpected boolean, expected a service name") + override fun asOperationName(service: ServiceType): ServiceOperation = + reportAndThrow("Unexpected boolean, expected an operation name") + + override val asObject: ConfigurationObject get() = reportAndThrow("Unexpected boolean, expected an object") + override val asValue: ConfigurationValue get() = this + override val asList: ConfigurationList get() = reportAndThrow("Unexpected boolean, expected a list") + } + + private fun tools.samt.parser.IdentifierNode.toConfigurationValue() = + object : ConfigurationValue, PublicApiConfigurationMapping { + override val original = this@toConfigurationValue + override val asString: String get() = reportAndThrow("Unexpected identifier, expected a string") + override val asIdentifier: String get() = original.name + + override fun > asEnum(enum: Class): T = + reportAndThrow("Unexpected identifier, expected an enum (${enum.simpleName})") + + override val asLong: Long get() = reportAndThrow("Unexpected identifier, expected an integer") + override val asDouble: Double get() = reportAndThrow("Unexpected identifier, expected a float") + override val asBoolean: Boolean get() = reportAndThrow("Unexpected identifier, expected a boolean") + override val asServiceName: ServiceType + get() = provider.implements.find { it.service.name == original.name }?.service + ?: reportAndThrow("No service with name '${original.name}' found in provider '${provider.name}'") + + override fun asOperationName(service: ServiceType): ServiceOperation = + provider.implements.find { it.service.qualifiedName == service.qualifiedName }?.implementedOperations?.find { it.name == original.name } + ?: reportAndThrow("No operation with name '${original.name}' found in service '${service.name}' of provider '${provider.name}'") + + override val asObject: ConfigurationObject get() = reportAndThrow("Unexpected identifier, expected an object") + override val asValue: ConfigurationValue get() = this + override val asList: ConfigurationList get() = reportAndThrow("Unexpected identifier, expected a list") + } + + private fun tools.samt.parser.ArrayNode.toConfigurationList() = + object : ConfigurationList, PublicApiConfigurationMapping { + override val original = this@toConfigurationList + override val entries: List + get() = original.values.map { it.toConfigurationElement() } + override val asObject: ConfigurationObject + get() = reportAndThrow("Unexpected array, expected an object") + override val asValue: ConfigurationValue + get() = reportAndThrow("Unexpected array, expected a value") + override val asList: ConfigurationList + get() = this + } + + private fun tools.samt.parser.ObjectNode.toConfigurationObject() = + object : ConfigurationObject, PublicApiConfigurationMapping { + override val original = this@toConfigurationObject + override val fields: Map + get() = original.fields.associate { it.name.toConfigurationValue() to it.value.toConfigurationElement() } + + override fun getField(name: String): ConfigurationElement = + getFieldOrNull(name) ?: run { + original.reportError(controller) { + message("No field with name '$name' found") + highlight("related object", original.location) + } + throw NoSuchElementException("No field with name '$name' found") + } + + override fun getFieldOrNull(name: String): ConfigurationElement? = + original.fields.find { it.name.name == name }?.value?.toConfigurationElement() + + override val asObject: ConfigurationObject + get() = this + override val asValue: ConfigurationValue + get() { + original.reportError(controller) { + message("Object is not a value") + highlight("unexpected object, expected value", original.location) + } + error("Object is not a value") + } + override val asList: ConfigurationList + get() { + original.reportError(controller) { + message("Object is not a list") + highlight("unexpected object, expected list", original.location) + } + error("Object is not a list") + } + } +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/http/HttpTransport.kt b/codegen/src/main/kotlin/tools/samt/codegen/http/HttpTransport.kt new file mode 100644 index 00000000..ad0d1181 --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/http/HttpTransport.kt @@ -0,0 +1,228 @@ +package tools.samt.codegen.http + +import tools.samt.api.plugin.TransportConfiguration +import tools.samt.api.plugin.TransportConfigurationParser +import tools.samt.api.plugin.TransportConfigurationParserParams +import tools.samt.api.plugin.asEnum + +object HttpTransportConfigurationParser : TransportConfigurationParser { + override val transportName: String + get() = "http" + + override fun default(): HttpTransportConfiguration = HttpTransportConfiguration( + serializationMode = HttpTransportConfiguration.SerializationMode.Json, + services = emptyList(), + ) + + private val isValidRegex = Regex("""\w+\s+\S+(\s+\{.*?\s+in\s+\S+})*""") + private val methodEndpointRegex = Regex("""(\w+)\s+(\S+)(.*)""") + private val parameterRegex = Regex("""\{(.*?)\s+in\s+(\S+)}""") + + override fun parse(params: TransportConfigurationParserParams): HttpTransportConfiguration { + val config = params.config + val serializationMode = + config.getFieldOrNull("serialization")?.asValue?.asEnum() + ?: HttpTransportConfiguration.SerializationMode.Json + + val services = config.getFieldOrNull("operations")?.asObject?.let { operations -> + + operations.asObject.fields.map { (operationsKey, operationsField) -> + val servicePath = operations.getFieldOrNull("basePath")?.asValue?.asString ?: "" + val service = operationsKey.asServiceName + val serviceName = service.name + val operationConfiguration = operationsField.asObject + + val parsedOperations = operationConfiguration.fields + .filterKeys { it.asIdentifier != "basePath" } + .mapNotNull { (key, value) -> + val operationConfig = value.asValue + val operation = key.asOperationName(service) + val operationName = operation.name + + if (!(operationConfig.asString matches isValidRegex)) { + params.reportError( + "Invalid operation config for '$operationName', expected ' '. A valid example: 'POST /${operationName} {parameter1, parameter2 in query}'", + operationConfig + ) + return@mapNotNull null + } + + val methodEndpointResult = methodEndpointRegex.matchEntire(operationConfig.asString) + if (methodEndpointResult == null) { + params.reportError( + "Invalid operation config for '$operationName', expected ' '", + operationConfig + ) + return@mapNotNull null + } + + val (method, path, parameterPart) = methodEndpointResult.destructured + + val methodEnum = when (method) { + "GET" -> HttpTransportConfiguration.HttpMethod.Get + "POST" -> HttpTransportConfiguration.HttpMethod.Post + "PUT" -> HttpTransportConfiguration.HttpMethod.Put + "DELETE" -> HttpTransportConfiguration.HttpMethod.Delete + "PATCH" -> HttpTransportConfiguration.HttpMethod.Patch + else -> { + params.reportError("Invalid http method '$method'", operationConfig) + return@mapNotNull null + } + } + + val parameters = mutableListOf() + + // parse path and path parameters + val pathComponents = path.split("/") + for (component in pathComponents) { + if (!component.startsWith("{") || !component.endsWith("}")) continue + + val pathParameterName = component.substring(1, component.length - 1) + + if (pathParameterName.isEmpty()) { + params.reportError( + "Expected parameter name between curly braces in '$path'", + operationConfig + ) + continue + } + + if (operation.parameters.none { it.name == pathParameterName }) { + params.reportError("Path parameter '$pathParameterName' not found in operation '$operationName'", operationConfig) + continue + } + + parameters += HttpTransportConfiguration.ParameterConfiguration( + name = pathParameterName, + transportMode = HttpTransportConfiguration.TransportMode.Path, + ) + } + + val parameterResults = parameterRegex.findAll(parameterPart) + // parse parameter declarations + for (parameterResult in parameterResults) { + val (names, type) = parameterResult.destructured + val transportMode = when (type) { + "query" -> HttpTransportConfiguration.TransportMode.Query + "header" -> HttpTransportConfiguration.TransportMode.Header + "body" -> HttpTransportConfiguration.TransportMode.Body + "cookie" -> HttpTransportConfiguration.TransportMode.Cookie + else -> { + params.reportError("Invalid transport mode '$type'", operationConfig) + continue + } + } + + for (name in names.split(",").map { it.trim() }) { + if (operation.parameters.none { it.name == name }) { + params.reportError("Parameter '$name' not found in operation '$operationName'", operationConfig) + continue + } + parameters += HttpTransportConfiguration.ParameterConfiguration( + name = name, + transportMode = transportMode, + ) + } + } + + HttpTransportConfiguration.OperationConfiguration( + name = operationName, + method = methodEnum, + path = path, + parameters = parameters, + ) + } + + HttpTransportConfiguration.ServiceConfiguration( + name = serviceName, + operations = parsedOperations, + path = servicePath + ) + } + } ?: emptyList() + + return HttpTransportConfiguration( + serializationMode = serializationMode, + services = services, + ) + } +} + +class HttpTransportConfiguration( + val serializationMode: SerializationMode, + val services: List, +) : TransportConfiguration { + class ServiceConfiguration( + val name: String, + val path: String, + val operations: List + ) { + fun getOperation(name: String): OperationConfiguration? { + return operations.firstOrNull { it.name == name } + } + } + + class OperationConfiguration( + val name: String, + val method: HttpMethod, + val path: String, + val parameters: List, + ) { + fun getParameter(name: String): ParameterConfiguration? { + return parameters.firstOrNull { it.name == name } + } + } + + class ParameterConfiguration( + val name: String, + val transportMode: TransportMode, + ) + + enum class SerializationMode { + Json, + } + + enum class TransportMode { + Body, // encoded in request body via serializationMode + Query, // encoded as url query parameter + Path, // encoded as part of url path + Header, // encoded as HTTP header + Cookie, // encoded as HTTP cookie + } + + enum class HttpMethod { + Get, + Post, + Put, + Delete, + Patch, + } + + private fun getService(name: String): ServiceConfiguration? { + return services.firstOrNull { it.name == name } + } + + fun getMethod(serviceName: String, operationName: String): HttpMethod { + val service = getService(serviceName) + val operation = service?.getOperation(operationName) + return operation?.method ?: HttpMethod.Post + } + + fun getPath(serviceName: String, operationName: String): String { + val service = getService(serviceName) + val operation = service?.getOperation(operationName) + return operation?.path ?: "/$operationName" + } + + fun getPath(serviceName: String): String { + val service = getService(serviceName) + return service?.path ?: "" + } + + fun getTransportMode(serviceName: String, operationName: String, parameterName: String): TransportMode { + val service = getService(serviceName) + val operation = service?.getOperation(operationName) + val parameter = operation?.getParameter(parameterName) + return parameter?.transportMode ?: TransportMode.Body + } +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/kotlin/KotlinGeneratorUtils.kt b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/KotlinGeneratorUtils.kt new file mode 100644 index 00000000..94a6b72c --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/KotlinGeneratorUtils.kt @@ -0,0 +1,71 @@ +package tools.samt.codegen.kotlin + +import tools.samt.api.types.* + +object KotlinGeneratorConfig { + const val removePrefixFromSamtPackage = "removePrefixFromSamtPackage" + const val addPrefixToKotlinPackage = "addPrefixToKotlinPackage" +} + +val GeneratedFilePreamble = """ + @file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + + /* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ +""".trimIndent() + +internal fun String.replacePackage(options: Map): String { + val removePrefix = options[KotlinGeneratorConfig.removePrefixFromSamtPackage] + val addPrefix = options[KotlinGeneratorConfig.addPrefixToKotlinPackage] + + var result = this + + if (removePrefix != null) { + result = result.removePrefix(removePrefix).removePrefix(".") + } + + if (addPrefix != null) { + result = "$addPrefix.$result" + } + + return result +} + +internal fun SamtPackage.getQualifiedName(options: Map): String = qualifiedName.replacePackage(options) + +internal fun TypeReference.getQualifiedName(options: Map): String { + val qualifiedName = type.getQualifiedName(options) + return if (isOptional) { + "$qualifiedName?" + } else { + qualifiedName + } +} + +internal fun Type.getQualifiedName(options: Map): String = when (this) { + is LiteralType -> when (this) { + is StringType -> "String" + is BytesType -> "ByteArray" + is IntType -> "Int" + is LongType -> "Long" + is FloatType -> "Float" + is DoubleType -> "Double" + is DecimalType -> "java.math.BigDecimal" + is BooleanType -> "Boolean" + is DateType -> "java.time.LocalDate" + is DateTimeType -> "java.time.LocalDateTime" + is DurationType -> "java.time.Duration" + else -> error("Unsupported literal type: ${this.javaClass.simpleName}") + } + + is ListType -> "List<${elementType.getQualifiedName(options)}>" + is MapType -> "Map<${keyType.getQualifiedName(options)}, ${valueType.getQualifiedName(options)}>" + + is UserType -> qualifiedName.replacePackage(options) + + else -> error("Unsupported type: ${javaClass.simpleName}") +} + +internal fun UserType.getTargetPackage(options: Map): String = qualifiedName.replacePackage(options).dropLastWhile { it != '.' } diff --git a/codegen/src/main/kotlin/tools/samt/codegen/kotlin/KotlinTypesGenerator.kt b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/KotlinTypesGenerator.kt new file mode 100644 index 00000000..decd712b --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/KotlinTypesGenerator.kt @@ -0,0 +1,139 @@ +package tools.samt.codegen.kotlin + +import tools.samt.api.plugin.CodegenFile +import tools.samt.api.plugin.Generator +import tools.samt.api.plugin.GeneratorParams +import tools.samt.api.types.* + +object KotlinTypesGenerator : Generator { + override val name: String = "kotlin-types" + override fun generate(generatorParams: GeneratorParams): List { + generatorParams.packages.forEach { + generatePackage(it, generatorParams.options) + } + val result = emittedFiles.toList() + emittedFiles.clear() + return result + } + + private val emittedFiles = mutableListOf() + + private fun generatePackage(pack: SamtPackage, options: Map) { + if (pack.hasModelTypes()) { + val packageSource = buildString { + appendLine(GeneratedFilePreamble) + appendLine() + appendLine("package ${pack.getQualifiedName(options)}") + appendLine() + + pack.records.forEach { + appendRecord(it, options) + } + + pack.enums.forEach { + appendEnum(it) + } + + pack.aliases.forEach { + appendAlias(it, options) + } + + pack.services.forEach { + appendService(it, options) + } + } + + val filePath = "${pack.getQualifiedName(options).replace('.', '/')}/Types.kt" + val file = CodegenFile(filePath, packageSource) + emittedFiles.add(file) + } + } + + private fun StringBuilder.appendRecord(record: RecordType, options: Map) { + if (record.fields.isEmpty()) { + appendLine("class ${record.name}") + appendLine() + return + } + + appendLine("data class ${record.name}(") + record.fields.forEach { field -> + val fullyQualifiedName = field.type.getQualifiedName(options) + + if (field.type.isRuntimeOptional) { + appendLine(" val ${field.name}: $fullyQualifiedName = null,") + } else { + appendLine(" val ${field.name}: $fullyQualifiedName,") + } + } + appendLine(")") + appendLine() + } + + private fun StringBuilder.appendEnum(enum: EnumType) { + appendLine("enum class ${enum.name} {") + appendLine(" /** Default value used when the enum could not be parsed */") + appendLine(" FAILED_TO_PARSE,") + enum.values.forEach { + appendLine(" ${it},") + } + appendLine("}") + } + + private fun StringBuilder.appendAlias(alias: AliasType, options: Map) { + appendLine("typealias ${alias.name} = ${alias.aliasedType.getQualifiedName(options)}") + } + + private fun StringBuilder.appendService(service: ServiceType, options: Map) { + appendLine("interface ${service.name} {") + service.operations.forEach { operation -> + appendServiceOperation(operation, options) + } + appendLine("}") + } + + private fun StringBuilder.appendServiceOperation(operation: ServiceOperation, options: Map) { + when (operation) { + is RequestResponseOperation -> { + // method head + if (operation.isAsync) { + appendLine(" suspend fun ${operation.name}(") + } else { + appendLine(" fun ${operation.name}(") + } + + // parameters + appendServiceOperationParameterList(operation.parameters, options) + + // return type + if (operation.returnType != null) { + appendLine(" ): ${operation.returnType!!.getQualifiedName(options)}") + } else { + appendLine(" )") + } + } + + is OnewayOperation -> { + appendLine(" fun ${operation.name}(") + appendServiceOperationParameterList(operation.parameters, options) + appendLine(" )") + } + } + } + + private fun StringBuilder.appendServiceOperationParameterList(parameters: List, options: Map) { + parameters.forEach { parameter -> + val fullyQualifiedName = parameter.type.getQualifiedName(options) + + if (parameter.type.isRuntimeOptional) { + appendLine(" ${parameter.name}: $fullyQualifiedName = null,") + } else { + appendLine(" ${parameter.name}: $fullyQualifiedName,") + } + } + } + + private fun SamtPackage.hasModelTypes(): Boolean { + return records.isNotEmpty() || enums.isNotEmpty() || services.isNotEmpty() || aliases.isNotEmpty() + } +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorConsumerGenerator.kt b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorConsumerGenerator.kt new file mode 100644 index 00000000..1b20f0a8 --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorConsumerGenerator.kt @@ -0,0 +1,254 @@ +package tools.samt.codegen.kotlin.ktor + +import tools.samt.api.plugin.CodegenFile +import tools.samt.api.plugin.Generator +import tools.samt.api.plugin.GeneratorParams +import tools.samt.api.types.* +import tools.samt.codegen.http.HttpTransportConfiguration +import tools.samt.codegen.kotlin.GeneratedFilePreamble +import tools.samt.codegen.kotlin.KotlinTypesGenerator +import tools.samt.codegen.kotlin.getQualifiedName + +object KotlinKtorConsumerGenerator : Generator { + override val name: String = "kotlin-ktor-consumer" + + override fun generate(generatorParams: GeneratorParams): List { + generatorParams.packages.forEach { + generateMappings(it, generatorParams.options) + generatePackage(it, generatorParams.options) + } + val result = KotlinTypesGenerator.generate(generatorParams) + emittedFiles + emittedFiles.clear() + return result + } + + private val emittedFiles = mutableListOf() + + private fun generateMappings(pack: SamtPackage, options: Map) { + val packageSource = mappingFileContent(pack, options) + if (packageSource.isNotEmpty()) { + val filePath = "${pack.getQualifiedName(options).replace('.', '/')}/KtorMappings.kt" + val file = CodegenFile(filePath, packageSource) + emittedFiles.add(file) + } + } + + private fun generatePackage(pack: SamtPackage, options: Map) { + val relevantConsumers = pack.consumers.filter { it.provider.transport is HttpTransportConfiguration } + if (relevantConsumers.isNotEmpty()) { + // generate ktor consumers + relevantConsumers.forEach { consumer -> + val transportConfiguration = consumer.provider.transport as HttpTransportConfiguration + + val packageSource = buildString { + appendLine(GeneratedFilePreamble) + appendLine() + appendLine("package ${pack.getQualifiedName(options)}") + appendLine() + + appendConsumer(consumer, transportConfiguration, options) + } + + val filePath = "${pack.getQualifiedName(options).replace('.', '/')}/Consumer.kt" + val file = CodegenFile(filePath, packageSource) + emittedFiles.add(file) + } + } + } + + data class ConsumerInfo(val consumer: ConsumerType, val uses: ConsumedService) { + val service = uses.service + val consumedOperations = uses.consumedOperations + val unconsumedOperations = uses.unconsumedOperations + } + + private fun StringBuilder.appendConsumer(consumer: ConsumerType, transportConfiguration: HttpTransportConfiguration, options: Map) { + appendLine("import io.ktor.client.*") + appendLine("import io.ktor.client.engine.cio.*") + appendLine("import io.ktor.client.plugins.contentnegotiation.*") + appendLine("import io.ktor.client.request.*") + appendLine("import io.ktor.client.statement.*") + appendLine("import io.ktor.http.*") + appendLine("import io.ktor.serialization.kotlinx.json.*") + appendLine("import io.ktor.util.*") + appendLine("import kotlinx.coroutines.runBlocking") + appendLine("import kotlinx.serialization.json.*") + appendLine("import kotlinx.coroutines.*") + appendLine() + + val implementedServices = consumer.uses.map { ConsumerInfo(consumer, it) } + appendLine("class ${consumer.className}(private val baseUrl: String) : ${implementedServices.joinToString { it.service.getQualifiedName(options) }} {") + implementedServices.forEach { info -> + appendConsumerOperations(info, transportConfiguration, options) + } + appendLine("}") + } + + private fun StringBuilder.appendConsumerOperations(info: ConsumerInfo, transportConfiguration: HttpTransportConfiguration, options: Map) { + appendLine(" private val client = HttpClient(CIO) {") + appendLine(" install(ContentNegotiation) {") + appendLine(" json()") + appendLine(" }") + appendLine(" }") + appendLine() + appendLine(" /** Used to launch oneway operations asynchronously */") + appendLine(" private val onewayScope = CoroutineScope(Dispatchers.IO)") + appendLine() + + info.consumedOperations.forEach { operation -> + val operationParameters = operation.parameters.joinToString { "${it.name}: ${it.type.getQualifiedName(options)}" } + + when (operation) { + is RequestResponseOperation -> { + if (operation.isAsync) { + appendLine(" override suspend fun ${operation.name}($operationParameters): ${operation.returnType?.getQualifiedName(options) ?: "Unit"} = run {") + } else { + appendLine(" override fun ${operation.name}($operationParameters): ${operation.returnType?.getQualifiedName(options) ?: "Unit"} = runBlocking {") + } + + appendConsumerServiceCall(info, operation, transportConfiguration, options) + appendCheckResponseStatus(operation) + appendConsumerResponseParsing(operation, options) + + appendLine(" }") + } + + is OnewayOperation -> { + appendLine(" override fun ${operation.name}($operationParameters): Unit {") + appendLine(" onewayScope.launch {") + + appendConsumerServiceCall(info, operation, transportConfiguration, options) + appendCheckResponseStatus(operation) + + appendLine(" }") + appendLine(" }") + } + } + appendLine() + } + + info.unconsumedOperations.forEach { operation -> + val operationParameters = operation.parameters.joinToString { "${it.name}: ${it.type.getQualifiedName(options)}" } + + when (operation) { + is RequestResponseOperation -> { + if (operation.isAsync) { + appendLine(" override suspend fun ${operation.name}($operationParameters): ${operation.returnType?.getQualifiedName(options) ?: "Unit"}") + } else { + appendLine(" override fun ${operation.name}($operationParameters): ${operation.returnType?.getQualifiedName(options) ?: "Unit"}") + } + } + + is OnewayOperation -> { + appendLine(" override fun ${operation.name}($operationParameters): Unit") + } + } + appendLine(" = error(\"Not used in SAMT consumer and therefore not generated\")") + } + } + + private fun StringBuilder.appendConsumerServiceCall(info: ConsumerInfo, operation: ServiceOperation, transport: HttpTransportConfiguration, options: Map) { + // collect parameters for each transport type + val headerParameters = mutableMapOf() + val cookieParameters = mutableMapOf() + val bodyParameters = mutableMapOf() + val pathParameters = mutableMapOf() + val queryParameters = mutableMapOf() + operation.parameters.forEach { + val name = it.name + when (transport.getTransportMode(info.service.name, operation.name, name)) { + HttpTransportConfiguration.TransportMode.Header -> { + headerParameters[name] = it + } + HttpTransportConfiguration.TransportMode.Cookie -> { + cookieParameters[name] = it + } + HttpTransportConfiguration.TransportMode.Body -> { + bodyParameters[name] = it + } + HttpTransportConfiguration.TransportMode.Path -> { + pathParameters[name] = it + } + HttpTransportConfiguration.TransportMode.Query -> { + queryParameters[name] = it + } + } + } + + // build request headers and body + appendLine(" // Make actual network call") + appendLine(" val `client response` = client.request(this@${info.consumer.className}.baseUrl) {") + + // build request path + // need to split transport path into path segments and query parameter slots + // remove first empty component (paths start with a / so the first component is always empty) + val transportPath = transport.getPath(info.service.name, operation.name) + val transportPathComponents = transportPath.split("/") + appendLine(" url {") + appendLine(" // Construct path and encode path parameters") + transportPathComponents.drop(1).forEach { + if (it.startsWith("{") && it.endsWith("}")) { + val parameterName = it.substring(1, it.length - 1) + require(parameterName in pathParameters) { "${operation.name}: path parameter $parameterName is not a known path parameter" } + appendLine(" appendPathSegments($parameterName, encodeSlash = true)") + } else { + appendLine(" appendPathSegments(\"$it\", encodeSlash = true)") + } + } + appendLine() + + appendLine(" // Encode query parameters") + queryParameters.forEach { (name, queryParameter) -> + appendLine(" this.parameters.append(\"$name\", (${encodeJsonElement(queryParameter.type, options, valueName = name)}).toString())") + } + appendLine(" }") + + // serialization mode + when (transport.serializationMode) { + HttpTransportConfiguration.SerializationMode.Json -> appendLine(" contentType(ContentType.Application.Json)") + } + + // transport method + val transportMethod = transport.getMethod(info.service.name, operation.name) + appendLine(" this.method = HttpMethod.$transportMethod") + + // header parameters + headerParameters.forEach { (name, headerParameter) -> + appendLine(" header(\"${name}\", ${encodeJsonElement(headerParameter.type, options, valueName = name)})") + } + + // cookie parameters + cookieParameters.forEach { (name, cookieParameter) -> + appendLine(" cookie(\"${name}\", (${encodeJsonElement(cookieParameter.type, options, valueName = name)}).toString())") + } + + // body parameters + appendLine(" setBody(") + appendLine(" buildJsonObject {") + bodyParameters.forEach { (name, bodyParameter) -> + appendLine(" put(\"$name\", ${encodeJsonElement(bodyParameter.type, options, valueName = name)})") + } + appendLine(" }") + appendLine(" )") + + appendLine(" }") + } + + private fun StringBuilder.appendCheckResponseStatus(operation: ServiceOperation) { + appendLine(" check(`client response`.status.isSuccess()) { \"${operation.name} failed with status \${`client response`.status}\" }") + } + + private fun StringBuilder.appendConsumerResponseParsing( + operation: RequestResponseOperation, + options: Map + ) { + operation.returnType?.let { returnType -> + appendLine(" val bodyAsText = `client response`.bodyAsText()") + appendLine(" val jsonElement = Json.parseToJsonElement(bodyAsText)") + appendLine() + appendLine(" ${decodeJsonElement(returnType, options)}") + } + } + + private val ConsumerType.className get() = "${provider.name}Impl" +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorGeneratorUtilities.kt b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorGeneratorUtilities.kt new file mode 100644 index 00000000..e9495b3f --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorGeneratorUtilities.kt @@ -0,0 +1,249 @@ +package tools.samt.codegen.kotlin.ktor + +import tools.samt.api.types.* +import tools.samt.codegen.kotlin.GeneratedFilePreamble +import tools.samt.codegen.kotlin.getQualifiedName +import tools.samt.codegen.kotlin.getTargetPackage + +fun mappingFileContent(pack: SamtPackage, options: Map) = buildString { + if (pack.records.isNotEmpty() || pack.enums.isNotEmpty() || pack.aliases.isNotEmpty()) { + appendLine(GeneratedFilePreamble) + appendLine() + appendLine("package ${pack.getQualifiedName(options)}") + appendLine() + appendLine("import io.ktor.util.*") + appendLine("import kotlinx.serialization.json.*") + appendLine() + + pack.records.forEach { record -> + appendEncodeRecord(record, options) + appendDecodeRecord(record, options) + appendLine() + } + + pack.enums.forEach { enum -> + appendEncodeEnum(enum, options) + appendDecodeEnum(enum, options) + appendLine() + } + + pack.aliases.forEach { alias -> + appendEncodeAlias(alias, options) + appendDecodeAlias(alias, options) + appendLine() + } + } +} + +private fun StringBuilder.appendEncodeRecord( + record: RecordType, + options: Map, +) { + appendLine("/** Encode and validate record ${record.qualifiedName} to JSON */") + appendLine("fun `encode ${record.name}`(record: ${record.getQualifiedName(options)}): JsonElement {") + for (field in record.fields) { + appendEncodeRecordField(field, options) + } + appendLine(" // Create JSON for ${record.qualifiedName}") + appendLine(" return buildJsonObject {") + for (field in record.fields) { + appendLine(" put(\"${field.name}\", `field ${field.name}`)") + } + appendLine(" }") + appendLine("}") +} + +private fun StringBuilder.appendDecodeRecord( + record: RecordType, + options: Map, +) { + appendLine("/** Decode and validate record ${record.qualifiedName} from JSON */") + appendLine("fun `decode ${record.name}`(json: JsonElement): ${record.getQualifiedName(options)} {") + for (field in record.fields) { + appendDecodeRecordField(field, options) + } + appendLine(" // Create record ${record.qualifiedName}") + appendLine(" return ${record.getQualifiedName(options)}(") + for (field in record.fields) { + appendLine(" ${field.name} = `field ${field.name}`,") + } + appendLine(" )") + appendLine("}") +} + +private fun StringBuilder.appendEncodeEnum(enum: EnumType, options: Map) { + val enumName = enum.getQualifiedName(options) + appendLine("/** Encode enum ${enum.qualifiedName} to JSON */") + appendLine("fun `encode ${enum.name}`(value: ${enumName}?): JsonElement = when(value) {") + appendLine(" null -> JsonNull") + enum.values.forEach { value -> + appendLine(" ${enumName}.${value} -> JsonPrimitive(\"${value}\")") + } + appendLine(" ${enumName}.FAILED_TO_PARSE -> error(\"Cannot encode FAILED_TO_PARSE value\")") + appendLine("}") +} + +private fun StringBuilder.appendDecodeEnum(enum: EnumType, options: Map) { + val enumName = enum.getQualifiedName(options) + appendLine("/** Decode enum ${enum.qualifiedName} from JSON */") + appendLine("fun `decode ${enum.name}`(json: JsonElement): $enumName = when(json.jsonPrimitive.content) {") + enum.values.forEach { value -> + appendLine(" \"${value}\" -> ${enumName}.${value}") + } + appendLine(" // Value not found in enum ${enum.qualifiedName}") + appendLine(" else -> ${enumName}.FAILED_TO_PARSE") + appendLine("}") +} + +private fun StringBuilder.appendEncodeRecordField(field: RecordField, options: Map) { + appendLine(" // Encode field ${field.name}") + appendLine(" val `field ${field.name}` = run {") + append(" val value = record.${field.name}") + appendLine() + appendLine(" ${encodeJsonElement(field.type, options)}") + appendLine(" }") +} + +private fun StringBuilder.appendDecodeRecordField(field: RecordField, options: Map) { + appendLine(" // Decode field ${field.name}") + appendLine(" val `field ${field.name}` = run {") + append(" val jsonElement = ") + if (field.type.isRuntimeOptional) { + append("json.jsonObject[\"${field.name}\"] ?: JsonNull") + } else { + append("json.jsonObject[\"${field.name}\"]!!") + } + appendLine() + appendLine(" ${decodeJsonElement(field.type, options)}") + appendLine(" }") +} + +private fun StringBuilder.appendEncodeAlias(alias: AliasType, options: Map) { + appendLine("/** Encode alias ${alias.qualifiedName} to JSON */") + appendLine("fun `encode ${alias.name}`(value: ${alias.getQualifiedName(options)}): JsonElement =") + appendLine(" ${encodeJsonElement(alias.fullyResolvedType, options, valueName = "value")}") +} + +private fun StringBuilder.appendDecodeAlias(alias: AliasType, options: Map) { + appendLine("/** Decode alias ${alias.qualifiedName} from JSON */") + appendLine("fun `decode ${alias.name}`(json: JsonElement): ${alias.fullyResolvedType.getQualifiedName(options)} {") + if (alias.fullyResolvedType.isRuntimeOptional) { + appendLine(" if (json is JsonNull) return null") + } + appendLine(" return ${decodeJsonElement(alias.fullyResolvedType, options, valueName = "json")}") + appendLine("}") +} + +/** + * Encode a [typeReference] to a JSON element. + * The resulting expression will always be a JsonElement. + */ +fun encodeJsonElement(typeReference: TypeReference, options: Map, valueName: String = "value"): String { + val convertExpression = when (val type = typeReference.type) { + is LiteralType -> { + val getContent = when (type) { + is StringType, + is IntType, + is LongType, + is FloatType, + is DoubleType, + is BooleanType -> valueName + is BytesType -> "${valueName}.encodeBase64()" + is DecimalType -> "${valueName}.toPlainString()" + is DateType, + is DateTimeType, + is DurationType -> "${valueName}.toString()" + else -> error("Unsupported literal type: ${type.javaClass.simpleName}") + } + "JsonPrimitive($getContent${validateLiteralConstraintsSuffix(typeReference)})" + } + + is ListType -> "JsonArray(${valueName}.map { ${encodeJsonElement(type.elementType, options, valueName = "it")} })" + is MapType -> "JsonObject(${valueName}.mapValues { (_, value) -> ${encodeJsonElement(type.valueType, options, valueName = "value")} })" + + is UserType -> "${type.getTargetPackage(options)}`encode ${type.name}`(${valueName})" + + else -> error("Unsupported type: ${type.javaClass.simpleName}") + } + + return if (typeReference.isRuntimeOptional) { + "$valueName?.let { $valueName -> $convertExpression } ?: JsonNull" + } else { + convertExpression + } +} + +/** + * Decode a [typeReference] from a JSON element. + * The resulting expression will always be a value of the type. + */ +fun decodeJsonElement(typeReference: TypeReference, options: Map, valueName: String = "jsonElement"): String = + when (val type = typeReference.type) { + is LiteralType -> when (type) { + is StringType -> "${valueName}.jsonPrimitive.content" + is BytesType -> "${valueName}.jsonPrimitive.content.decodeBase64Bytes()" + is IntType -> "${valueName}.jsonPrimitive.int" + is LongType -> "${valueName}.jsonPrimitive.long" + is FloatType -> "${valueName}.jsonPrimitive.float" + is DoubleType -> "${valueName}.jsonPrimitive.double" + is DecimalType -> "${valueName}.jsonPrimitive.content.let { java.math.BigDecimal(it) }" + is BooleanType -> "${valueName}.jsonPrimitive.boolean" + is DateType -> "${valueName}.jsonPrimitive.content.let { java.time.LocalDate.parse(it) }" + is DateTimeType -> "${valueName}.jsonPrimitive.content.let { java.time.LocalDateTime.parse(it) }" + is DurationType -> "${valueName}.jsonPrimitive.content.let { java.time.Duration.parse(it) }" + else -> error("Unsupported literal type: ${type.javaClass.simpleName}") + } + validateLiteralConstraintsSuffix(typeReference) + + is ListType -> { + val elementDecodeStatement = decodeJsonElement(type.elementType, options, valueName = "it") + if (type.elementType.isRuntimeOptional) + "${valueName}.jsonArray.map { it.takeUnless { it is JsonNull }?.let { $elementDecodeStatement } }" + else + "${valueName}.jsonArray.map { $elementDecodeStatement }" + } + is MapType -> { + val valueDecodeStatement = decodeJsonElement(type.valueType, options, valueName = "value") + if (type.valueType.isRuntimeOptional) + "${valueName}.jsonObject.mapValues { (_, value) -> value.takeUnless { it is JsonNull }?.let { value -> $valueDecodeStatement } }" + else + "${valueName}.jsonObject.mapValues { (_, value) -> $valueDecodeStatement }" + } + + is UserType -> "${type.getTargetPackage(options)}`decode ${type.name}`(${valueName})" + + else -> error("Unsupported type: ${type.javaClass.simpleName}") + } + +private fun validateLiteralConstraintsSuffix(typeReference: TypeReference): String { + val conditions = buildList { + typeReference.rangeConstraint?.let { constraint -> + constraint.lowerBound?.let { + add("it >= ${constraint.lowerBound}") + } + constraint.upperBound?.let { + add("it <= ${constraint.upperBound}") + } + } + typeReference.sizeConstraint?.let { constraint -> + val property = if (typeReference.type is StringType) "length" else "size" + constraint.lowerBound?.let { + add("it.${property} >= ${constraint.lowerBound}") + } + constraint.upperBound?.let { + add("it.${property} <= ${constraint.upperBound}") + } + } + typeReference.patternConstraint?.let { constraint -> + add("it.matches(Regex(\"${constraint.pattern}\"))") + } + typeReference.valueConstraint?.let { constraint -> + add("it == ${constraint.value})") + } + } + + if (conditions.isEmpty()) { + return "" + } + + return ".also { require(${conditions.joinToString(" && ")}) }" +} diff --git a/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorProviderGenerator.kt b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorProviderGenerator.kt new file mode 100644 index 00000000..19c1da95 --- /dev/null +++ b/codegen/src/main/kotlin/tools/samt/codegen/kotlin/ktor/KotlinKtorProviderGenerator.kt @@ -0,0 +1,311 @@ +package tools.samt.codegen.kotlin.ktor + +import tools.samt.api.plugin.CodegenFile +import tools.samt.api.plugin.Generator +import tools.samt.api.plugin.GeneratorParams +import tools.samt.api.types.* +import tools.samt.codegen.http.HttpTransportConfiguration +import tools.samt.codegen.kotlin.GeneratedFilePreamble +import tools.samt.codegen.kotlin.KotlinTypesGenerator +import tools.samt.codegen.kotlin.getQualifiedName + +object KotlinKtorProviderGenerator : Generator { + override val name: String = "kotlin-ktor-provider" + private const val skipKtorServer = "skipKtorServer" + + override fun generate(generatorParams: GeneratorParams): List { + generatorParams.packages.forEach { + generateMappings(it, generatorParams.options) + generatePackage(it, generatorParams.options) + } + val result = KotlinTypesGenerator.generate(generatorParams) + emittedFiles + emittedFiles.clear() + return result + } + + private val emittedFiles = mutableListOf() + + private fun generateMappings(pack: SamtPackage, options: Map) { + val packageSource = mappingFileContent(pack, options) + if (packageSource.isNotEmpty()) { + val filePath = "${pack.getQualifiedName(options).replace('.', '/')}/KtorMappings.kt" + val file = CodegenFile(filePath, packageSource) + emittedFiles.add(file) + } + } + + private fun generatePackage(pack: SamtPackage, options: Map) { + val relevantProviders = pack.providers.filter { it.transport is HttpTransportConfiguration } + if (relevantProviders.isNotEmpty()) { + if (options[skipKtorServer] != "true") { + // generate general ktor files + generateKtorServer(pack, options) + } + + // generate ktor providers + relevantProviders.forEach { provider -> + val transportConfiguration = provider.transport + check(transportConfiguration is HttpTransportConfiguration) + + val packageSource = buildString { + appendLine(GeneratedFilePreamble) + appendLine() + appendLine("package ${pack.getQualifiedName(options)}") + appendLine() + + appendProvider(provider, transportConfiguration, options) + } + + val filePath = "${pack.getQualifiedName(options).replace('.', '/')}/${provider.name}.kt" + val file = CodegenFile(filePath, packageSource) + emittedFiles.add(file) + } + } + } + + private fun generateKtorServer(pack: SamtPackage, options: Map) { + val packageSource = buildString { + appendLine(GeneratedFilePreamble) + appendLine() + appendLine("package ${pack.getQualifiedName(options)}") + appendLine() + + appendLine("import io.ktor.http.*") + appendLine("import io.ktor.serialization.kotlinx.json.*") + appendLine("import io.ktor.server.plugins.contentnegotiation.*") + appendLine("import io.ktor.server.response.*") + appendLine("import io.ktor.server.application.*") + appendLine("import io.ktor.server.request.*") + appendLine("import io.ktor.server.routing.*") + appendLine("import kotlinx.serialization.json.*") + appendLine() + + appendLine("fun Application.configureSerialization() {") + appendLine(" install(ContentNegotiation) {") + appendLine(" json()") + appendLine(" }") + appendLine(" routing {") + + for (provider in pack.providers) { + val implementedServices = provider.implements.map { ProviderInfo(it) } + appendLine(" route${provider.name}(") + for (info in implementedServices) { + provider.implements.joinToString(" */, /* ") { it.service.getQualifiedName(options) } + appendLine( + " ${info.serviceArgumentName} = TODO(\"Implement ${ + info.service.getQualifiedName( + options + ) + }\")," + ) + } + appendLine(" )") + } + + appendLine(" }") + appendLine("}") + } + + val filePath = "${pack.getQualifiedName(options).replace('.', '/')}/KtorServer.kt" + val file = CodegenFile(filePath, packageSource) + emittedFiles.add(file) + } + + data class ProviderInfo(val implements: ProvidedService) { + val service = implements.service + val serviceArgumentName = implements.service.name.replaceFirstChar { it.lowercase() } + } + + private fun StringBuilder.appendProvider( + provider: ProviderType, + transportConfiguration: HttpTransportConfiguration, + options: Map, + ) { + appendLine("import io.ktor.http.*") + appendLine("import io.ktor.serialization.kotlinx.json.*") + appendLine("import io.ktor.server.application.*") + appendLine("import io.ktor.server.plugins.contentnegotiation.*") + appendLine("import io.ktor.server.request.*") + appendLine("import io.ktor.server.response.*") + appendLine("import io.ktor.server.routing.*") + appendLine("import io.ktor.util.*") + appendLine("import kotlinx.serialization.json.*") + appendLine() + + val implementedServices = provider.implements.map { ProviderInfo(it) } + appendLine("/** Connector for SAMT provider ${provider.name} */") + appendLine("fun Routing.route${provider.name}(") + for (info in implementedServices) { + appendLine(" ${info.serviceArgumentName}: ${info.service.getQualifiedName(options)},") + } + appendLine(") {") + appendUtilities() + implementedServices.forEach { info -> + appendProviderOperations(info, transportConfiguration, options) + } + appendLine("}") + } + + private fun StringBuilder.appendUtilities() { + appendLine(" /** Utility used to convert string to JSON element */") + appendLine(" fun String.toJson() = Json.parseToJsonElement(this)") + appendLine(" /** Utility used to convert string to JSON element or null */") + appendLine(" fun String.toJsonOrNull() = Json.parseToJsonElement(this).takeUnless { it is JsonNull }") + appendLine() + } + + private fun StringBuilder.appendProviderOperations( + info: ProviderInfo, + transportConfiguration: HttpTransportConfiguration, + options: Map, + ) { + val service = info.service + appendLine(" // Handler for SAMT Service ${info.service.name}") + appendLine(" route(\"${transportConfiguration.getPath(service.name)}\") {") + info.implements.implementedOperations.forEach { operation -> + appendProviderOperation(operation, info, service, transportConfiguration, options) + } + appendLine(" }") + appendLine() + } + + private fun StringBuilder.appendProviderOperation( + operation: ServiceOperation, + info: ProviderInfo, + service: ServiceType, + transportConfiguration: HttpTransportConfiguration, + options: Map, + ) { + when (operation) { + is RequestResponseOperation -> { + appendLine(" // Handler for SAMT operation ${operation.name}") + appendLine(" ${getKtorRoute(service, operation, transportConfiguration)} {") + + appendParsingPreamble() + + operation.parameters.forEach { parameter -> + appendParameterDecoding(service, operation, parameter, transportConfiguration, options) + } + + appendLine(" // Call user provided implementation") + val returnType = operation.returnType + if (returnType != null) { + appendLine(" val value = ${getServiceCall(info, operation)}") + appendLine() + appendLine(" // Encode response") + appendLine(" val response = ${encodeJsonElement(returnType, options)}") + appendLine() + appendLine(" // Return response with 200 OK") + appendLine(" call.respond(HttpStatusCode.OK, response)") + } else { + appendLine(" ${getServiceCall(info, operation)}") + appendLine() + appendLine(" // Return 204 No Content") + appendLine(" call.respond(HttpStatusCode.NoContent)") + } + + appendLine(" }") + appendLine() + } + is OnewayOperation -> { + appendLine(" // Handler for SAMT oneway operation ${operation.name}") + appendLine(" ${getKtorRoute(service, operation, transportConfiguration)} {") + + appendParsingPreamble() + + operation.parameters.forEach { parameter -> + appendParameterDecoding(service, operation, parameter, transportConfiguration, options) + } + + appendLine(" // Use launch to handle the request asynchronously, not waiting for the response") + appendLine(" launch {") + appendLine(" // Call user provided implementation") + appendLine(" ${getServiceCall(info, operation)}") + appendLine(" }") + appendLine() + + appendLine(" // Oneway operation always returns 204 No Content") + appendLine(" call.respond(HttpStatusCode.NoContent)") + appendLine(" }") + } + } + } + + private fun StringBuilder.appendParsingPreamble() { + appendLine(" // Parse body lazily in case no parameter is transported in the body") + appendLine(" val bodyAsText = call.receiveText()") + appendLine(" val body by lazy { bodyAsText.toJson() }") + appendLine() + } + + private fun getKtorRoute( + service: ServiceType, + operation: ServiceOperation, + transportConfiguration: HttpTransportConfiguration, + ): String { + val method = when (transportConfiguration.getMethod(service.name, operation.name)) { + HttpTransportConfiguration.HttpMethod.Get -> "get" + HttpTransportConfiguration.HttpMethod.Post -> "post" + HttpTransportConfiguration.HttpMethod.Put -> "put" + HttpTransportConfiguration.HttpMethod.Delete -> "delete" + HttpTransportConfiguration.HttpMethod.Patch -> "patch" + } + val path = transportConfiguration.getPath(service.name, operation.name) + return "${method}(\"${path}\")" + } + + private fun getServiceCall(info: ProviderInfo, operation: ServiceOperation): String { + return "${info.serviceArgumentName}.${operation.name}(${operation.parameters.joinToString { "`parameter ${it.name}`" }})" + } + + private fun StringBuilder.appendParameterDecoding( + service: ServiceType, + operation: ServiceOperation, + parameter: ServiceOperationParameter, + transportConfiguration: HttpTransportConfiguration, + options: Map, + ) { + appendLine(" // Decode parameter ${parameter.name}") + appendLine(" val `parameter ${parameter.name}` = run {") + val transportMode = transportConfiguration.getTransportMode(service.name, operation.name, parameter.name) + appendParameterDeserialization(parameter, transportMode, options) + appendLine(" }") + appendLine() + } + + private fun StringBuilder.appendParameterDeserialization( + parameter: ServiceOperationParameter, + transportMode: HttpTransportConfiguration.TransportMode, + options: Map, + ) { + appendReadParameterJsonElement(parameter, transportMode) + appendLine(" ${decodeJsonElement(parameter.type, options)}") + } + + private fun StringBuilder.appendReadParameterJsonElement( + parameter: ServiceOperationParameter, + transportMode: HttpTransportConfiguration.TransportMode, + ) { + appendLine(" // Read from ${transportMode.name.lowercase()}") + append(" val jsonElement = ") + if (parameter.type.isRuntimeOptional) { + when (transportMode) { + HttpTransportConfiguration.TransportMode.Body -> append("body.jsonObject[\"${parameter.name}\"]?.takeUnless { it is JsonNull }") + HttpTransportConfiguration.TransportMode.Query -> append("call.request.queryParameters[\"${parameter.name}\"]?.toJsonOrNull()") + HttpTransportConfiguration.TransportMode.Path -> append("call.parameters[\"${parameter.name}\"]?.toJsonOrNull()") + HttpTransportConfiguration.TransportMode.Header -> append("call.request.headers[\"${parameter.name}\"]?.toJsonOrNull()") + HttpTransportConfiguration.TransportMode.Cookie -> append("call.request.cookies[\"${parameter.name}\"]?.toJsonOrNull()") + } + append(" ?: return@run null") + } else { + when (transportMode) { + HttpTransportConfiguration.TransportMode.Body -> append("body.jsonObject[\"${parameter.name}\"]!!") + HttpTransportConfiguration.TransportMode.Query -> append("call.request.queryParameters[\"${parameter.name}\"]!!.toJson()") + HttpTransportConfiguration.TransportMode.Path -> append("call.parameters[\"${parameter.name}\"]!!.toJson()") + HttpTransportConfiguration.TransportMode.Header -> append("call.request.headers[\"${parameter.name}\"]!!.toJson()") + HttpTransportConfiguration.TransportMode.Cookie -> append("call.request.cookies[\"${parameter.name}\"]!!.toJson()") + } + } + appendLine() + } +} diff --git a/codegen/src/test/kotlin/tools/samt/codegen/CodegenTest.kt b/codegen/src/test/kotlin/tools/samt/codegen/CodegenTest.kt new file mode 100644 index 00000000..1d76b6ff --- /dev/null +++ b/codegen/src/test/kotlin/tools/samt/codegen/CodegenTest.kt @@ -0,0 +1,58 @@ +package tools.samt.codegen + +import tools.samt.api.plugin.CodegenFile +import tools.samt.common.DiagnosticController +import tools.samt.common.collectSamtFiles +import tools.samt.common.readSamtSource +import tools.samt.config.SamtConfigurationParser +import tools.samt.lexer.Lexer +import tools.samt.parser.Parser +import tools.samt.semantic.SemanticModel +import java.net.URI +import kotlin.io.path.Path +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse + +class CodegenTest { + private val testDirectory = Path("src/test/resources/generator-test-model") + + @Test + fun `correctly compiles test model`() { + val controller = DiagnosticController(URI("file:///tmp")) + + val configuration = SamtConfigurationParser.parseConfiguration(testDirectory.resolve("samt.yaml")) + val sourceFiles = collectSamtFiles(configuration.source.toUri()).readSamtSource(controller) + + assertFalse(controller.hasErrors()) + + // attempt to parse each source file into an AST + val fileNodes = buildList { + for (source in sourceFiles) { + val context = controller.getOrCreateContext(source) + val tokenStream = Lexer.scan(source.content.reader(), context) + + add(Parser.parse(source, tokenStream, context)) + } + } + + assertFalse(controller.hasErrors()) + + // build up the semantic model from the AST + val model = SemanticModel.build(fileNodes, controller) + + assertFalse(controller.hasErrors()) + + val actualFiles = mutableListOf() + for (generator in configuration.generators) { + actualFiles += Codegen.generate(model, generator, controller).map { it.copy(filepath = generator.output.resolve(it.filepath).toString()) } + } + + val expectedFiles = testDirectory.toFile().walkTopDown().filter { it.isFile && it.extension == "kt" }.toList() + + val expected = expectedFiles.associate { it.toPath().normalize() to it.readText().replace("\r\n", "\n") }.toSortedMap() + val actual = actualFiles.associate { Path(it.filepath).normalize() to it.source.replace("\r\n", "\n") }.toSortedMap() + + assertEquals(expected, actual) + } +} diff --git a/codegen/src/test/kotlin/tools/samt/codegen/http/HttpTransportTest.kt b/codegen/src/test/kotlin/tools/samt/codegen/http/HttpTransportTest.kt new file mode 100644 index 00000000..4ac7acf3 --- /dev/null +++ b/codegen/src/test/kotlin/tools/samt/codegen/http/HttpTransportTest.kt @@ -0,0 +1,302 @@ +package tools.samt.codegen.http + +import tools.samt.api.plugin.TransportConfiguration +import tools.samt.codegen.PublicApiMapper +import tools.samt.common.DiagnosticController +import tools.samt.common.SourceFile +import tools.samt.lexer.Lexer +import tools.samt.parser.Parser +import tools.samt.semantic.SemanticModel +import java.net.URI +import kotlin.test.* + +class HttpTransportTest { + private val diagnosticController = DiagnosticController(URI("file:///tmp")) + + @BeforeTest + fun setup() { + diagnosticController.contexts.clear() + diagnosticController.globalMessages.clear() + } + + @Test + fun `default configuration return default values for operations`() { + val config = HttpTransportConfigurationParser.default() + assertEquals(HttpTransportConfiguration.SerializationMode.Json, config.serializationMode) + assertEquals(emptyList(), config.services) + assertEquals(HttpTransportConfiguration.HttpMethod.Post, config.getMethod("service", "operation")) + assertEquals("", config.getPath("service")) + assertEquals("/operation", config.getPath("service", "operation")) + assertEquals(HttpTransportConfiguration.TransportMode.Body, config.getTransportMode("service", "operation", "parameter")) + } + + @Test + fun `correctly parses complex example`() { + val source = """ + package tools.samt.greeter + + typealias ID = String? (1..50) + + record Greeting { + message: String (0..128) + } + + enum GreetingType { + HELLO, + HI, + HEY + } + + service Greeter { + greet(id: ID, + name: String (1..50), + type: GreetingType, + reference: Greeting + ): Greeting + greetAll(names: List): Map + get() + put() + oneway delete() + patch() + default() + } + + provide GreeterEndpoint { + implements Greeter + + transport http { + operations: { + Greeter: { + greet: "POST /greet/{id} {name in header} {type in cookie}", + greetAll: "GET /greet/all {names in query}", + get: "GET /", + put: "PUT /", + delete: "DELETE /", + patch: "PATCH /" + } + } + } + } + """.trimIndent() + + val transport = parseAndCheck(source to emptyList()) + assertIs(transport) + + assertEquals(HttpTransportConfiguration.SerializationMode.Json, transport.serializationMode) + assertEquals(listOf("Greeter"), transport.services.map { it.name }) + + assertEquals(HttpTransportConfiguration.HttpMethod.Post, transport.getMethod("Greeter", "greet")) + assertEquals("/greet/{id}", transport.getPath("Greeter", "greet")) + assertEquals(HttpTransportConfiguration.TransportMode.Path, transport.getTransportMode("Greeter", "greet", "id")) + assertEquals(HttpTransportConfiguration.TransportMode.Header, transport.getTransportMode("Greeter", "greet", "name")) + assertEquals(HttpTransportConfiguration.TransportMode.Cookie, transport.getTransportMode("Greeter", "greet", "type")) + assertEquals(HttpTransportConfiguration.TransportMode.Body, transport.getTransportMode("Greeter", "greet", "reference")) + + assertEquals(HttpTransportConfiguration.HttpMethod.Get, transport.getMethod("Greeter", "greetAll")) + assertEquals("/greet/all", transport.getPath("Greeter", "greetAll")) + assertEquals(HttpTransportConfiguration.TransportMode.Query, transport.getTransportMode("Greeter", "greetAll", "names")) + + assertEquals(HttpTransportConfiguration.HttpMethod.Get, transport.getMethod("Greeter", "get")) + assertEquals("/", transport.getPath("Greeter", "get")) + assertEquals(HttpTransportConfiguration.HttpMethod.Put, transport.getMethod("Greeter", "put")) + assertEquals("/", transport.getPath("Greeter", "put")) + assertEquals(HttpTransportConfiguration.HttpMethod.Delete, transport.getMethod("Greeter", "delete")) + assertEquals("/", transport.getPath("Greeter", "delete")) + assertEquals(HttpTransportConfiguration.HttpMethod.Patch, transport.getMethod("Greeter", "patch")) + assertEquals("/", transport.getPath("Greeter", "patch")) + assertEquals(HttpTransportConfiguration.HttpMethod.Post, transport.getMethod("Greeter", "default")) + assertEquals("/default", transport.getPath("Greeter", "default")) + } + + @Test + fun `fails for invalid HTTP method`() { + val source = """ + package tools.samt.greeter + + service Greeter { + greet(name: String): String + } + + provide GreeterEndpoint { + implements Greeter + + transport http { + operations: { + Greeter: { + greet: "YEET /greet" + } + } + } + } + """.trimIndent() + + parseAndCheck(source to listOf("Error: Invalid http method 'YEET'")) + } + + @Test + fun `fails for invalid parameter binding`() { + val source = """ + package tools.samt.greeter + + service Greeter { + greet(name: String): String + foo() + } + + provide GreeterEndpoint { + implements Greeter + + transport http { + operations: { + Greeter: { + greet: "POST /greet {name in yeet}", + foo: "POST /foo {name in header}" + } + } + } + } + """.trimIndent() + + parseAndCheck(source to listOf("Error: Invalid transport mode 'yeet'", "Error: Parameter 'name' not found in operation 'foo'")) + } + + @Test + fun `fails for invalid path parameter binding`() { + val source = """ + package tools.samt.greeter + + service Greeter { + greet(name: String): String + foo() + } + + provide GreeterEndpoint { + implements Greeter + + transport http { + operations: { + Greeter: { + greet: "POST /greet/{}/me", + foo: "POST /foo/{name}" + } + } + } + } + """.trimIndent() + + parseAndCheck(source to listOf("Error: Expected parameter name between curly braces in '/greet/{}/me'", "Error: Path parameter 'name' not found in operation 'foo'")) + } + + @Test + fun `fails for invalid syntax`() { + val source = """ + package tools.samt.greeter + + service Greeter { + greet(name: String): String + } + + provide GreeterEndpoint { + implements Greeter + + transport http { + operations: { + Greeter: { + greet: "POST /greet {header:name}" + } + } + } + } + """.trimIndent() + + parseAndCheck(source to listOf("Error: Invalid operation config for 'greet', expected ' '. A valid example: 'POST /greet {parameter1, parameter2 in query}'")) + } + + @Test + fun `fails for non-existent service`() { + val source = """ + package tools.samt.greeter + + service Greeter { + greet(name: String): String + } + + service Foo { + bar() + } + + provide GreeterEndpoint { + implements Greeter + + transport http { + operations: { + Foo: { + bar: "PUT /bar" + } + } + } + } + """.trimIndent() + + parseAndCheck(source to listOf("Error: No service with name 'Foo' found in provider 'GreeterEndpoint'")) + } + + @Test + fun `fails for non-implemented operation`() { + val source = """ + package tools.samt.greeter + + service Greeter { + greet(name: String): String + bar() + } + + provide GreeterEndpoint { + implements Greeter { greet } + + transport http { + operations: { + Greeter: { + bar: "PUT /bar" + } + } + } + } + """.trimIndent() + + parseAndCheck(source to listOf("Error: No operation with name 'bar' found in service 'Greeter' of provider 'GreeterEndpoint'")) + } + + private fun parseAndCheck( + vararg sourceAndExpectedMessages: Pair>, + ): TransportConfiguration { + val fileTree = sourceAndExpectedMessages.mapIndexed { index, (source) -> + val filePath = URI("file:///tmp/HttpTransportTest-${index}.samt") + val sourceFile = SourceFile(filePath, source) + val parseContext = diagnosticController.getOrCreateContext(sourceFile) + val stream = Lexer.scan(source.reader(), parseContext) + val fileTree = Parser.parse(sourceFile, stream, parseContext) + assertFalse(parseContext.hasErrors(), "Expected no parse errors, but had errors: ${parseContext.messages}}") + fileTree + } + + val parseMessageCount = diagnosticController.contexts.associate { it.source.content to it.messages.size } + + val semanticModel = SemanticModel.build(fileTree, diagnosticController) + + val publicApiMapper = PublicApiMapper(listOf(HttpTransportConfigurationParser), diagnosticController) + + val transport = semanticModel.global.allSubPackages.map { publicApiMapper.toPublicApi(it) }.flatMap { it.providers }.single().transport + + for ((source, expectedMessages) in sourceAndExpectedMessages) { + val messages = diagnosticController.contexts + .first { it.source.content == source } + .messages + .drop(parseMessageCount.getValue(source)) + .map { "${it.severity}: ${it.message}" } + assertEquals(expectedMessages, messages) + } + + return transport + } +} diff --git a/codegen/src/test/resources/generator-test-model/.samtrc.yaml b/codegen/src/test/resources/generator-test-model/.samtrc.yaml new file mode 100644 index 00000000..c0d46ba1 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/.samtrc.yaml @@ -0,0 +1 @@ +extends: recommended diff --git a/codegen/src/test/resources/generator-test-model/README.md b/codegen/src/test/resources/generator-test-model/README.md new file mode 100644 index 00000000..26d76652 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/README.md @@ -0,0 +1,4 @@ +# Test Project + +This is a test project for the code generator. +It ensures that the code generator produces the expected output for the given input. diff --git a/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/consumer/Consumer.kt b/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/consumer/Consumer.kt new file mode 100644 index 00000000..ea986154 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/consumer/Consumer.kt @@ -0,0 +1,158 @@ +@file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + +/* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ + +package tools.samt.client.generated.consumer + +import io.ktor.client.* +import io.ktor.client.engine.cio.* +import io.ktor.client.plugins.contentnegotiation.* +import io.ktor.client.request.* +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.serialization.kotlinx.json.* +import io.ktor.util.* +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.* +import kotlinx.coroutines.* + +class GreeterEndpointImpl(private val baseUrl: String) : tools.samt.client.generated.greeter.Greeter { + private val client = HttpClient(CIO) { + install(ContentNegotiation) { + json() + } + } + + /** Used to launch oneway operations asynchronously */ + private val onewayScope = CoroutineScope(Dispatchers.IO) + + override fun greet(id: tools.samt.client.generated.greeter.ID, name: String, type: tools.samt.client.generated.greeter.GreetingType?): tools.samt.client.generated.greeter.Greeting = runBlocking { + // Make actual network call + val `client response` = client.request(this@GreeterEndpointImpl.baseUrl) { + url { + // Construct path and encode path parameters + appendPathSegments("greet", encodeSlash = true) + appendPathSegments(name, encodeSlash = true) + + // Encode query parameters + this.parameters.append("type", (type?.let { type -> tools.samt.client.generated.greeter.`encode GreetingType`(type) } ?: JsonNull).toString()) + } + contentType(ContentType.Application.Json) + this.method = HttpMethod.Post + header("id", id?.let { id -> tools.samt.client.generated.greeter.`encode ID`(id) } ?: JsonNull) + setBody( + buildJsonObject { + } + ) + } + check(`client response`.status.isSuccess()) { "greet failed with status ${`client response`.status}" } + val bodyAsText = `client response`.bodyAsText() + val jsonElement = Json.parseToJsonElement(bodyAsText) + + tools.samt.client.generated.greeter.`decode Greeting`(jsonElement) + } + + override fun greetAll(names: List): Map = runBlocking { + // Make actual network call + val `client response` = client.request(this@GreeterEndpointImpl.baseUrl) { + url { + // Construct path and encode path parameters + appendPathSegments("greet", encodeSlash = true) + appendPathSegments("all", encodeSlash = true) + + // Encode query parameters + this.parameters.append("names", (JsonArray(names.map { it?.let { it -> JsonPrimitive(it.also { require(it.length >= 1 && it.length <= 50) }) } ?: JsonNull })).toString()) + } + contentType(ContentType.Application.Json) + this.method = HttpMethod.Get + setBody( + buildJsonObject { + } + ) + } + check(`client response`.status.isSuccess()) { "greetAll failed with status ${`client response`.status}" } + val bodyAsText = `client response`.bodyAsText() + val jsonElement = Json.parseToJsonElement(bodyAsText) + + jsonElement.jsonObject.mapValues { (_, value) -> value.takeUnless { it is JsonNull }?.let { value -> tools.samt.client.generated.greeter.`decode Greeting`(value) } } + } + + override fun greeting(who: tools.samt.client.generated.greeter.Person): String = runBlocking { + // Make actual network call + val `client response` = client.request(this@GreeterEndpointImpl.baseUrl) { + url { + // Construct path and encode path parameters + appendPathSegments("greeting", encodeSlash = true) + + // Encode query parameters + } + contentType(ContentType.Application.Json) + this.method = HttpMethod.Post + setBody( + buildJsonObject { + put("who", tools.samt.client.generated.greeter.`encode Person`(who)) + } + ) + } + check(`client response`.status.isSuccess()) { "greeting failed with status ${`client response`.status}" } + val bodyAsText = `client response`.bodyAsText() + val jsonElement = Json.parseToJsonElement(bodyAsText) + + jsonElement.jsonPrimitive.content.also { require(it.length >= 1 && it.length <= 100) } + } + + override fun allTheTypes(long: Long, float: Float, double: Double, decimal: java.math.BigDecimal, boolean: Boolean, date: java.time.LocalDate, dateTime: java.time.LocalDateTime, duration: java.time.Duration): Unit = runBlocking { + // Make actual network call + val `client response` = client.request(this@GreeterEndpointImpl.baseUrl) { + url { + // Construct path and encode path parameters + appendPathSegments("allTheTypes", encodeSlash = true) + + // Encode query parameters + } + contentType(ContentType.Application.Json) + this.method = HttpMethod.Post + setBody( + buildJsonObject { + put("long", JsonPrimitive(long)) + put("float", JsonPrimitive(float)) + put("double", JsonPrimitive(double)) + put("decimal", JsonPrimitive(decimal.toPlainString())) + put("boolean", JsonPrimitive(boolean)) + put("date", JsonPrimitive(date.toString())) + put("dateTime", JsonPrimitive(dateTime.toString())) + put("duration", JsonPrimitive(duration.toString())) + } + ) + } + check(`client response`.status.isSuccess()) { "allTheTypes failed with status ${`client response`.status}" } + } + + override fun fireAndForget(deleteWorld: Boolean): Unit { + onewayScope.launch { + // Make actual network call + val `client response` = client.request(this@GreeterEndpointImpl.baseUrl) { + url { + // Construct path and encode path parameters + appendPathSegments("world", encodeSlash = true) + + // Encode query parameters + } + contentType(ContentType.Application.Json) + this.method = HttpMethod.Put + cookie("deleteWorld", (JsonPrimitive(deleteWorld.also { require(it == true)) })).toString()) + setBody( + buildJsonObject { + } + ) + } + check(`client response`.status.isSuccess()) { "fireAndForget failed with status ${`client response`.status}" } + } + } + + override suspend fun legacy(): Unit + = error("Not used in SAMT consumer and therefore not generated") +} diff --git a/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/greeter/KtorMappings.kt b/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/greeter/KtorMappings.kt new file mode 100644 index 00000000..abb943d1 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/greeter/KtorMappings.kt @@ -0,0 +1,112 @@ +@file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + +/* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ + +package tools.samt.client.generated.greeter + +import io.ktor.util.* +import kotlinx.serialization.json.* + +/** Encode and validate record tools.samt.greeter.Greeting to JSON */ +fun `encode Greeting`(record: tools.samt.client.generated.greeter.Greeting): JsonElement { + // Encode field message + val `field message` = run { + val value = record.message + JsonPrimitive(value.also { require(it.length >= 0 && it.length <= 128) }) + } + // Create JSON for tools.samt.greeter.Greeting + return buildJsonObject { + put("message", `field message`) + } +} +/** Decode and validate record tools.samt.greeter.Greeting from JSON */ +fun `decode Greeting`(json: JsonElement): tools.samt.client.generated.greeter.Greeting { + // Decode field message + val `field message` = run { + val jsonElement = json.jsonObject["message"]!! + jsonElement.jsonPrimitive.content.also { require(it.length >= 0 && it.length <= 128) } + } + // Create record tools.samt.greeter.Greeting + return tools.samt.client.generated.greeter.Greeting( + message = `field message`, + ) +} + +/** Encode and validate record tools.samt.greeter.Person to JSON */ +fun `encode Person`(record: tools.samt.client.generated.greeter.Person): JsonElement { + // Encode field id + val `field id` = run { + val value = record.id + value?.let { value -> tools.samt.client.generated.greeter.`encode ID`(value) } ?: JsonNull + } + // Encode field name + val `field name` = run { + val value = record.name + JsonPrimitive(value) + } + // Encode field age + val `field age` = run { + val value = record.age + JsonPrimitive(value.also { require(it >= 1) }) + } + // Create JSON for tools.samt.greeter.Person + return buildJsonObject { + put("id", `field id`) + put("name", `field name`) + put("age", `field age`) + } +} +/** Decode and validate record tools.samt.greeter.Person from JSON */ +fun `decode Person`(json: JsonElement): tools.samt.client.generated.greeter.Person { + // Decode field id + val `field id` = run { + val jsonElement = json.jsonObject["id"] ?: JsonNull + tools.samt.client.generated.greeter.`decode ID`(jsonElement) + } + // Decode field name + val `field name` = run { + val jsonElement = json.jsonObject["name"]!! + jsonElement.jsonPrimitive.content + } + // Decode field age + val `field age` = run { + val jsonElement = json.jsonObject["age"]!! + jsonElement.jsonPrimitive.int.also { require(it >= 1) } + } + // Create record tools.samt.greeter.Person + return tools.samt.client.generated.greeter.Person( + id = `field id`, + name = `field name`, + age = `field age`, + ) +} + +/** Encode enum tools.samt.greeter.GreetingType to JSON */ +fun `encode GreetingType`(value: tools.samt.client.generated.greeter.GreetingType?): JsonElement = when(value) { + null -> JsonNull + tools.samt.client.generated.greeter.GreetingType.HELLO -> JsonPrimitive("HELLO") + tools.samt.client.generated.greeter.GreetingType.HI -> JsonPrimitive("HI") + tools.samt.client.generated.greeter.GreetingType.HEY -> JsonPrimitive("HEY") + tools.samt.client.generated.greeter.GreetingType.FAILED_TO_PARSE -> error("Cannot encode FAILED_TO_PARSE value") +} +/** Decode enum tools.samt.greeter.GreetingType from JSON */ +fun `decode GreetingType`(json: JsonElement): tools.samt.client.generated.greeter.GreetingType = when(json.jsonPrimitive.content) { + "HELLO" -> tools.samt.client.generated.greeter.GreetingType.HELLO + "HI" -> tools.samt.client.generated.greeter.GreetingType.HI + "HEY" -> tools.samt.client.generated.greeter.GreetingType.HEY + // Value not found in enum tools.samt.greeter.GreetingType + else -> tools.samt.client.generated.greeter.GreetingType.FAILED_TO_PARSE +} + +/** Encode alias tools.samt.greeter.ID to JSON */ +fun `encode ID`(value: tools.samt.client.generated.greeter.ID): JsonElement = + value?.let { value -> JsonPrimitive(value.also { require(it.length >= 1 && it.length <= 50) }) } ?: JsonNull +/** Decode alias tools.samt.greeter.ID from JSON */ +fun `decode ID`(json: JsonElement): String? { + if (json is JsonNull) return null + return json.jsonPrimitive.content.also { require(it.length >= 1 && it.length <= 50) } +} + diff --git a/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/greeter/Types.kt b/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/greeter/Types.kt new file mode 100644 index 00000000..eeeab92b --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/out/ktor-client/tools/samt/client/generated/greeter/Types.kt @@ -0,0 +1,55 @@ +@file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + +/* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ + +package tools.samt.client.generated.greeter + +data class Greeting( + val message: String, +) + +data class Person( + val id: tools.samt.client.generated.greeter.ID = null, + val name: String, + val age: Int, +) + +enum class GreetingType { + /** Default value used when the enum could not be parsed */ + FAILED_TO_PARSE, + HELLO, + HI, + HEY, +} +typealias ID = String? +interface Greeter { + fun greet( + id: tools.samt.client.generated.greeter.ID = null, + name: String, + type: tools.samt.client.generated.greeter.GreetingType? = null, + ): tools.samt.client.generated.greeter.Greeting + fun greetAll( + names: List, + ): Map + fun greeting( + who: tools.samt.client.generated.greeter.Person, + ): String + fun allTheTypes( + long: Long, + float: Float, + double: Double, + decimal: java.math.BigDecimal, + boolean: Boolean, + date: java.time.LocalDate, + dateTime: java.time.LocalDateTime, + duration: java.time.Duration, + ) + fun fireAndForget( + deleteWorld: Boolean, + ) + suspend fun legacy( + ) +} diff --git a/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/GreeterEndpoint.kt b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/GreeterEndpoint.kt new file mode 100644 index 00000000..5efeaf86 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/GreeterEndpoint.kt @@ -0,0 +1,220 @@ +@file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + +/* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ + +package tools.samt.server.generated.greeter + +import io.ktor.http.* +import io.ktor.serialization.kotlinx.json.* +import io.ktor.server.application.* +import io.ktor.server.plugins.contentnegotiation.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.util.* +import kotlinx.serialization.json.* + +/** Connector for SAMT provider GreeterEndpoint */ +fun Routing.routeGreeterEndpoint( + greeter: tools.samt.server.generated.greeter.Greeter, +) { + /** Utility used to convert string to JSON element */ + fun String.toJson() = Json.parseToJsonElement(this) + /** Utility used to convert string to JSON element or null */ + fun String.toJsonOrNull() = Json.parseToJsonElement(this).takeUnless { it is JsonNull } + + // Handler for SAMT Service Greeter + route("") { + // Handler for SAMT operation greet + post("/greet/{name}") { + // Parse body lazily in case no parameter is transported in the body + val bodyAsText = call.receiveText() + val body by lazy { bodyAsText.toJson() } + + // Decode parameter id + val `parameter id` = run { + // Read from header + val jsonElement = call.request.headers["id"]?.toJsonOrNull() ?: return@run null + tools.samt.server.generated.greeter.`decode ID`(jsonElement) + } + + // Decode parameter name + val `parameter name` = run { + // Read from path + val jsonElement = call.parameters["name"]!!.toJson() + jsonElement.jsonPrimitive.content.also { require(it.length >= 1 && it.length <= 50) } + } + + // Decode parameter type + val `parameter type` = run { + // Read from query + val jsonElement = call.request.queryParameters["type"]?.toJsonOrNull() ?: return@run null + tools.samt.server.generated.greeter.`decode GreetingType`(jsonElement) + } + + // Call user provided implementation + val value = greeter.greet(`parameter id`, `parameter name`, `parameter type`) + + // Encode response + val response = tools.samt.server.generated.greeter.`encode Greeting`(value) + + // Return response with 200 OK + call.respond(HttpStatusCode.OK, response) + } + + // Handler for SAMT operation greetAll + get("/greet/all") { + // Parse body lazily in case no parameter is transported in the body + val bodyAsText = call.receiveText() + val body by lazy { bodyAsText.toJson() } + + // Decode parameter names + val `parameter names` = run { + // Read from query + val jsonElement = call.request.queryParameters["names"]!!.toJson() + jsonElement.jsonArray.map { it.takeUnless { it is JsonNull }?.let { it.jsonPrimitive.content.also { require(it.length >= 1 && it.length <= 50) } } } + } + + // Call user provided implementation + val value = greeter.greetAll(`parameter names`) + + // Encode response + val response = JsonObject(value.mapValues { (_, value) -> value?.let { value -> tools.samt.server.generated.greeter.`encode Greeting`(value) } ?: JsonNull }) + + // Return response with 200 OK + call.respond(HttpStatusCode.OK, response) + } + + // Handler for SAMT operation greeting + post("/greeting") { + // Parse body lazily in case no parameter is transported in the body + val bodyAsText = call.receiveText() + val body by lazy { bodyAsText.toJson() } + + // Decode parameter who + val `parameter who` = run { + // Read from body + val jsonElement = body.jsonObject["who"]!! + tools.samt.server.generated.greeter.`decode Person`(jsonElement) + } + + // Call user provided implementation + val value = greeter.greeting(`parameter who`) + + // Encode response + val response = JsonPrimitive(value.also { require(it.length >= 1 && it.length <= 100) }) + + // Return response with 200 OK + call.respond(HttpStatusCode.OK, response) + } + + // Handler for SAMT operation allTheTypes + post("/allTheTypes") { + // Parse body lazily in case no parameter is transported in the body + val bodyAsText = call.receiveText() + val body by lazy { bodyAsText.toJson() } + + // Decode parameter long + val `parameter long` = run { + // Read from body + val jsonElement = body.jsonObject["long"]!! + jsonElement.jsonPrimitive.long + } + + // Decode parameter float + val `parameter float` = run { + // Read from body + val jsonElement = body.jsonObject["float"]!! + jsonElement.jsonPrimitive.float + } + + // Decode parameter double + val `parameter double` = run { + // Read from body + val jsonElement = body.jsonObject["double"]!! + jsonElement.jsonPrimitive.double + } + + // Decode parameter decimal + val `parameter decimal` = run { + // Read from body + val jsonElement = body.jsonObject["decimal"]!! + jsonElement.jsonPrimitive.content.let { java.math.BigDecimal(it) } + } + + // Decode parameter boolean + val `parameter boolean` = run { + // Read from body + val jsonElement = body.jsonObject["boolean"]!! + jsonElement.jsonPrimitive.boolean + } + + // Decode parameter date + val `parameter date` = run { + // Read from body + val jsonElement = body.jsonObject["date"]!! + jsonElement.jsonPrimitive.content.let { java.time.LocalDate.parse(it) } + } + + // Decode parameter dateTime + val `parameter dateTime` = run { + // Read from body + val jsonElement = body.jsonObject["dateTime"]!! + jsonElement.jsonPrimitive.content.let { java.time.LocalDateTime.parse(it) } + } + + // Decode parameter duration + val `parameter duration` = run { + // Read from body + val jsonElement = body.jsonObject["duration"]!! + jsonElement.jsonPrimitive.content.let { java.time.Duration.parse(it) } + } + + // Call user provided implementation + greeter.allTheTypes(`parameter long`, `parameter float`, `parameter double`, `parameter decimal`, `parameter boolean`, `parameter date`, `parameter dateTime`, `parameter duration`) + + // Return 204 No Content + call.respond(HttpStatusCode.NoContent) + } + + // Handler for SAMT oneway operation fireAndForget + put("/world") { + // Parse body lazily in case no parameter is transported in the body + val bodyAsText = call.receiveText() + val body by lazy { bodyAsText.toJson() } + + // Decode parameter deleteWorld + val `parameter deleteWorld` = run { + // Read from cookie + val jsonElement = call.request.cookies["deleteWorld"]!!.toJson() + jsonElement.jsonPrimitive.boolean.also { require(it == true)) } + } + + // Use launch to handle the request asynchronously, not waiting for the response + launch { + // Call user provided implementation + greeter.fireAndForget(`parameter deleteWorld`) + } + + // Oneway operation always returns 204 No Content + call.respond(HttpStatusCode.NoContent) + } + // Handler for SAMT operation legacy + post("/legacy") { + // Parse body lazily in case no parameter is transported in the body + val bodyAsText = call.receiveText() + val body by lazy { bodyAsText.toJson() } + + // Call user provided implementation + greeter.legacy() + + // Return 204 No Content + call.respond(HttpStatusCode.NoContent) + } + + } + +} diff --git a/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/KtorMappings.kt b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/KtorMappings.kt new file mode 100644 index 00000000..1ae105d4 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/KtorMappings.kt @@ -0,0 +1,112 @@ +@file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + +/* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ + +package tools.samt.server.generated.greeter + +import io.ktor.util.* +import kotlinx.serialization.json.* + +/** Encode and validate record tools.samt.greeter.Greeting to JSON */ +fun `encode Greeting`(record: tools.samt.server.generated.greeter.Greeting): JsonElement { + // Encode field message + val `field message` = run { + val value = record.message + JsonPrimitive(value.also { require(it.length >= 0 && it.length <= 128) }) + } + // Create JSON for tools.samt.greeter.Greeting + return buildJsonObject { + put("message", `field message`) + } +} +/** Decode and validate record tools.samt.greeter.Greeting from JSON */ +fun `decode Greeting`(json: JsonElement): tools.samt.server.generated.greeter.Greeting { + // Decode field message + val `field message` = run { + val jsonElement = json.jsonObject["message"]!! + jsonElement.jsonPrimitive.content.also { require(it.length >= 0 && it.length <= 128) } + } + // Create record tools.samt.greeter.Greeting + return tools.samt.server.generated.greeter.Greeting( + message = `field message`, + ) +} + +/** Encode and validate record tools.samt.greeter.Person to JSON */ +fun `encode Person`(record: tools.samt.server.generated.greeter.Person): JsonElement { + // Encode field id + val `field id` = run { + val value = record.id + value?.let { value -> tools.samt.server.generated.greeter.`encode ID`(value) } ?: JsonNull + } + // Encode field name + val `field name` = run { + val value = record.name + JsonPrimitive(value) + } + // Encode field age + val `field age` = run { + val value = record.age + JsonPrimitive(value.also { require(it >= 1) }) + } + // Create JSON for tools.samt.greeter.Person + return buildJsonObject { + put("id", `field id`) + put("name", `field name`) + put("age", `field age`) + } +} +/** Decode and validate record tools.samt.greeter.Person from JSON */ +fun `decode Person`(json: JsonElement): tools.samt.server.generated.greeter.Person { + // Decode field id + val `field id` = run { + val jsonElement = json.jsonObject["id"] ?: JsonNull + tools.samt.server.generated.greeter.`decode ID`(jsonElement) + } + // Decode field name + val `field name` = run { + val jsonElement = json.jsonObject["name"]!! + jsonElement.jsonPrimitive.content + } + // Decode field age + val `field age` = run { + val jsonElement = json.jsonObject["age"]!! + jsonElement.jsonPrimitive.int.also { require(it >= 1) } + } + // Create record tools.samt.greeter.Person + return tools.samt.server.generated.greeter.Person( + id = `field id`, + name = `field name`, + age = `field age`, + ) +} + +/** Encode enum tools.samt.greeter.GreetingType to JSON */ +fun `encode GreetingType`(value: tools.samt.server.generated.greeter.GreetingType?): JsonElement = when(value) { + null -> JsonNull + tools.samt.server.generated.greeter.GreetingType.HELLO -> JsonPrimitive("HELLO") + tools.samt.server.generated.greeter.GreetingType.HI -> JsonPrimitive("HI") + tools.samt.server.generated.greeter.GreetingType.HEY -> JsonPrimitive("HEY") + tools.samt.server.generated.greeter.GreetingType.FAILED_TO_PARSE -> error("Cannot encode FAILED_TO_PARSE value") +} +/** Decode enum tools.samt.greeter.GreetingType from JSON */ +fun `decode GreetingType`(json: JsonElement): tools.samt.server.generated.greeter.GreetingType = when(json.jsonPrimitive.content) { + "HELLO" -> tools.samt.server.generated.greeter.GreetingType.HELLO + "HI" -> tools.samt.server.generated.greeter.GreetingType.HI + "HEY" -> tools.samt.server.generated.greeter.GreetingType.HEY + // Value not found in enum tools.samt.greeter.GreetingType + else -> tools.samt.server.generated.greeter.GreetingType.FAILED_TO_PARSE +} + +/** Encode alias tools.samt.greeter.ID to JSON */ +fun `encode ID`(value: tools.samt.server.generated.greeter.ID): JsonElement = + value?.let { value -> JsonPrimitive(value.also { require(it.length >= 1 && it.length <= 50) }) } ?: JsonNull +/** Decode alias tools.samt.greeter.ID from JSON */ +fun `decode ID`(json: JsonElement): String? { + if (json is JsonNull) return null + return json.jsonPrimitive.content.also { require(it.length >= 1 && it.length <= 50) } +} + diff --git a/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/KtorServer.kt b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/KtorServer.kt new file mode 100644 index 00000000..0544f452 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/KtorServer.kt @@ -0,0 +1,28 @@ +@file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + +/* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ + +package tools.samt.server.generated.greeter + +import io.ktor.http.* +import io.ktor.serialization.kotlinx.json.* +import io.ktor.server.plugins.contentnegotiation.* +import io.ktor.server.response.* +import io.ktor.server.application.* +import io.ktor.server.request.* +import io.ktor.server.routing.* +import kotlinx.serialization.json.* + +fun Application.configureSerialization() { + install(ContentNegotiation) { + json() + } + routing { + routeGreeterEndpoint( + greeter = TODO("Implement tools.samt.server.generated.greeter.Greeter"), + ) + } +} diff --git a/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/Types.kt b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/Types.kt new file mode 100644 index 00000000..59a8973e --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/out/ktor-server/tools/samt/server/generated/greeter/Types.kt @@ -0,0 +1,55 @@ +@file:Suppress("RemoveRedundantQualifierName", "unused", "UnusedImport", "LocalVariableName", "FunctionName", "ConvertTwoComparisonsToRangeCheck", "ReplaceSizeCheckWithIsNotEmpty", "NAME_SHADOWING", "UNUSED_VARIABLE", "NestedLambdaShadowedImplicitParameter", "KotlinRedundantDiagnosticSuppress") + +/* + * This file is generated by SAMT, manual changes will be overwritten. + * Visit the SAMT GitHub for more details: https://github.com/samtkit/core + */ + +package tools.samt.server.generated.greeter + +data class Greeting( + val message: String, +) + +data class Person( + val id: tools.samt.server.generated.greeter.ID = null, + val name: String, + val age: Int, +) + +enum class GreetingType { + /** Default value used when the enum could not be parsed */ + FAILED_TO_PARSE, + HELLO, + HI, + HEY, +} +typealias ID = String? +interface Greeter { + fun greet( + id: tools.samt.server.generated.greeter.ID = null, + name: String, + type: tools.samt.server.generated.greeter.GreetingType? = null, + ): tools.samt.server.generated.greeter.Greeting + fun greetAll( + names: List, + ): Map + fun greeting( + who: tools.samt.server.generated.greeter.Person, + ): String + fun allTheTypes( + long: Long, + float: Float, + double: Double, + decimal: java.math.BigDecimal, + boolean: Boolean, + date: java.time.LocalDate, + dateTime: java.time.LocalDateTime, + duration: java.time.Duration, + ) + fun fireAndForget( + deleteWorld: Boolean, + ) + suspend fun legacy( + ) +} diff --git a/codegen/src/test/resources/generator-test-model/samt.yaml b/codegen/src/test/resources/generator-test-model/samt.yaml new file mode 100644 index 00000000..16047642 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/samt.yaml @@ -0,0 +1,11 @@ +generators: + - name: kotlin-ktor-provider + output: ./out/ktor-server/ + options: + removePrefixFromSamtPackage: tools.samt + addPrefixToKotlinPackage: tools.samt.server.generated + - name: kotlin-ktor-consumer + output: ./out/ktor-client/ + options: + removePrefixFromSamtPackage: tools.samt + addPrefixToKotlinPackage: tools.samt.client.generated diff --git a/codegen/src/test/resources/generator-test-model/src/greeter-consumer.samt b/codegen/src/test/resources/generator-test-model/src/greeter-consumer.samt new file mode 100644 index 00000000..c58726ba --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/src/greeter-consumer.samt @@ -0,0 +1,9 @@ +import tools.samt.greeter.GreeterEndpoint +import tools.samt.greeter.Greeter + +// Usually belongs to another package +package tools.samt.consumer + +consume GreeterEndpoint { + uses Greeter { greet, greetAll, greeting, allTheTypes, fireAndForget } +} diff --git a/codegen/src/test/resources/generator-test-model/src/greeter-provider.samt b/codegen/src/test/resources/generator-test-model/src/greeter-provider.samt new file mode 100644 index 00000000..d1418ba1 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/src/greeter-provider.samt @@ -0,0 +1,16 @@ +package tools.samt.greeter + +provide GreeterEndpoint { + implements Greeter + + transport http { + operations: { + Greeter: { + greet: "POST /greet/{name} {id in header} {type in query}", + greetAll: "GET /greet/all {names in query}", + greeting: "POST /greeting", + fireAndForget: "PUT /world {deleteWorld in cookie}" + } + } + } +} diff --git a/codegen/src/test/resources/generator-test-model/src/greeter.samt b/codegen/src/test/resources/generator-test-model/src/greeter.samt new file mode 100644 index 00000000..d22b7b62 --- /dev/null +++ b/codegen/src/test/resources/generator-test-model/src/greeter.samt @@ -0,0 +1,46 @@ +package tools.samt.greeter + +typealias ID = String? (1..50) + +record Greeting { + message: String (0..128) +} + +record Person { + id: ID + name: String + age: Int (1..*) +} + +enum GreetingType { + HELLO, + HI, + HEY +} + +service Greeter { + greet(id: ID, + name: String (1..50), + type: GreetingType? + ): Greeting + // Nullability to verify edge-cases + greetAll(names: List): Map + greeting(who: Person): String (1..100) + + @Description("Used to test all the types") + allTheTypes( + long: Long, + float: Float, + double: Double, + decimal: Decimal, + boolean: Boolean, + date: Date, + dateTime: DateTime, + duration: Duration + ) + + oneway fireAndForget(deleteWorld: Boolean (value(true))) + + @Deprecated("Do not use anymore!") + async legacy() +} diff --git a/common/src/main/kotlin/tools/samt/common/SamtConfiguration.kt b/common/src/main/kotlin/tools/samt/common/SamtConfiguration.kt index 6ec74f5d..ba957b0e 100644 --- a/common/src/main/kotlin/tools/samt/common/SamtConfiguration.kt +++ b/common/src/main/kotlin/tools/samt/common/SamtConfiguration.kt @@ -1,7 +1,9 @@ package tools.samt.common +import java.nio.file.Path + data class SamtConfiguration( - val source: String, + val source: Path, val plugins: List, val generators: List, ) @@ -9,7 +11,7 @@ data class SamtConfiguration( sealed interface SamtPluginConfiguration data class SamtLocalPluginConfiguration( - val path: String, + val path: Path, ) : SamtPluginConfiguration data class SamtMavenPluginConfiguration( @@ -21,6 +23,6 @@ data class SamtMavenPluginConfiguration( data class SamtGeneratorConfiguration( val name: String, - val output: String, + val output: Path, val options: Map, ) diff --git a/parser/src/main/kotlin/tools/samt/parser/Nodes.kt b/parser/src/main/kotlin/tools/samt/parser/Nodes.kt index 92a197d0..18bbbc9c 100644 --- a/parser/src/main/kotlin/tools/samt/parser/Nodes.kt +++ b/parser/src/main/kotlin/tools/samt/parser/Nodes.kt @@ -1,12 +1,27 @@ package tools.samt.parser -import tools.samt.common.Location -import tools.samt.common.SourceFile +import tools.samt.common.* sealed interface Node { val location: Location } +inline fun Node.report(controller: DiagnosticController, severity: DiagnosticSeverity, block: DiagnosticMessageBuilder.() -> Unit) { + controller.getOrCreateContext(location.source).report(severity, block) +} + +inline fun Node.reportError(controller: DiagnosticController, block: DiagnosticMessageBuilder.() -> Unit) { + report(controller, DiagnosticSeverity.Error, block) +} + +inline fun Node.reportWarning(controller: DiagnosticController, block: DiagnosticMessageBuilder.() -> Unit) { + report(controller, DiagnosticSeverity.Warning, block) +} + +inline fun Node.reportInfo(controller: DiagnosticController, block: DiagnosticMessageBuilder.() -> Unit) { + report(controller, DiagnosticSeverity.Info, block) +} + sealed interface AnnotatedNode : Node { val annotations: List } diff --git a/public-api/build.gradle.kts b/public-api/build.gradle.kts new file mode 100644 index 00000000..d194b3df --- /dev/null +++ b/public-api/build.gradle.kts @@ -0,0 +1,3 @@ +plugins { + id("samt-core.kotlin-conventions") +} diff --git a/public-api/src/main/kotlin/tools/samt/api/plugin/Generator.kt b/public-api/src/main/kotlin/tools/samt/api/plugin/Generator.kt new file mode 100644 index 00000000..8e93c7c6 --- /dev/null +++ b/public-api/src/main/kotlin/tools/samt/api/plugin/Generator.kt @@ -0,0 +1,59 @@ +package tools.samt.api.plugin + +import tools.samt.api.types.SamtPackage + +/** + * A code generator. + * This interface is intended to be implemented by a code generator, for example Kotlin-Ktor. + */ +interface Generator { + /** + * The name of the generator, used to identify it in the configuration + */ + val name: String + + /** + * Generate code for the given packages + * @param generatorParams The parameters for the generator + * @return A list of generated files, which will be written to disk + */ + fun generate(generatorParams: GeneratorParams): List +} + +/** + * This class represents a file generated by a [Generator]. + */ +data class CodegenFile(val filepath: String, val source: String) + +/** + * The parameters for a [Generator]. + */ +interface GeneratorParams { + /** + * The packages to generate code for, includes all SAMT subpackages + */ + val packages: List + + /** + * The configuration for the generator as specified in the SAMT configuration + */ + val options: Map + + /** + * Report an error + * @param message The error message + */ + fun reportError(message: String) + + /** + * Report a warning + * @param message The warning message + */ + fun reportWarning(message: String) + + /** + * Report an info message + * @param message The info message + */ + fun reportInfo(message: String) +} diff --git a/public-api/src/main/kotlin/tools/samt/api/plugin/Transport.kt b/public-api/src/main/kotlin/tools/samt/api/plugin/Transport.kt new file mode 100644 index 00000000..1b37f851 --- /dev/null +++ b/public-api/src/main/kotlin/tools/samt/api/plugin/Transport.kt @@ -0,0 +1,178 @@ +package tools.samt.api.plugin + +import tools.samt.api.types.ServiceOperation +import tools.samt.api.types.ServiceType + +/** + * A transport configuration parser. + * This interface is intended to be implemented by a transport configuration parser, for example HTTP. + * It is used to parse the configuration body into a specific [TransportConfiguration]. + */ +interface TransportConfigurationParser { + /** + * The name of the transport, used to identify it in the configuration + */ + val transportName: String + + /** + * Create the default configuration for this transport, used when no configuration body is specified + * @return Default configuration + */ + fun default(): TransportConfiguration + + /** + * Parses the configuration body and returns the configuration object + * @throws RuntimeException if the configuration is invalid and graceful error handling is not possible + * @return Parsed configuration + */ + fun parse(params: TransportConfigurationParserParams): TransportConfiguration +} + +/** + * A base interface for transport configurations. + * This interface is intended to be sub-typed and extended by transport configuration implementations. + */ +interface TransportConfiguration + +/** + * The parameters for a [TransportConfigurationParser]. + */ +interface TransportConfigurationParserParams { + /** + * The configuration body to parse + */ + val config: ConfigurationObject + + /** + * Report an error + * @param message The error message + * @param context The configuration element that caused the error, will be highlighted in the editor + */ + fun reportError(message: String, context: ConfigurationElement? = null) + + /** + * Report a warning + * @param message The warning message + * @param context The configuration element that caused the warning, will be highlighted in the editor + */ + fun reportWarning(message: String, context: ConfigurationElement? = null) + + /** + * Report an info message + * @param message The info message + * @param context The configuration element that caused the info message, will be highlighted in the editor + */ + fun reportInfo(message: String, context: ConfigurationElement? = null) +} + +/** + * A configuration element + */ +interface ConfigurationElement { + /** + * This element as an [ConfigurationObject] + * @throws RuntimeException if this element is not an object + */ + val asObject: ConfigurationObject + + /** + * This element as an [ConfigurationValue] + * @throws RuntimeException if this element is not a primitive value + */ + val asValue: ConfigurationValue + + /** + * This element as an [ConfigurationList] + * @throws RuntimeException if this element is not a list + */ + val asList: ConfigurationList +} + +/** + * A configuration object, contains a map of fields + */ +interface ConfigurationObject : ConfigurationElement { + /** + * The fields of this object + */ + val fields: Map + + /** + * Get a field by name + * @throws RuntimeException if the field does not exist + */ + fun getField(name: String): ConfigurationElement + + /** + * Get a field by name, or null if it does not exist + */ + fun getFieldOrNull(name: String): ConfigurationElement? +} + +/** + * A configuration list, contains a list of elements + */ +interface ConfigurationList : ConfigurationElement { + /** + * The entries of this list + */ + val entries: List +} + +/** + * A primitive configuration value + */ +interface ConfigurationValue : ConfigurationElement { + /** + * This value as a string + * @throws RuntimeException if this value is not a string + */ + val asString: String + + /** + * This value as an identifier + * @throws RuntimeException if this value is not an identifier + */ + val asIdentifier: String + + /** + * This value as an enum, matches the enum value by name case-insensitively (e.g. "get" matches HttpMethod.GET) + * @throws RuntimeException if this value is not convertible to the provided [enum] + */ + fun > asEnum(enum: Class): T + + /** + * This value as a long + * @throws RuntimeException if this value is not a long + */ + val asLong: Long + + /** + * This value as a double + * @throws RuntimeException if this value is not a double + */ + val asDouble: Double + + /** + * This value as a boolean + * @throws RuntimeException if this value is not a boolean + */ + val asBoolean: Boolean + + /** + * This value as a service name, matches against services in the current provider context + * @throws RuntimeException if this value is not a service within the current provider context + */ + val asServiceName: ServiceType + + /** + * This value as a service operation name, matches against operations in the current provider context and [service] + * @throws RuntimeException if this value is not a service operation within the current provider context and [service] + */ + fun asOperationName(service: ServiceType): ServiceOperation +} + +/** + * Convenience wrapper for [ConfigurationValue.asEnum] + */ +inline fun > ConfigurationValue.asEnum() = asEnum(T::class.java) diff --git a/public-api/src/main/kotlin/tools/samt/api/types/StandardLibraryTypes.kt b/public-api/src/main/kotlin/tools/samt/api/types/StandardLibraryTypes.kt new file mode 100644 index 00000000..4b18c888 --- /dev/null +++ b/public-api/src/main/kotlin/tools/samt/api/types/StandardLibraryTypes.kt @@ -0,0 +1,40 @@ +package tools.samt.api.types + +interface LiteralType : Type + +interface IntType : LiteralType +interface LongType : LiteralType +interface FloatType : LiteralType +interface DoubleType : LiteralType +interface DecimalType : LiteralType +interface BooleanType : LiteralType +interface StringType : LiteralType +interface BytesType : LiteralType +interface DateType : LiteralType +interface DateTimeType : LiteralType +interface DurationType : LiteralType + +/** + * An ordered list of elements + */ +interface ListType : Type { + /** + * The type of the elements in the list + */ + val elementType: TypeReference +} + +/** + * A map of key-value pairs + */ +interface MapType : Type { + /** + * The type of the keys in the map + */ + val keyType: TypeReference + + /** + * The type of the values in the map + */ + val valueType: TypeReference +} diff --git a/public-api/src/main/kotlin/tools/samt/api/types/Types.kt b/public-api/src/main/kotlin/tools/samt/api/types/Types.kt new file mode 100644 index 00000000..b4fbfbb9 --- /dev/null +++ b/public-api/src/main/kotlin/tools/samt/api/types/Types.kt @@ -0,0 +1,108 @@ +package tools.samt.api.types + +interface SamtPackage { + val name: String + val qualifiedName: String + val records: List + val enums: List + val services: List + val providers: List + val consumers: List + val aliases: List +} + +/** + * A SAMT type + */ +interface Type + + +/** + * A type reference + */ +interface TypeReference { + /** + * The type this reference points to + */ + val type: Type + + /** + * Is true if this type reference is optional, meaning it can be null + */ + val isOptional: Boolean + + /** + * The range constraints placed on this type, if any + */ + val rangeConstraint: Constraint.Range? + + /** + * The size constraints placed on this type, if any + */ + val sizeConstraint: Constraint.Size? + + /** + * The pattern constraints placed on this type, if any + */ + val patternConstraint: Constraint.Pattern? + + /** + * The value constraints placed on this type, if any + */ + val valueConstraint: Constraint.Value? + + /** + * The runtime type this reference points to, could be different from [type] if this is an alias + */ + val runtimeType: Type + + /** + * Is true if this type reference or underlying type is optional, meaning it can be null at runtime + * This is different from [isOptional] in that it will return true for an alias that points to an optional type + */ + val isRuntimeOptional: Boolean + + /** + * The runtime range constraints placed on this type, if any. + * Will differ from [rangeConstraint] if this is an alias + */ + val runtimeRangeConstraint: Constraint.Range? + + /** + * The runtime size constraints placed on this type, if any. + * Will differ from [sizeConstraint] if this is an alias + */ + val runtimeSizeConstraint: Constraint.Size? + + /** + * The runtime pattern constraints placed on this type, if any. + * Will differ from [patternConstraint] if this is an alias + */ + val runtimePatternConstraint: Constraint.Pattern? + + /** + * The runtime value constraints placed on this type, if any. + * Will differ from [valueConstraint] if this is an alias + */ + val runtimeValueConstraint: Constraint.Value? +} + +interface Constraint { + interface Range : Constraint { + val lowerBound: Number? + val upperBound: Number? + } + + interface Size : Constraint { + val lowerBound: Long? + val upperBound: Long? + } + + interface Pattern : Constraint { + val pattern: String + } + + interface Value : Constraint { + val value: Any + } +} diff --git a/public-api/src/main/kotlin/tools/samt/api/types/UserTypes.kt b/public-api/src/main/kotlin/tools/samt/api/types/UserTypes.kt new file mode 100644 index 00000000..80a5b7e9 --- /dev/null +++ b/public-api/src/main/kotlin/tools/samt/api/types/UserTypes.kt @@ -0,0 +1,155 @@ +package tools.samt.api.types + +import tools.samt.api.plugin.TransportConfiguration + +interface UserType : Type { + val name: String + val qualifiedName: String +} + +interface AliasType : UserType { + /** + * The type this alias stands for, could be another alias + */ + val aliasedType: TypeReference + + /** + * The fully resolved type, will not contain any type aliases anymore, just the underlying merged type + */ + val fullyResolvedType: TypeReference +} + +/** + * A SAMT record + */ +interface RecordType : UserType { + val fields: List +} + +/** + * A field in a record + */ +interface RecordField { + val name: String + val type: TypeReference +} + +/** + * A SAMT enum + */ +interface EnumType : UserType { + val values: List +} + +/** + * A SAMT service + */ +interface ServiceType : UserType { + val operations: List +} + +/** + * An operation in a service + */ +interface ServiceOperation { + val name: String + val parameters: List +} + +/** + * A parameter in a service operation + */ +interface ServiceOperationParameter { + val name: String + val type: TypeReference +} + +/** + * A service operation that returns a response + */ +interface RequestResponseOperation : ServiceOperation { + /** + * The return type of this operation. + * If null, this operation returns nothing. + */ + val returnType: TypeReference? + + /** + * Is true if this operation is asynchronous. + * This could mean that the operation returns a future in Java, or a Promise in JavaScript. + */ + val isAsync: Boolean +} + +/** + * A service operation that is fire-and-forget, never returning a response + */ +interface OnewayOperation : ServiceOperation + +/** + * A SAMT provider + */ +interface ProviderType : UserType { + val implements: List + val transport: TransportConfiguration +} + +/** + * Connects a provider to a service + */ +interface ProvidedService { + /** + * The underlying service this provider implements + */ + val service: ServiceType + + /** + * The operations that are implemented by this provider + */ + val implementedOperations: List + + /** + * The operations that are not implemented by this provider + */ + val unimplementedOperations: List +} + +/** + * A SAMT consumer + */ +interface ConsumerType : Type { + /** + * The provider this consumer is connected to + */ + val provider: ProviderType + + /** + * The services this consumer uses + */ + val uses: List + + /** + * The package this consumer is located in + */ + val samtPackage: String +} + +/** + * Connects a consumer to a service + */ +interface ConsumedService { + /** + * The underlying service this consumer uses + */ + val service: ServiceType + + /** + * The operations that are consumed by this consumer + */ + val consumedOperations: List + + /** + * The operations that are not consumed by this consumer + */ + val unconsumedOperations: List +} diff --git a/samt-config/src/main/kotlin/tools/samt/config/SamtConfigurationParser.kt b/samt-config/src/main/kotlin/tools/samt/config/SamtConfigurationParser.kt index 7f3ec9c9..871877bb 100644 --- a/samt-config/src/main/kotlin/tools/samt/config/SamtConfigurationParser.kt +++ b/samt-config/src/main/kotlin/tools/samt/config/SamtConfigurationParser.kt @@ -37,12 +37,14 @@ object SamtConfigurationParser { SamtConfiguration() } + val projectDirectory = path.parent + return CommonSamtConfiguration( - source = parsedConfiguration.source, + source = projectDirectory.resolve(parsedConfiguration.source).normalize(), plugins = parsedConfiguration.plugins.map { plugin -> when (plugin) { is SamtLocalPluginConfiguration -> CommonLocalPluginConfiguration( - path = plugin.path + path = projectDirectory.resolve(plugin.path).normalize() ) is SamtMavenPluginConfiguration -> CommonMavenPluginConfiguration( @@ -63,7 +65,7 @@ object SamtConfigurationParser { generators = parsedConfiguration.generators.map { generator -> CommonGeneratorConfiguration( name = generator.name, - output = generator.output, + output = projectDirectory.resolve(generator.output).normalize(), options = generator.options ) } diff --git a/samt-config/src/test/kotlin/tools/samt/config/SamtConfigurationParserTest.kt b/samt-config/src/test/kotlin/tools/samt/config/SamtConfigurationParserTest.kt index 517d2b4c..5abfbf77 100644 --- a/samt-config/src/test/kotlin/tools/samt/config/SamtConfigurationParserTest.kt +++ b/samt-config/src/test/kotlin/tools/samt/config/SamtConfigurationParserTest.kt @@ -29,10 +29,10 @@ class SamtConfigurationParserTest { assertEquals( tools.samt.common.SamtConfiguration( - source = "./some/other/src", + source = testDirectory.resolve("some/other/src"), plugins = listOf( tools.samt.common.SamtLocalPluginConfiguration( - path = "./path/to/plugin.jar" + path = testDirectory.resolve("path/to/plugin.jar") ), tools.samt.common.SamtMavenPluginConfiguration( groupId = "com.example", @@ -50,7 +50,7 @@ class SamtConfigurationParserTest { generators = listOf( tools.samt.common.SamtGeneratorConfiguration( name = "samt-kotlin-ktor", - output = "./some/other/out", + output = testDirectory.resolve("some/other/out"), options = mapOf( "removePrefixFromSamtPackage" to "tools.samt", "addPrefixToKotlinPackage" to "tools.samt.example.generated", @@ -67,12 +67,12 @@ class SamtConfigurationParserTest { assertEquals( tools.samt.common.SamtConfiguration( - source = "./src", + source = testDirectory.resolve("src"), plugins = emptyList(), generators = listOf( tools.samt.common.SamtGeneratorConfiguration( name = "samt-kotlin-ktor", - output = "./out", + output = testDirectory.resolve("out"), options = mapOf( "addPrefixToKotlinPackage" to "com.company.samt.generated", ) diff --git a/semantic/src/main/kotlin/tools/samt/semantic/ConstraintBuilder.kt b/semantic/src/main/kotlin/tools/samt/semantic/ConstraintBuilder.kt index 92877ecd..e750bef9 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/ConstraintBuilder.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/ConstraintBuilder.kt @@ -12,7 +12,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { is NumberNode -> expressionNode.value is WildcardNode -> null else -> { - controller.getOrCreateContext(expressionNode.location.source).error { + expressionNode.reportError(controller) { message("Range constraint argument must be a valid number range") highlight("neither a number nor '*'", expressionNode.location) help("A valid constraint would be range(1..10.5) or range(1..*)") @@ -25,7 +25,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { val higher = resolveSide(argument.right) if (lower == null && higher == null) { - controller.getOrCreateContext(argument.location.source).error { + argument.reportError(controller) { message("Range constraint must have at least one valid number") highlight("invalid constraint", argument.location) help("A valid constraint would be range(1..10.5) or range(1..*)") @@ -36,11 +36,10 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { if (lower is Double && higher is Double && lower > higher || lower is Long && higher is Long && lower > higher ) { - controller.getOrCreateContext(argument.location.source) - .error { - message("Range constraint must have a lower bound lower than the upper bound") - highlight("invalid constraint", argument.location) - } + argument.reportError(controller) { + message("Range constraint must have a lower bound lower than the upper bound") + highlight("invalid constraint", argument.location) + } return null } @@ -60,7 +59,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { is WildcardNode -> null else -> { - controller.getOrCreateContext(expressionNode.location.source).error { + expressionNode.reportError(controller) { message("Expected size constraint argument to be a whole number or wildcard") highlight("expected whole number or wildcard '*'", expressionNode.location) help("A valid constraint would be size(1..10), size(1..*) or size(*..10)") @@ -73,7 +72,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { val higher = resolveSide(argument.right) if (lower == null && higher == null) { - controller.getOrCreateContext(argument.location.source).error { + argument.reportError(controller) { message("Constraint parameters cannot both be wildcards") highlight("invalid constraint", argument.location) help("A valid constraint would be range(1..10.5) or range(1..*)") @@ -82,7 +81,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { } if (lower != null && higher != null && lower > higher) { - controller.getOrCreateContext(argument.location.source).error { + argument.reportError(controller) { message("Size constraint lower bound must be lower than or equal to the upper bound") highlight("invalid constraint", argument.location) } @@ -113,7 +112,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { is NumberNode -> ResolvedTypeReference.Constraint.Value(expression, argument.value) is BooleanNode -> ResolvedTypeReference.Constraint.Value(expression, argument.value) else -> { - controller.getOrCreateContext(argument.location.source).error { + argument.reportError(controller) { message("Value constraint must be a string, integer, float or boolean") highlight("invalid constraint", argument.location) help("A valid constraint would be value(\"foo\"), value(42) or value(false)") @@ -136,7 +135,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { when (name) { "range" -> { if (expression.arguments.size != 1 || expression.arguments.firstOrNull() !is RangeExpressionNode) { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Range constraint must have exactly one range argument") highlight("invalid constraint", expression.location) help("A valid constraint would be range(1..10.5)") @@ -148,7 +147,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { "size" -> { if (expression.arguments.size != 1 || expression.arguments.firstOrNull() !is RangeExpressionNode) { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Size constraint must have exactly one size argument") highlight("invalid constraint", expression.location) help("A valid constraint would be size(1..10)") @@ -163,7 +162,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { "pattern" -> { if (expression.arguments.size != 1 || expression.arguments.firstOrNull() !is StringNode) { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Pattern constraint must have exactly one string argument") highlight("invalid constraint", expression.location) help("A valid constraint would be pattern(\"a-z\")") @@ -175,7 +174,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { "value" -> { if (expression.arguments.size != 1) { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("value constraint must have exactly one argument") highlight("invalid constraint", expression.location) } @@ -185,7 +184,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { } is String -> { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Constraint with name '${name}' does not exist") highlight("unknown constraint", expression.base.location) help("A valid constraint would be range(1..10.5), size(1..10), pattern(\"a-z\") or value(\"foo\")") @@ -220,7 +219,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { is StringNode -> return createPattern(expression = expression, argument = expression) else -> Unit } - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Invalid constraint") highlight("invalid constraint", expression.location) } @@ -247,7 +246,7 @@ internal class ConstraintBuilder(private val controller: DiagnosticController) { return if (validateConstraintMatches(constraint, baseType)) { constraint } else { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Constraint '${constraint.humanReadableName}' is not allowed for type '${baseType.humanReadableName}'") highlight("illegal constraint", expression.location) diff --git a/semantic/src/main/kotlin/tools/samt/semantic/Package.kt b/semantic/src/main/kotlin/tools/samt/semantic/Package.kt index ce37db5f..bb497a4c 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/Package.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/Package.kt @@ -2,7 +2,7 @@ package tools.samt.semantic import tools.samt.parser.* -class Package(val name: String) { +class Package(val name: String, private val parent: Package?) { val subPackages: MutableList = mutableListOf() val records: MutableList = mutableListOf() @@ -45,43 +45,59 @@ class Package(val name: String) { } operator fun plusAssign(record: RecordType) { + require(!isRootPackage) records.add(record) types[record.name] = record - typeByNode[record.declaration] = record + linkType(record.declaration, record) } operator fun plusAssign(enum: EnumType) { + require(!isRootPackage) enums.add(enum) types[enum.name] = enum - typeByNode[enum.declaration] = enum + linkType(enum.declaration, enum) } operator fun plusAssign(service: ServiceType) { + require(!isRootPackage) services.add(service) types[service.name] = service - typeByNode[service.declaration] = service + linkType(service.declaration, service) } operator fun plusAssign(provider: ProviderType) { + require(!isRootPackage) providers.add(provider) types[provider.name] = provider - typeByNode[provider.declaration] = provider + linkType(provider.declaration, provider) } operator fun plusAssign(consumer: ConsumerType) { + require(!isRootPackage) consumers.add(consumer) - typeByNode[consumer.declaration] = consumer + linkType(consumer.declaration, consumer) } operator fun plusAssign(alias: AliasType) { + require(!isRootPackage) aliases.add(alias) types[alias.name] = alias - typeByNode[alias.declaration] = alias + linkType(alias.declaration, alias) } operator fun contains(identifier: IdentifierNode): Boolean = types.containsKey(identifier.name) + val isRootPackage: Boolean + get() = parent == null + val allSubPackages: List get() = subPackages + subPackages.flatMap { it.allSubPackages } + + val nameComponents: List + get() = if (isRootPackage) { + emptyList() // root package + } else { + parent!!.nameComponents + name + } } diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt index 4c6df4b2..10be0c11 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModel.kt @@ -2,14 +2,11 @@ package tools.samt.semantic import tools.samt.common.DiagnosticController import tools.samt.common.SourceFile -import tools.samt.parser.FileNode -import tools.samt.parser.NamedDeclarationNode -import tools.samt.parser.TypeImportNode -import tools.samt.parser.WildcardImportNode +import tools.samt.parser.* class SemanticModel( - val global: Package, - val userMetadata: UserMetadata, + val global: Package, + val userMetadata: UserMetadata, ) { companion object { fun build(files: List, controller: DiagnosticController): SemanticModel { @@ -26,11 +23,11 @@ class SemanticModel( * - Resolve all references to types * - Resolve all references to their declarations in the AST * */ -internal class SemanticModelBuilder ( +internal class SemanticModelBuilder( private val files: List, private val controller: DiagnosticController, ) { - private val global = Package(name = "") + private val global = Package(name = "root", null) private val preProcessor = SemanticModelPreProcessor(controller) private val postProcessor = SemanticModelPostProcessor(controller) private val referenceResolver = SemanticModelReferenceResolver(controller, global) @@ -77,15 +74,16 @@ internal class SemanticModelBuilder ( check(typeReference is ResolvedTypeReference) fun merge(base: ResolvedTypeReference, inner: ResolvedTypeReference): ResolvedTypeReference { if (base.isOptional && inner.isOptional) { - controller.getOrCreateContext(base.fullNode.location.source).warn { + base.fullNode.reportWarning(controller) { message("Type is already optional, ignoring '?'") highlight("duplicate optional", base.fullNode.location) highlight("declared optional here", inner.fullNode.location) } } - val overlappingConstraints = base.constraints.filter { baseConstraint -> inner.constraints.any { innerConstraint -> baseConstraint::class == innerConstraint::class } } + val overlappingConstraints = + base.constraints.filter { baseConstraint -> inner.constraints.any { innerConstraint -> baseConstraint::class == innerConstraint::class } } for (overlappingConstraint in overlappingConstraints) { - controller.getOrCreateContext(base.fullNode.location.source).error { + base.fullNode.reportError(controller) { message("Cannot have multiple constraints of the same type") val baseConstraint = base.constraints.first { it::class == overlappingConstraint::class } highlight("duplicate constraint", baseConstraint.node.location) @@ -117,6 +115,7 @@ internal class SemanticModelBuilder ( null } } + is MapType -> { val keyType = getFullyResolvedType(type.keyType) val valueType = getFullyResolvedType(type.valueType) @@ -126,13 +125,15 @@ internal class SemanticModelBuilder ( null } } + is PackageType -> { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Type alias cannot reference package") highlight("illegal package", typeReference.typeNode.location) } typeReference } + is ConsumerType -> error("Consumer type cannot be referenced by name, this should never happen") } } @@ -196,7 +197,7 @@ internal class SemanticModelBuilder ( file.imports.forEach { import -> fun addImportedType(name: String, type: Type) { putIfAbsent(name, type)?.let { existingType -> - controller.getOrCreateContext(file.sourceFile).error { + file.reportError(controller) { message("Import '$name' conflicts with locally defined type with same name") highlight("conflicting import", import.location) if (existingType is UserDeclared) { @@ -232,7 +233,7 @@ internal class SemanticModelBuilder ( addImportedType(name, type) } } else { - controller.getOrCreateContext(file.sourceFile).error { + file.reportError(controller) { message("Import '${import.name.name}.*' must point to a package and not a type") highlight( "illegal wildcard import", import.location, suggestChange = "import ${ @@ -252,7 +253,7 @@ internal class SemanticModelBuilder ( // Add built-in types fun addBuiltIn(name: String, type: Type) { putIfAbsent(name, type)?.let { existingType -> - controller.getOrCreateContext(file.sourceFile).error { + file.reportError(controller) { message("Type '$name' shadows built-in type with same name") if (existingType is UserDeclared) { val definition = existingType.declaration diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelAnnotationProcessor.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelAnnotationProcessor.kt index 7a25fca5..cec0dc0a 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelAnnotationProcessor.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelAnnotationProcessor.kt @@ -3,6 +3,7 @@ package tools.samt.semantic import tools.samt.common.DiagnosticController import tools.samt.parser.AnnotationNode import tools.samt.parser.StringNode +import tools.samt.parser.reportError internal class SemanticModelAnnotationProcessor( private val controller: DiagnosticController @@ -12,11 +13,10 @@ internal class SemanticModelAnnotationProcessor( val deprecations = mutableMapOf() for (element in global.getAnnotatedElements()) { for (annotation in element.annotations) { - val context = controller.getOrCreateContext(annotation.location.source) when (val name = annotation.name.name) { "Description" -> { if (element in descriptions) { - context.error { + annotation.reportError(controller) { message("Duplicate @Description annotation") highlight("duplicate annotation", annotation.location) highlight("previous annotation", element.annotations.first { it.name.name == "Description" }.location) @@ -26,7 +26,7 @@ internal class SemanticModelAnnotationProcessor( } "Deprecated" -> { if (element in deprecations) { - context.error { + annotation.reportError(controller) { message("Duplicate @Deprecated annotation") highlight("duplicate annotation", annotation.location) highlight("previous annotation", element.annotations.first { it.name.name == "Deprecated" }.location) @@ -35,7 +35,7 @@ internal class SemanticModelAnnotationProcessor( deprecations[element] = getDeprecation(annotation) } else -> { - context.error { + annotation.reportError(controller) { message("Unknown annotation @${name}, allowed annotations are @Description and @Deprecated") highlight("invalid annotation", annotation.location) } @@ -62,9 +62,8 @@ internal class SemanticModelAnnotationProcessor( private fun getDescription(annotation: AnnotationNode): String { check(annotation.name.name == "Description") val arguments = annotation.arguments - val context = controller.getOrCreateContext(annotation.location.source) if (arguments.isEmpty()) { - context.error { + annotation.reportError(controller) { message("Missing argument for @Description") highlight("invalid annotation", annotation.location) } @@ -72,7 +71,7 @@ internal class SemanticModelAnnotationProcessor( } if (arguments.size > 1) { val errorLocation = arguments[1].location.copy(end = arguments.last().location.end) - context.error { + annotation.reportError(controller) { message("@Description expects exactly one string argument") highlight("extraneous arguments", errorLocation) } @@ -80,7 +79,7 @@ internal class SemanticModelAnnotationProcessor( return when (val description = arguments.first()) { is StringNode -> description.value else -> { - context.error { + annotation.reportError(controller) { message("Argument for @Description must be a string") highlight("invalid argument type", description.location) } @@ -91,17 +90,16 @@ internal class SemanticModelAnnotationProcessor( private fun getDeprecation(annotation: AnnotationNode): UserMetadata.Deprecation { check(annotation.name.name == "Deprecated") - val context = controller.getOrCreateContext(annotation.location.source) val description = annotation.arguments.firstOrNull() if (description != null && description !is StringNode) { - context.error { + annotation.reportError(controller) { message("Argument for @Deprecated must be a string") highlight("invalid argument type", description.location) } } if (annotation.arguments.size > 1) { val errorLocation = annotation.arguments[1].location.copy(end = annotation.arguments.last().location.end) - context.error { + annotation.reportError(controller) { message("@Deprecated expects at most one string argument") highlight("extraneous arguments", errorLocation) } diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt index 451b881d..0422a777 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPostProcessor.kt @@ -2,6 +2,8 @@ package tools.samt.semantic import tools.samt.common.DiagnosticController import tools.samt.common.Location +import tools.samt.parser.reportError +import tools.samt.parser.reportWarning internal class SemanticModelPostProcessor(private val controller: DiagnosticController) { /** @@ -27,7 +29,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont check(typeReference is ResolvedTypeReference) when (val type = typeReference.type) { is ServiceType -> { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { // error message applies to both record fields and return types message("Cannot use service '${type.name}' as type") highlight("service type not allowed here", typeReference.typeNode.location) @@ -35,14 +37,14 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont } is ProviderType -> { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Cannot use provider '${type.name}' as type") highlight("provider type not allowed here", typeReference.typeNode.location) } } is PackageType -> { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Cannot use package '${type.packageName}' as type") highlight("package type not allowed here", typeReference.typeNode.location) } @@ -61,7 +63,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont val underlyingTypeReference = type.fullyResolvedType ?: return val underlyingType = underlyingTypeReference.type if (underlyingType is ServiceType || underlyingType is ProviderType || underlyingType is PackageType) { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Type alias refers to '${underlyingType.humanReadableName}', which cannot be used in this context") highlight("type alias", typeReference.typeNode.location) highlight("underlying type", underlyingTypeReference.typeNode.location) @@ -69,7 +71,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont } if (typeReference.isOptional && underlyingTypeReference.isOptional) { - controller.getOrCreateContext(typeReference.typeNode.location.source).warn { + typeReference.typeNode.reportWarning(controller) { message("Type alias refers to type which is already optional, ignoring '?'") highlight("duplicate optional", typeReference.fullNode.location) highlight("declared optional here", underlyingTypeReference.fullNode.location) @@ -96,7 +98,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont if (aliasedType is ServiceType) { block(aliasedType) } else { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Expected a service but type alias '${type.name}' points to '${aliasedType.humanReadableName}'") highlight("type alias", typeReference.typeNode.location) highlight("underlying type", aliasedTypeReference.typeNode.location) @@ -106,7 +108,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont is UnknownType -> Unit else -> { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Expected a service but got '${type.humanReadableName}'") highlight("illegal type", typeReference.typeNode.location) } @@ -129,7 +131,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont if (aliasedType is ProviderType) { block(aliasedType) } else { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Expected a provider but type alias '${type.name}' points to '${aliasedType.humanReadableName}'") highlight("type alias", typeReference.typeNode.location) highlight("underlying type", aliasedTypeReference.typeNode.location) @@ -139,7 +141,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont is UnknownType -> Unit else -> { - controller.getOrCreateContext(typeReference.typeNode.location.source).error { + typeReference.typeNode.reportError(controller) { message("Expected a provider but got '${type.humanReadableName}'") highlight("illegal type", typeReference.typeNode.location) } @@ -152,7 +154,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont var isBlank = true if (typeReference.constraints.isNotEmpty()) { isBlank = false - controller.getOrCreateContext(typeReference.fullNode.location.source).error { + typeReference.fullNode.reportError(controller) { message("Cannot have constraints on $what") for (constraint in typeReference.constraints) { highlight("illegal constraint", constraint.node.location) @@ -161,7 +163,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont } if (typeReference.isOptional) { isBlank = false - controller.getOrCreateContext(typeReference.fullNode.location.source).error { + typeReference.fullNode.reportError(controller) { message("Cannot have optional $what") highlight("illegal optional", typeReference.fullNode.location) } @@ -190,7 +192,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont provider.implements.forEach { implements -> checkServiceType(implements.service) { type -> implementsTypes.putIfAbsent(type, implements.node.location)?.let { existingLocation -> - controller.getOrCreateContext(implements.node.location.source).error { + implements.node.reportError(controller) { message("Service '${type.name}' already implemented") highlight("duplicate implements", implements.node.location) highlight("previous implements", existingLocation) @@ -206,7 +208,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont if (matchingOperation != null) { matchingOperation } else { - controller.getOrCreateContext(provider.declaration.location.source).error { + provider.declaration.reportError(controller) { message("Operation '${serviceOperationName.name}' not found in service '${type.name}'") highlight("unknown operation", serviceOperationName.location) } @@ -224,7 +226,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont consumer.uses.forEach { uses -> checkServiceType(uses.service) { type -> usesTypes.putIfAbsent(type, uses.node.location)?.let { existingLocation -> - controller.getOrCreateContext(uses.node.location.source).error { + uses.node.reportError(controller) { message("Service '${type.name}' already used") highlight("duplicate uses", uses.node.location) highlight("previous uses", existingLocation) @@ -235,7 +237,7 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont val matchingImplements = providerType.implements.find { (it.service as ResolvedTypeReference).type == type } if (matchingImplements == null) { - controller.getOrCreateContext(uses.node.location.source).error { + uses.node.reportError(controller) { message("Service '${type.name}' is not implemented by provider '${providerType.name}'") highlight("unavailable service", uses.node.serviceName.location) } @@ -251,12 +253,12 @@ internal class SemanticModelPostProcessor(private val controller: DiagnosticCont matchingOperation } else { if (type.operations.any { it.name == serviceOperationName.name }) { - controller.getOrCreateContext(uses.node.location.source).error { + uses.node.reportError(controller) { message("Operation '${serviceOperationName.name}' in service '${type.name}' is not implemented by provider '${providerType.name}'") highlight("unavailable operation", serviceOperationName.location) } } else { - controller.getOrCreateContext(uses.node.location.source).error { + uses.node.reportError(controller) { message("Operation '${serviceOperationName.name}' not found in service '${type.name}'") highlight("unknown operation", serviceOperationName.location) } diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPreProcessor.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPreProcessor.kt index b9710b76..2c9db0a7 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPreProcessor.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelPreProcessor.kt @@ -12,7 +12,7 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr ) { if (statement.name in parentPackage) { val existingType = parentPackage.types.getValue(statement.name.name) - controller.getOrCreateContext(statement.location.source).error { + statement.reportError(controller) { message("'${statement.name.name}' is already declared") highlight("duplicate declaration", statement.name.location) if (existingType is UserDeclared) { @@ -32,7 +32,7 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr val name = identifierGetter(item).name val existingLocation = existingItems.putIfAbsent(name, item.location) if (existingLocation != null) { - controller.getOrCreateContext(item.location.source).error { + item.reportError(controller) { message("$what '$name' is defined more than once") highlight("duplicate declaration", identifierGetter(item).location) highlight("previous declaration", existingLocation) @@ -47,7 +47,7 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr for (component in file.packageDeclaration.name.components) { var subPackage = parentPackage.subPackages.find { it.name == component.name } if (subPackage == null) { - subPackage = Package(component.name) + subPackage = Package(component.name, parentPackage) parentPackage.subPackages.add(subPackage) } parentPackage = subPackage @@ -59,7 +59,7 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr reportDuplicateDeclaration(parentPackage, statement) reportDuplicates(statement.fields, "Record field") { it.name } if (statement.extends.isNotEmpty()) { - controller.getOrCreateContext(statement.location.source).error { + statement.reportError(controller) { message("Record extends are not yet supported") highlight("cannot extend other records", statement.extends.first().location) } @@ -73,7 +73,8 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr } parentPackage += RecordType( fields = fields, - declaration = statement + declaration = statement, + parentPackage = parentPackage, ) } @@ -81,7 +82,7 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr reportDuplicateDeclaration(parentPackage, statement) reportDuplicates(statement.values, "Enum value") { it } val values = statement.values.map { it.name } - parentPackage += EnumType(values, statement) + parentPackage += EnumType(values, statement, parentPackage) } is ServiceDeclarationNode -> { @@ -106,12 +107,6 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr } is RequestResponseOperationNode -> { - if (operation.isAsync) { - controller.getOrCreateContext(operation.location.source).error { - message("Async operations are not yet supported") - highlight("unsupported async operation", operation.location) - } - } ServiceType.RequestResponseOperation( name = operation.name.name, parameters = parameters, @@ -123,7 +118,7 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr } } } - parentPackage += ServiceType(operations, statement) + parentPackage += ServiceType(operations, statement, parentPackage) } is ProviderDeclarationNode -> { @@ -135,11 +130,12 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr implements ) } + val transport = ProviderType.Transport( - name = statement.transport.protocolName.name, + name = statement.transport.protocolName.name.lowercase(), configuration = statement.transport.configuration ) - parentPackage += ProviderType(implements, transport, statement) + parentPackage += ProviderType(implements, transport, statement, parentPackage) } is ConsumerDeclarationNode -> { @@ -152,7 +148,8 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr node = it ) }, - declaration = statement + declaration = statement, + parentPackage = parentPackage, ) } @@ -160,7 +157,8 @@ internal class SemanticModelPreProcessor(private val controller: DiagnosticContr reportDuplicateDeclaration(parentPackage, statement) parentPackage += AliasType( aliasedType = UnresolvedTypeReference(statement.type), - declaration = statement + declaration = statement, + parentPackage = parentPackage, ) } diff --git a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelReferenceResolver.kt b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelReferenceResolver.kt index a57e0363..9412ddf9 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelReferenceResolver.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/SemanticModelReferenceResolver.kt @@ -28,7 +28,7 @@ internal class SemanticModelReferenceResolver( return ResolvedTypeReference(type, expression) } - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Type '${expression.name}' could not be resolved") highlight("unresolved type", expression.location) } @@ -54,14 +54,14 @@ internal class SemanticModelReferenceResolver( } null -> { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Type '${expression.name}' could not be resolved") highlight("unresolved type", expression.location) } } else -> { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Type '${expression.components.first().name}' is not a package, cannot access sub-types") highlight("not a package", expression.components.first().location) } @@ -73,14 +73,14 @@ internal class SemanticModelReferenceResolver( val baseType = resolveAndLinkExpression(scope, expression.base) val constraints = expression.arguments.mapNotNull { constraintBuilder.build(baseType.type, it) } if (baseType.constraints.isNotEmpty()) { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Cannot have nested constraints") highlight("illegal nested constraint", expression.location) } } for (constraintInstances in constraints.groupBy { it::class }.values) { if (constraintInstances.size > 1) { - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Cannot have multiple constraints of the same type") highlight("first constraint", constraintInstances.first().node.location) for (duplicateConstraints in constraintInstances.drop(1)) { @@ -128,7 +128,7 @@ internal class SemanticModelReferenceResolver( } } } - controller.getOrCreateContext(expression.location.source).error { + expression.reportError(controller) { message("Unsupported generic type") highlight(expression.location) help("Valid generic types are List and Map") @@ -138,7 +138,7 @@ internal class SemanticModelReferenceResolver( is OptionalDeclarationNode -> { val baseType = resolveAndLinkExpression(scope, expression.base) if (baseType.isOptional) { - controller.getOrCreateContext(expression.location.source).warn { + expression.reportWarning(controller) { message("Type is already optional, ignoring '?'") highlight("already optional", expression.base.location) } @@ -149,7 +149,7 @@ internal class SemanticModelReferenceResolver( is BooleanNode, is NumberNode, is StringNode, - -> controller.getOrCreateContext(expression.location.source).error { + -> expression.reportError(controller) { message("Cannot use literal value as type") highlight("not a type expression", expression.location) } @@ -158,7 +158,7 @@ internal class SemanticModelReferenceResolver( is ArrayNode, is RangeExpressionNode, is WildcardNode, - -> controller.getOrCreateContext(expression.location.source).error { + -> expression.reportError(controller) { message("Invalid type expression") highlight("not a type expression", expression.location) } @@ -180,7 +180,7 @@ internal class SemanticModelReferenceResolver( } null -> { - controller.getOrCreateContext(component.location.source).error { + component.reportError(controller) { message("Could not resolve reference '${component.name}'") highlight("unresolved reference", component.location) } @@ -191,7 +191,7 @@ internal class SemanticModelReferenceResolver( if (iterator.hasNext()) { // We resolved a non-package type but there are still components left - controller.getOrCreateContext(component.location.source).error { + component.reportError(controller) { message("Type '${component.name}' is not a package, cannot access sub-types") highlight("must be a package", component.location) } diff --git a/semantic/src/main/kotlin/tools/samt/semantic/Types.kt b/semantic/src/main/kotlin/tools/samt/semantic/Types.kt index 50ed4d7b..5cbae41b 100644 --- a/semantic/src/main/kotlin/tools/samt/semantic/Types.kt +++ b/semantic/src/main/kotlin/tools/samt/semantic/Types.kt @@ -113,6 +113,7 @@ object DurationType : LiteralType { sealed interface UserDeclaredNamedType : UserDeclared, Type { override val humanReadableName: String get() = name override val declaration: NamedDeclarationNode + val parentPackage: Package val name: String get() = declaration.name.name } @@ -146,11 +147,13 @@ class AliasType( /** The fully resolved type, will not contain any type aliases anymore, just the underlying merged type */ var fullyResolvedType: ResolvedTypeReference? = null, override val declaration: TypeAliasNode, + override val parentPackage: Package, ) : UserDeclaredNamedType, UserAnnotated class RecordType( val fields: List, override val declaration: RecordDeclarationNode, + override val parentPackage: Package, ) : UserDeclaredNamedType, UserAnnotated { class Field( val name: String, @@ -162,11 +165,13 @@ class RecordType( class EnumType( val values: List, override val declaration: EnumDeclarationNode, + override val parentPackage: Package, ) : UserDeclaredNamedType, UserAnnotated class ServiceType( val operations: List, override val declaration: ServiceDeclarationNode, + override val parentPackage: Package, ) : UserDeclaredNamedType, UserAnnotated { sealed interface Operation : UserAnnotated { val name: String @@ -198,8 +203,9 @@ class ServiceType( class ProviderType( val implements: List, - @Suppress("unused") val transport: Transport, + val transport: Transport, override val declaration: ProviderDeclarationNode, + override val parentPackage: Package, ) : UserDeclaredNamedType { class Implements( var service: TypeReference, @@ -209,13 +215,14 @@ class ProviderType( class Transport( val name: String, - @Suppress("unused") val configuration: Any?, + val configuration: ObjectNode?, ) } class ConsumerType( var provider: TypeReference, var uses: List, + val parentPackage: Package, override val declaration: ConsumerDeclarationNode, ) : Type, UserDeclared { class Uses( diff --git a/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt b/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt index c0a4e012..dced9482 100644 --- a/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt +++ b/semantic/src/test/kotlin/tools/samt/semantic/SemanticModelTest.kt @@ -644,20 +644,6 @@ class SemanticModelTest { source to listOf("Error: Record extends are not yet supported") ) } - - @Test - fun `cannot use async operations`() { - val source = """ - package color - - service ColorService { - async get(): Int - } - """.trimIndent() - parseAndCheck( - source to listOf("Error: Async operations are not yet supported") - ) - } } @Nested @@ -821,7 +807,7 @@ class SemanticModelTest { package services service FooService { - foo() + oneway foo() } """.trimIndent() parseAndCheck( @@ -901,7 +887,7 @@ class SemanticModelTest { @Deprecated("service deprecation") service UserService { @Deprecated("operation deprecation") - get(@Deprecated("parameter deprecation") id: Id): User + async get(@Deprecated("parameter deprecation") id: Id): User } """.trimIndent() val model = parseAndCheck( diff --git a/settings.gradle.kts b/settings.gradle.kts index 9ebe2560..bbf7bfa5 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -7,6 +7,8 @@ include( ":semantic", ":language-server", ":samt-config", + ":codegen", + ":public-api", ) dependencyResolutionManagement { diff --git a/specification/examples/.gitignore b/specification/examples/.gitignore new file mode 100644 index 00000000..78d42947 --- /dev/null +++ b/specification/examples/.gitignore @@ -0,0 +1 @@ +*.kt diff --git a/specification/examples/.samtrc.yaml b/specification/examples/.samtrc.yaml new file mode 100644 index 00000000..c35faf33 --- /dev/null +++ b/specification/examples/.samtrc.yaml @@ -0,0 +1 @@ +extends: strict diff --git a/specification/examples/greeter.samt b/specification/examples/greeter.samt index 6f01a3f8..3c57a764 100644 --- a/specification/examples/greeter.samt +++ b/specification/examples/greeter.samt @@ -18,7 +18,7 @@ record GreetResponse { @Description("foo bar This is some very long comment which describes the service. - baz") + ") timestamp: DateTime } diff --git a/specification/examples/petstore/.samtrc.yaml b/specification/examples/petstore/.samtrc.yaml new file mode 100644 index 00000000..c35faf33 --- /dev/null +++ b/specification/examples/petstore/.samtrc.yaml @@ -0,0 +1 @@ +extends: strict diff --git a/specification/examples/petstore/samt.yaml b/specification/examples/petstore/samt.yaml new file mode 100644 index 00000000..68b29443 --- /dev/null +++ b/specification/examples/petstore/samt.yaml @@ -0,0 +1,2 @@ +generators: + - name: kotlin-ktor-provider diff --git a/specification/examples/petstore/src/common.samt b/specification/examples/petstore/src/common.samt new file mode 100644 index 00000000..0bc70507 --- /dev/null +++ b/specification/examples/petstore/src/common.samt @@ -0,0 +1,7 @@ +package org.openapitools.examples.petstore + +typealias ID = Long + +typealias UUID = String ( pattern("[0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12}") ) + +record NotFoundFault diff --git a/specification/examples/petstore/src/pet-provider.samt b/specification/examples/petstore/src/pet-provider.samt new file mode 100644 index 00000000..06904843 --- /dev/null +++ b/specification/examples/petstore/src/pet-provider.samt @@ -0,0 +1,21 @@ +package org.openapitools.examples.petstore + +provide PetEndpointHTTP { + implements PetService + + transport http { + serialization: "json", + operations: { + PetService: { + basePath: "/pet", + addPet: "POST /", + updatePet: "PUT /", + findPetsByStatus: "GET /findByStatus {status in query}", + findPetsByTags: "GET /findByTags {tags in query}", + getPetById: "GET /{petId}", + updatePetWithForm: "POST /{petId} {name, status in query}", + deletePet: "DELETE /petId}" + } + } + } +} diff --git a/specification/examples/petstore/src/pet.samt b/specification/examples/petstore/src/pet.samt new file mode 100644 index 00000000..b15a1bc6 --- /dev/null +++ b/specification/examples/petstore/src/pet.samt @@ -0,0 +1,43 @@ +package org.openapitools.examples.petstore + +record Pet { + id: ID? + name: String + category: Category? + photoUrls: List + tags: List? + status: PetStatus +} + +record Category { + id: ID? + name: String? +} + +record Tag { + id: ID? + name: String? +} + +record ApiResponse { + code: Int? + type: String? + message: String? +} + +enum PetStatus { + available, + pending, + sold +} + +service PetService { + addPet(newPet: Pet): Pet + updatePet(updatedPet: Pet): Pet + findPetsByStatus(status: PetStatus): List + findPetsByTags(tags: List): List + getPetById(petId: ID): Pet + updatePetWithForm(petId: ID, name: String?, status: PetStatus?): Pet + deletePet(petId: ID): Pet + uploadImage(petId: ID, additionalMetadata: String?, file: Bytes): ApiResponse +} diff --git a/specification/examples/petstore/src/store-provider.samt b/specification/examples/petstore/src/store-provider.samt new file mode 100644 index 00000000..6204bd3c --- /dev/null +++ b/specification/examples/petstore/src/store-provider.samt @@ -0,0 +1,18 @@ +package org.openapitools.examples.petstore + +provide StoreEndpointHTTP { + implements StoreService + + transport http { + serialization: "json", + operations: { + StoreService: { + basePath: "/store", + getInventory: "GET /inventory", + placeOrder: "POST /order", + getOrderById: "GET /order/{orderId}", + deleteOrder: "DELETE /order/{orderId}" + } + } + } +} diff --git a/specification/examples/petstore/src/store.samt b/specification/examples/petstore/src/store.samt new file mode 100644 index 00000000..90b70298 --- /dev/null +++ b/specification/examples/petstore/src/store.samt @@ -0,0 +1,23 @@ +package org.openapitools.examples.petstore + +record Order { + id: ID? + petId: ID? + quantity: Int? + shipDate: DateTime? + status: OrderStatus? + complete: Boolean? +} + +enum OrderStatus { + placed, + approved, + delivered +} + +service StoreService { + getInventory(): Map + placeOrder(order: Order): Order + getOrderById(orderId: ID): Order + deleteOrder(orderId: ID) +} diff --git a/specification/examples/petstore/src/user-provider.samt b/specification/examples/petstore/src/user-provider.samt new file mode 100644 index 00000000..dad6ace6 --- /dev/null +++ b/specification/examples/petstore/src/user-provider.samt @@ -0,0 +1,21 @@ +package org.openapitools.examples.petstore + +provide UserEndpointHTTP { + implements UserService + + transport http { + serialization: "json", + operations: { + UserService: { + basePath: "/user", + createUser: "POST /", + createUsers: "POST /createWithList", + login: "GET /login {username, password in query}", + logout: "GET /logout", + getUserByUsername: "GET /{username}", + updateUser: "PUT /{username}", + deleteUser: "DELETE /{username}" + } + } + } +} diff --git a/specification/examples/petstore/src/user.samt b/specification/examples/petstore/src/user.samt new file mode 100644 index 00000000..ce343784 --- /dev/null +++ b/specification/examples/petstore/src/user.samt @@ -0,0 +1,22 @@ +package org.openapitools.examples.petstore + +record User { + id: ID? + username: String? + firstName: String? + lastName: String? + email: String? + password: String? + phone: String? + userStatus: Int? +} + +service UserService { + createUser(user: User): User + createUsers(users: List): User + login(username: String, password: String): String + logout() + getUserByUsername(username: String): User + updateUser(username: String, user: User): User + deleteUser(username: String) +} diff --git a/specification/examples/samt.yaml b/specification/examples/samt.yaml new file mode 100644 index 00000000..67a45950 --- /dev/null +++ b/specification/examples/samt.yaml @@ -0,0 +1,5 @@ +source: ./todo-service + +generators: + - name: kotlin-ktor-consumer + output: ./out diff --git a/specification/examples/todo-service/common.samt b/specification/examples/todo-service/common.samt index a0a898a4..8be8dfce 100644 --- a/specification/examples/todo-service/common.samt +++ b/specification/examples/todo-service/common.samt @@ -2,5 +2,10 @@ package tools.samt.examples.common typealias UUID = String ( pattern("[0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12}") ) +/* record NotFoundFault extends Fault record MissingPermissionsFault extends Fault +*/ + +record NotFoundFault +record MissingPermissionsFault diff --git a/specification/examples/todo-service/todo-provider-http.samt b/specification/examples/todo-service/todo-provider-http.samt index 4a3a16fd..0f3677c9 100644 --- a/specification/examples/todo-service/todo-provider-http.samt +++ b/specification/examples/todo-service/todo-provider-http.samt @@ -4,38 +4,32 @@ provide TodoEndpointHTTP { implements TodoManager implements TodoListManager - transport HTTP { - serialization: "JSON", + transport http { + serialization: "json", operations: { TodoManager: { - createTodo: "POST /todo", - searchTodo: "GET /todo?title={title}", + createTodo: "POST /todo {cookie:session}", + searchTodo: "GET /todo {query:title}", getTodo: "GET /todo/{id}", getTodos: "GET /todo", updateTodo: "PUT /todo/{id}", deleteTodo: "DELETE /todo/{id}", - markAsCompleted: "PUT /todo/{id}/completed" + markAsCompleted: "PUT /todo/{id} {query:completed}" }, TodoListManager: { createTodoList: "POST /todo-list", - searchTodoList: "GET /todo-list?title={title}", + searchTodoList: "GET /todo-list {query:title}", getTodoList: "GET /todo-list/{id}", getTodoLists: "GET /todo-list", updateTodoList: "PUT /todo-list/{id}", deleteTodoList: "DELETE /todo-list/{id}", - addTodoToList: "PUT /todo-list/{id}/todo/{todoId}", - removeTodoFromList: "DELETE /todo-list/{id}/todo/{todoId}" + addTodoToList: "PUT /todo-list/{listId}/todo/{todoId}", + removeTodoFromList: "DELETE /todo-list/{listId}/todo/{todoId}" } }, faults: { - NotFoundFault: { - code: 404, - message: "Todo not found" - }, - MissingPermissionsFault: { - code: 403, - message: "Missing permissions" - } + NotFoundFault: 404, + MissingPermissionsFault: 403 } } } diff --git a/specification/examples/todo-service/todo-service.samt b/specification/examples/todo-service/todo-service.samt index e37b252a..0a50cd25 100644 --- a/specification/examples/todo-service/todo-service.samt +++ b/specification/examples/todo-service/todo-service.samt @@ -19,7 +19,7 @@ record TodoList { @Description("A service for managing todo items") service TodoManager { - createTodo(title: String, description: String): TodoItem + createTodo(title: String, description: String, session: String): TodoItem searchTodo(title: String): TodoItem? getTodo(id: UUID): TodoItem raises NotFoundFault getTodos(): List