Skip to content

Commit

Permalink
Register HashMap and HashSet
Browse files Browse the repository at this point in the history
  • Loading branch information
UnknownJoe796 committed Nov 19, 2024
1 parent 813bbac commit 4cf8bde
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,27 @@ val lightningServerKSchema: LightningServerKSchema by lazy {
val registry = SerializationRegistry(Serialization.module).also {
it.registerShared()
}
Documentable.endpoints.flatMap {
sequenceOf(it.inputType, it.outputType) + it.route.path.serializers.asSequence()
}.forEach {
registry.registerVirtualDeep(it)
Documentable.endpoints.forEach {
try {
registry.registerVirtualDeep(it.inputType)
registry.registerVirtualDeep(it.outputType)
it.route.path.serializers.forEach { serializer ->
registry.registerVirtualDeep(serializer)
}
} catch(e: Exception) {
throw IllegalStateException("Failed to generate schema for endpoint ${it.route.endpoint}", e)
}
}
Documentable.websockets.forEach {
try {
registry.registerVirtualDeep(it.inputType)
registry.registerVirtualDeep(it.outputType)
it.path.serializers.forEach { serializer ->
registry.registerVirtualDeep(serializer)
}
} catch(e: Exception) {
throw IllegalStateException("Failed to generate schema for websocket ${it.path}", e)
}
}
LightningServerKSchema(
baseUrl = generalSettings().publicUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ class SerializationRegistry(val module: SerializersModule) {
}

fun register(serializer: KSerializer<*>) {
// println("$this Registered ${serializer.descriptor.serialName}")
direct[serializer.descriptor.serialName] = serializer
}

fun register(name: String, make: (Array<KSerializer<Nothing>>) -> KSerializer<*>) {
// println("$this Registered $name")
@Suppress("UNCHECKED_CAST")
factory[name] = make as (Array<KSerializer<*>>) -> KSerializer<*>
}
Expand Down Expand Up @@ -142,6 +144,13 @@ class SerializationRegistry(val module: SerializersModule) {
it[1]
)
}
register(serializer<HashSet<Int>>().descriptor.serialName) { SetSerializer(it[0]) }
register(serializer<HashMap<String, Int>>().descriptor.serialName) {
MapSerializer(
it[0],
it[1]
)
}
register(
MapEntrySerializer(
NothingSerializer(),
Expand Down Expand Up @@ -188,11 +197,15 @@ class SerializationRegistry(val module: SerializersModule) {
}

fun registerVirtualDeep(type: KSerializer<*>) {
type.nullElement()?.let { return registerVirtualDeep(it) }
if(registerVirtual(type) != null) {
type.tryChildSerializers()?.forEach { registerVirtualDeep(it) }
try {
type.nullElement()?.let { return registerVirtualDeep(it) }
if (registerVirtual(type) != null) {
type.tryChildSerializers()?.forEach { registerVirtualDeep(it) }
}
type.tryTypeParameterSerializers3()?.forEach { registerVirtualDeep(it) }
} catch(e: Exception) {
throw Exception("Failed to register serializer for ${type.descriptor.serialName}", e)
}
type.tryTypeParameterSerializers3()?.forEach { registerVirtualDeep(it) }
}
fun registerVirtual(type: KSerializer<*>): VirtualType? {
type.nullElement()?.let { return registerVirtual(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ data class VirtualStruct(
}
val specifiedDefaults by lazy {
fields.zip(serializers) { field, serializer ->
field.defaultJson?.let { DefaultDecoder.json.decodeFromString(serializer, it) } ?: DefaultNotPresent
field.defaultJson?.let {
return@zip DefaultDecoder.json.decodeFromString(serializer, it)
}
DefaultNotPresent
}
}
val defaults by lazy {
Expand All @@ -81,6 +84,9 @@ data class VirtualStruct(
SerializableProperty.FromVirtualField(it, registry, context)
}.toTypedArray()
}
val ensureNotNull = fields.withIndex().filter {
!it.value.optional && !it.value.type.isNullable
}.map { it.index }.toIntArray()

@Transient
override val descriptor: SerialDescriptor by lazy {
Expand All @@ -96,7 +102,7 @@ data class VirtualStruct(
}

override fun deserialize(decoder: Decoder): VirtualInstance {
val values = Array<Any?>(fields.size) { null }
val values = Array<Any?>(fields.size) { specifiedDefaults[it] }
val s = decoder.beginStructure(descriptor)
while (true) {
val index = s.decodeElementIndex(descriptor)
Expand All @@ -120,6 +126,12 @@ data class VirtualStruct(
}
}
s.endStructure(descriptor)
// Ensure we got everything
ensureNotNull.forEach { index ->
if(values[index] == null) {
throw SerializationException("${fields[index].name} required but was not present")
}
}
return VirtualInstance(this, values.asList())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonBuilder
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.time.Duration
Expand All @@ -20,30 +21,47 @@ class VirtualTypesTest {
prepareModelsShared()
prepareModelsSharedTest()
}
fun <T> testVirtualVersion(serializer: KSerializer<T>, instance: T) {
fun <T> testVirtualVersion(serializer: KSerializer<T>, instance: T, builderAction: JsonBuilder.()->Unit = {}) {
val virtualRegistry = SerializationRegistry.master.virtualize { it.contains("testing") }
val vtype = virtualRegistry.virtualTypes[serializer.descriptor.serialName] as VirtualStruct
val vtypeSerializer = virtualRegistry[serializer.descriptor.serialName, serializer.tryTypeParameterSerializers3() ?: arrayOf()] as VirtualStruct.Concrete
println(vtypeSerializer.serializers)
println(vtype.annotations)
val json = Json { serializersModule = ClientModule; encodeDefaults = true; allowStructuredMapKeys = true }
val json = Json {
serializersModule = ClientModule
encodeDefaults = true
allowStructuredMapKeys = true
builderAction()
}
println("Schema: ${json.encodeToString(vtype)}")
val original = instance
println(original)
// forward
json.decodeFromString(vtypeSerializer, json.encodeToString(serializer, original))

val string = json.encodeToString(serializer, original)
println(string)
val vinst = json.decodeFromString(vtypeSerializer, string)
println(vtype)
println(vinst)
val newString = json.encodeToString(vtypeSerializer, vinst)
println(newString)
assertEquals(string, newString)
assertEquals(
json.decodeFromString(serializer, string).toString().split(',').joinToString(",\n"),
json.decodeFromString(serializer, newString).toString().split(',').joinToString(",\n")
)
assertEquals(
json.decodeFromString(serializer, string),
json.decodeFromString(serializer, newString)
)

measureTime {
repeat(10000) {
json.encodeToString(vtypeSerializer, json.decodeFromString(vtypeSerializer, string))
}
}.also { println("Performance: ${it / 10000}") }
// reverse
json.decodeFromString(serializer, json.encodeToString(vtypeSerializer, vinst))

// measureTime {
// repeat(10000) {
// json.encodeToString(vtypeSerializer, json.decodeFromString(vtypeSerializer, string))
// }
// }.also { println("Performance: ${it / 10000}") }
}
@Test fun testSerializableAnnotation() {
val serializer = LargeTestModel.serializer()
Expand All @@ -57,8 +75,20 @@ class VirtualTypesTest {
vtypeSerializer.serializableProperties.find { it.name == "string" }!!.serializableAnnotations
)
}
@Test fun testStructure() = testVirtualVersion(LargeTestModel.serializer(), LargeTestModel())
@Test fun testGeneric() = testVirtualVersion(GenericBox.serializer(Int.serializer()), GenericBox(value = 1, nullable = 2, list = listOf(3, 4)))
@Test fun testStructure() {
testVirtualVersion(LargeTestModel.serializer(), LargeTestModel())
testVirtualVersion(LargeTestModel.serializer(), LargeTestModel()) { encodeDefaults = false }
}
@Test fun testGeneric() {
testVirtualVersion(
GenericBox.serializer(Int.serializer()),
GenericBox(value = 1, nullable = 2, list = listOf(3, 4))
)
testVirtualVersion(
GenericBox.serializer(Int.serializer()),
GenericBox(value = 1, nullable = 2, list = listOf(3, 4))
)
}
// @Test fun testEnum() {
// val vtype = SampleA.serializer().makeVirtualType() as VirtualEnum
// val json = Json { serializersModule = ClientModule; encodeDefaults = true }
Expand Down

0 comments on commit 4cf8bde

Please sign in to comment.