Skip to content

Commit

Permalink
Non empty collection schemas (zio#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Aug 10, 2024
1 parent eb71680 commit be70abe
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 91 deletions.
2 changes: 2 additions & 0 deletions tests/shared/src/test/scala/zio/schema/DynamicValueGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ object DynamicValueGen {
case Schema.Enum22(_, case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22, _) => anyDynamicValueOfEnum(Chunk(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22))
case Schema.EnumN(_, cases, _) => anyDynamicValueOfEnum(Chunk.fromIterable(cases.toSeq))
case Schema.Sequence(schema, _, _, _, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.NonEmptySequence(schema, _, _, _, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.Map(ks, vs, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.NonEmptyMap(ks, vs, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.Set(schema, _) => Gen.setOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.SetValue(_))
case Schema.Optional(schema, _) => Gen.oneOf(anyDynamicSomeValueOfSchema(schema), Gen.const(DynamicValue.NoneValue))
case Schema.Tuple2(left, right, _) => anyDynamicTupleValue(left, right)
Expand Down
33 changes: 26 additions & 7 deletions zio-schema-avro/src/main/scala/zio/schema/codec/AvroCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package zio.schema.codec
import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer
import java.util.UUID

import scala.collection.immutable.ListMap
import scala.jdk.CollectionConverters._
import scala.util.Try

import org.apache.avro.generic.{
GenericData,
GenericDatumReader,
Expand All @@ -18,7 +16,7 @@ import org.apache.avro.generic.{
import org.apache.avro.io.{ DecoderFactory, EncoderFactory }
import org.apache.avro.util.Utf8
import org.apache.avro.{ Conversions, LogicalTypes, Schema => SchemaAvro }

import zio.prelude.NonEmptyMap
import zio.schema.{ Fallback, FieldSet, Schema, StandardType, TypeId }
import zio.stream.ZPipeline
import zio.{ Chunk, Unsafe, ZIO }
Expand Down Expand Up @@ -201,9 +199,20 @@ object AvroCodec {
case record: Schema.Record[_] => decodeRecord(raw, record).map(_.asInstanceOf[A])
case Schema.Sequence(element, f, _, _, _) =>
decodeSequence(raw, element.asInstanceOf[Schema[Any]]).map(f.asInstanceOf[Chunk[Any] => A])
case nes @ Schema.NonEmptySequence(element, _, _, _, _) =>
decodeSequence(raw, element.asInstanceOf[Schema[Any]]).map(nes.fromChunk.asInstanceOf[Chunk[Any] => A])
case Schema.Set(element, _) => decodeSequence(raw, element.asInstanceOf[Schema[Any]]).map(_.toSet.asInstanceOf[A])
case mapSchema: Schema.Map[_, _] =>
decodeMap(raw, mapSchema.asInstanceOf[Schema.Map[Any, Any]]).map(_.asInstanceOf[A])
case mapSchema: Schema.NonEmptyMap[_, _] =>
decodeMap(
raw,
Schema.Map(
mapSchema.keySchema.asInstanceOf[Schema[Any]],
mapSchema.valueSchema.asInstanceOf[Schema[Any]],
mapSchema.annotations
)
).map(mapSchema.asInstanceOf[Schema.NonEmptyMap[Any, Any]].fromMap(_).asInstanceOf[A])
case Schema.Transform(schema, f, _, _, _) =>
decodeValue(raw, schema).flatMap(
a => f(a).left.map(msg => DecodeError.MalformedFieldWithPath(Chunk.single("Error"), msg))
Expand Down Expand Up @@ -662,12 +671,22 @@ object AvroCodec {
c21,
c22
)
case Schema.GenericRecord(typeId, structure, _) => encodeGenericRecord(a, typeId, structure)
case Schema.Primitive(standardType, _) => encodePrimitive(a, standardType)
case Schema.Sequence(element, _, g, _, _) => encodeSequence(element, g(a))
case Schema.Set(element, _) => encodeSet(element, a)
case Schema.GenericRecord(typeId, structure, _) => encodeGenericRecord(a, typeId, structure)
case Schema.Primitive(standardType, _) => encodePrimitive(a, standardType)
case Schema.Sequence(element, _, g, _, _) => encodeSequence(element, g(a))
case Schema.NonEmptySequence(element, _, g, _, _) => encodeSequence(element, g(a))
case Schema.Set(element, _) => encodeSet(element, a)
case mapSchema: Schema.Map[_, _] =>
encodeMap(mapSchema.asInstanceOf[Schema.Map[Any, Any]], a.asInstanceOf[scala.collection.immutable.Map[Any, Any]])
case mapSchema: Schema.NonEmptyMap[_, _] =>
encodeMap(
Schema.Map(
mapSchema.keySchema.asInstanceOf[Schema[Any]],
mapSchema.valueSchema.asInstanceOf[Schema[Any]],
mapSchema.annotations
),
a.asInstanceOf[NonEmptyMap[Any, Any]].toMap
)
case Schema.Transform(schema, _, g, _, _) =>
g(a).map(encodeValue(_, schema)).getOrElse(throw new Exception("Transform failed."))
case Schema.Optional(schema, _) => encodeOption(schema, a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,22 +457,24 @@ object BsonSchemaCodec {
//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaEncoder[A](schema: Schema[A]): BsonEncoder[A] =
schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => chunkEncoder(schemaEncoder(schema)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs)
case Schema.Set(s, _) => chunkEncoder(schemaEncoder(s)).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, _, _) => transformEncoder(c, g)
case Schema.Tuple2(l, r, _) => tuple2Encoder(schemaEncoder(l), schemaEncoder(r))
case Schema.Optional(schema, _) => BsonEncoder.option(schemaEncoder(schema))
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
case Schema.GenericRecord(_, structure, _) => genericRecordEncoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherEncoder(schemaEncoder(left), schemaEncoder(right))
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left), schemaEncoder(right))
case l @ Schema.Lazy(_) => schemaEncoder(l.schema)
case r: Schema.Record[A] => caseClassEncoder(r)
case e: Schema.Enum[A] => enumEncoder(e, e.cases)
case d @ Schema.Dynamic(_) => dynamicEncoder(d)
case null => throw new Exception(s"A captured schema is null, most likely due to wrong field initialization order")
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => chunkEncoder(schemaEncoder(schema)).contramap(g)
case Schema.NonEmptySequence(schema, _, g, _, _) => chunkEncoder(schemaEncoder(schema)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs)
case Schema.NonEmptyMap(ks, vs, _) => mapEncoder(ks, vs).contramap(_.toMap)
case Schema.Set(s, _) => chunkEncoder(schemaEncoder(s)).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, _, _) => transformEncoder(c, g)
case Schema.Tuple2(l, r, _) => tuple2Encoder(schemaEncoder(l), schemaEncoder(r))
case Schema.Optional(schema, _) => BsonEncoder.option(schemaEncoder(schema))
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
case Schema.GenericRecord(_, structure, _) => genericRecordEncoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherEncoder(schemaEncoder(left), schemaEncoder(right))
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left), schemaEncoder(right))
case l @ Schema.Lazy(_) => schemaEncoder(l.schema)
case r: Schema.Record[A] => caseClassEncoder(r)
case e: Schema.Enum[A] => enumEncoder(e, e.cases)
case d @ Schema.Dynamic(_) => dynamicEncoder(d)
case null => throw new Exception(s"A captured schema is null, most likely due to wrong field initialization order")
}
//scalafmt: { maxColumn = 120, optIn.configStyleArguments = true }

Expand Down Expand Up @@ -773,22 +775,24 @@ object BsonSchemaCodec {

//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaDecoder[A](schema: Schema[A]): BsonDecoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => BsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple2(left, right, _) => tuple2Decoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => chunkDecoder(schemaDecoder(codec)).map(f)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case Schema.Set(s, _) => chunkDecoder(schemaDecoder(s)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherDecoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Fallback(left, right, _, _) => fallbackDecoder(schemaDecoder(left), schemaDecoder(right))
case l @ Schema.Lazy(_) => schemaDecoder(l.schema)
case s: Schema.Record[A] => caseClassDecoder(s)
case e: Schema.Enum[A] => enumDecoder(e)
case d @ Schema.Dynamic(_) => dynamicDecoder(d)
case null => throw new Exception(s"Missing a handler for decoding of schema $schema.")
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => BsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple2(left, right, _) => tuple2Decoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => chunkDecoder(schemaDecoder(codec)).map(f)
case s @ Schema.NonEmptySequence(codec, _, _, _, _) => chunkDecoder(schemaDecoder(codec)).map(s.fromChunk)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case s @ Schema.NonEmptyMap(ks, vs, _) => mapDecoder(ks, vs).map(s.fromMap)
case Schema.Set(s, _) => chunkDecoder(schemaDecoder(s)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherDecoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Fallback(left, right, _, _) => fallbackDecoder(schemaDecoder(left), schemaDecoder(right))
case l @ Schema.Lazy(_) => schemaDecoder(l.schema)
case s: Schema.Record[A] => caseClassDecoder(s)
case e: Schema.Enum[A] => enumDecoder(e)
case d @ Schema.Dynamic(_) => dynamicDecoder(d)
case _ => throw new Exception(s"Missing a handler for decoding of schema $schema.")
}
//scalafmt: { maxColumn = 120, optIn.configStyleArguments = true }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import zio.json.{
JsonFieldDecoder,
JsonFieldEncoder
}
import zio.prelude.NonEmptyMap
import zio.schema._
import zio.schema.annotation._
import zio.schema.codec.DecodeError.ReadError
Expand Down Expand Up @@ -182,9 +183,11 @@ object JsonCodec {
//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaEncoder[A](schema: Schema[A], cfg: Config, discriminatorTuple: DiscriminatorTuple = Chunk.empty): ZJsonEncoder[A] =
schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g)
case Schema.NonEmptySequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg)
case Schema.NonEmptyMap(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg).contramap[NonEmptyMap[Any, Any]](_.toMap.asInstanceOf[Map[Any, Any]]).asInstanceOf[ZJsonEncoder[A]]
case Schema.Set(s, _) =>
ZJsonEncoder.chunk(schemaEncoder(s, cfg, discriminatorTuple)).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg)
Expand Down Expand Up @@ -544,18 +547,20 @@ object JsonCodec {

//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaDecoder[A](schema: Schema[A], discriminator: Int = -1): ZJsonDecoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => option(schemaDecoder(codec, discriminator))
case Schema.Tuple2(left, right, _) => ZJsonDecoder.tuple2(schemaDecoder(left, -1), schemaDecoder(right, -1))
case Schema.Transform(c, f, _, a, _) => schemaDecoder(a.foldLeft(c)((s, a) => s.annotate(a)), discriminator).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(f)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s, -1)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk, schema.annotations.contains(rejectExtraFields()))
case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left, -1), schemaDecoder(right, -1))
case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s)
case l @ Schema.Lazy(_) => schemaDecoder(l.schema, discriminator)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => option(schemaDecoder(codec, discriminator))
case Schema.Tuple2(left, right, _) => ZJsonDecoder.tuple2(schemaDecoder(left, -1), schemaDecoder(right, -1))
case Schema.Transform(c, f, _, a, _) => schemaDecoder(a.foldLeft(c)((s, a) => s.annotate(a)), discriminator).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(f)
case s @ Schema.NonEmptySequence(codec, _, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(s.fromChunk)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case Schema.NonEmptyMap(ks, vs, _) => mapDecoder(ks, vs).mapOrFail(m => NonEmptyMap.fromMapOption(m).toRight("NonEmptyMap expected"))
case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s, -1)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk, schema.annotations.contains(rejectExtraFields()))
case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left, -1), schemaDecoder(right, -1))
case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s)
case l @ Schema.Lazy(_) => schemaDecoder(l.schema, discriminator)
//case Schema.Meta(_, _) => astDecoder
case s @ Schema.CaseClass0(_, _, _) => caseClass0Decoder(discriminator, s)
case s @ Schema.CaseClass1(_, _, _, _) => caseClass1Decoder(discriminator, s)
Expand Down
Loading

0 comments on commit be70abe

Please sign in to comment.