diff --git a/server-core/src/main/kotlin/com/lightningkite/lightningserver/schema/LightningServerKSchemaGenerator.kt b/server-core/src/main/kotlin/com/lightningkite/lightningserver/schema/LightningServerKSchemaGenerator.kt index 927492c6..951b806f 100644 --- a/server-core/src/main/kotlin/com/lightningkite/lightningserver/schema/LightningServerKSchemaGenerator.kt +++ b/server-core/src/main/kotlin/com/lightningkite/lightningserver/schema/LightningServerKSchemaGenerator.kt @@ -21,10 +21,27 @@ val lightningServerKSchema: LightningServerKSchema by lazy { val registry = SerializationRegistry(Serialization.module).also { it.registerShared() } - Documentable.endpoints.flatMap { - sequenceOf(it.inputType, it.outputType) + it.route.path.serializers.asSequence() - }.forEach { - registry.registerVirtualDeep(it) + Documentable.endpoints.forEach { + try { + registry.registerVirtualDeep(it.inputType) + registry.registerVirtualDeep(it.outputType) + it.route.path.serializers.forEach { serializer -> + registry.registerVirtualDeep(serializer) + } + } catch(e: Exception) { + throw IllegalStateException("Failed to generate schema for endpoint ${it.route.endpoint}", e) + } + } + Documentable.websockets.forEach { + try { + registry.registerVirtualDeep(it.inputType) + registry.registerVirtualDeep(it.outputType) + it.path.serializers.forEach { serializer -> + registry.registerVirtualDeep(serializer) + } + } catch(e: Exception) { + throw IllegalStateException("Failed to generate schema for websocket ${it.path}", e) + } } LightningServerKSchema( baseUrl = generalSettings().publicUrl, diff --git a/shared/src/commonMain/kotlin/com/lightningkite/serialization/SerializationRegistry.kt b/shared/src/commonMain/kotlin/com/lightningkite/serialization/SerializationRegistry.kt index 11119456..a5aad429 100644 --- a/shared/src/commonMain/kotlin/com/lightningkite/serialization/SerializationRegistry.kt +++ b/shared/src/commonMain/kotlin/com/lightningkite/serialization/SerializationRegistry.kt @@ -72,10 +72,12 @@ class SerializationRegistry(val module: SerializersModule) { } fun register(serializer: KSerializer<*>) { +// println("$this Registered ${serializer.descriptor.serialName}") direct[serializer.descriptor.serialName] = serializer } fun register(name: String, make: (Array>) -> KSerializer<*>) { +// println("$this Registered $name") @Suppress("UNCHECKED_CAST") factory[name] = make as (Array>) -> KSerializer<*> } @@ -142,6 +144,13 @@ class SerializationRegistry(val module: SerializersModule) { it[1] ) } + register(serializer>().descriptor.serialName) { SetSerializer(it[0]) } + register(serializer>().descriptor.serialName) { + MapSerializer( + it[0], + it[1] + ) + } register( MapEntrySerializer( NothingSerializer(), @@ -188,11 +197,15 @@ class SerializationRegistry(val module: SerializersModule) { } fun registerVirtualDeep(type: KSerializer<*>) { - type.nullElement()?.let { return registerVirtualDeep(it) } - if(registerVirtual(type) != null) { - type.tryChildSerializers()?.forEach { registerVirtualDeep(it) } + try { + type.nullElement()?.let { return registerVirtualDeep(it) } + if (registerVirtual(type) != null) { + type.tryChildSerializers()?.forEach { registerVirtualDeep(it) } + } + type.tryTypeParameterSerializers3()?.forEach { registerVirtualDeep(it) } + } catch(e: Exception) { + throw Exception("Failed to register serializer for ${type.descriptor.serialName}", e) } - type.tryTypeParameterSerializers3()?.forEach { registerVirtualDeep(it) } } fun registerVirtual(type: KSerializer<*>): VirtualType? { type.nullElement()?.let { return registerVirtual(it) } diff --git a/shared/src/commonMain/kotlin/com/lightningkite/serialization/VirtualType.kt b/shared/src/commonMain/kotlin/com/lightningkite/serialization/VirtualType.kt index 88b1de78..1d6fc9d9 100644 --- a/shared/src/commonMain/kotlin/com/lightningkite/serialization/VirtualType.kt +++ b/shared/src/commonMain/kotlin/com/lightningkite/serialization/VirtualType.kt @@ -69,11 +69,19 @@ data class VirtualStruct( } val specifiedDefaults by lazy { fields.zip(serializers) { field, serializer -> - field.defaultJson?.let { DefaultDecoder.json.decodeFromString(serializer, it) } ?: DefaultNotPresent + field.defaultJson?.let { + return@zip DefaultDecoder.json.decodeFromString(serializer, it) + } + DefaultNotPresent } } val defaults by lazy { - serializers.map { it.default() } + fields.zip(serializers) { field, serializer -> + field.defaultJson?.let { + return@zip DefaultDecoder.json.decodeFromString(serializer, it) + } + serializer.default() + } } val defaultInstance by lazy { VirtualInstance(this, defaults) } val serializableProperties: Array> by lazy { @@ -81,6 +89,9 @@ data class VirtualStruct( SerializableProperty.FromVirtualField(it, registry, context) }.toTypedArray() } + val ensureNotNull = fields.withIndex().filter { + !it.value.optional && !it.value.type.isNullable + }.map { it.index }.toIntArray() @Transient override val descriptor: SerialDescriptor by lazy { @@ -96,7 +107,9 @@ data class VirtualStruct( } override fun deserialize(decoder: Decoder): VirtualInstance { - val values = Array(fields.size) { null } + val values = Array(fields.size) { + specifiedDefaults[it].takeUnless { it == DefaultNotPresent } + } val s = decoder.beginStructure(descriptor) while (true) { val index = s.decodeElementIndex(descriptor) @@ -120,6 +133,12 @@ data class VirtualStruct( } } s.endStructure(descriptor) + // Ensure we got everything + ensureNotNull.forEach { index -> + if(values[index] == null) { + throw SerializationException("${fields[index].name} required but was not present") + } + } return VirtualInstance(this, values.asList()) } diff --git a/shared/src/commonTest/kotlin/com/lightningkite/lightningdb/testing/VirtualTypesTest.kt b/shared/src/commonTest/kotlin/com/lightningkite/lightningdb/testing/VirtualTypesTest.kt index 955e1fc0..7660c5b0 100644 --- a/shared/src/commonTest/kotlin/com/lightningkite/lightningdb/testing/VirtualTypesTest.kt +++ b/shared/src/commonTest/kotlin/com/lightningkite/lightningdb/testing/VirtualTypesTest.kt @@ -9,6 +9,7 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.builtins.serializer import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonBuilder import kotlin.test.Test import kotlin.test.assertEquals import kotlin.time.Duration @@ -20,16 +21,23 @@ class VirtualTypesTest { prepareModelsShared() prepareModelsSharedTest() } - fun testVirtualVersion(serializer: KSerializer, instance: T) { + fun testVirtualVersion(serializer: KSerializer, instance: T, builderAction: JsonBuilder.()->Unit = {}) { val virtualRegistry = SerializationRegistry.master.virtualize { it.contains("testing") } val vtype = virtualRegistry.virtualTypes[serializer.descriptor.serialName] as VirtualStruct val vtypeSerializer = virtualRegistry[serializer.descriptor.serialName, serializer.tryTypeParameterSerializers3() ?: arrayOf()] as VirtualStruct.Concrete println(vtypeSerializer.serializers) println(vtype.annotations) - val json = Json { serializersModule = ClientModule; encodeDefaults = true; allowStructuredMapKeys = true } + val json = Json { + serializersModule = ClientModule + encodeDefaults = true + allowStructuredMapKeys = true + builderAction() + } println("Schema: ${json.encodeToString(vtype)}") val original = instance - println(original) + // forward + json.decodeFromString(vtypeSerializer, json.encodeToString(serializer, original)) + val string = json.encodeToString(serializer, original) println(string) val vinst = json.decodeFromString(vtypeSerializer, string) @@ -37,13 +45,23 @@ class VirtualTypesTest { println(vinst) val newString = json.encodeToString(vtypeSerializer, vinst) println(newString) - assertEquals(string, newString) + assertEquals( + json.decodeFromString(serializer, string).toString().split(',').joinToString(",\n"), + json.decodeFromString(serializer, newString).toString().split(',').joinToString(",\n") + ) + assertEquals( + json.decodeFromString(serializer, string), + json.decodeFromString(serializer, newString) + ) - measureTime { - repeat(10000) { - json.encodeToString(vtypeSerializer, json.decodeFromString(vtypeSerializer, string)) - } - }.also { println("Performance: ${it / 10000}") } + // reverse + json.decodeFromString(serializer, json.encodeToString(vtypeSerializer, vinst)) + +// measureTime { +// repeat(10000) { +// json.encodeToString(vtypeSerializer, json.decodeFromString(vtypeSerializer, string)) +// } +// }.also { println("Performance: ${it / 10000}") } } @Test fun testSerializableAnnotation() { val serializer = LargeTestModel.serializer() @@ -57,8 +75,20 @@ class VirtualTypesTest { vtypeSerializer.serializableProperties.find { it.name == "string" }!!.serializableAnnotations ) } - @Test fun testStructure() = testVirtualVersion(LargeTestModel.serializer(), LargeTestModel()) - @Test fun testGeneric() = testVirtualVersion(GenericBox.serializer(Int.serializer()), GenericBox(value = 1, nullable = 2, list = listOf(3, 4))) + @Test fun testStructure() { + testVirtualVersion(LargeTestModel.serializer(), LargeTestModel()) + testVirtualVersion(LargeTestModel.serializer(), LargeTestModel()) { encodeDefaults = false } + } + @Test fun testGeneric() { + testVirtualVersion( + GenericBox.serializer(Int.serializer()), + GenericBox(value = 1, nullable = 2, list = listOf(3, 4)) + ) + testVirtualVersion( + GenericBox.serializer(Int.serializer()), + GenericBox(value = 1, nullable = 2, list = listOf(3, 4)) + ) + } // @Test fun testEnum() { // val vtype = SampleA.serializer().makeVirtualType() as VirtualEnum // val json = Json { serializersModule = ClientModule; encodeDefaults = true }