Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-enable refined module #821

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait CollectionDecoders:
given[T](using decoder: Decoder[T]): Decoder[Seq[T]] = iterableDecoder(decoder, _.toSeq)
given[T](using decoder: Decoder[T]): Decoder[Set[T]] = iterableDecoder(decoder, _.toSet)
given[T](using decoder: Decoder[T]): Decoder[Vector[T]] = iterableDecoder(decoder, _.toVector)
given[T](using decoder: Decoder[T]): Decoder[Map[String, T]] = new MapDecoder[T](decoder)
given mapDecoder[T](using decoder: Decoder[T]): Decoder[Map[String, T]] = new MapDecoder[T](decoder)

def iterableDecoder[T, C[X] <: Iterable[X]](decoder: Decoder[T],
build: Iterable[T] => C[T]): Decoder[C[T]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ trait CollectionSchemas:

given[T](using schemaFor: SchemaFor[T]): SchemaFor[List[T]] = buildIterableSchemaFor[List, T]

given[V](using schemaFor: SchemaFor[V]): SchemaFor[Map[String, V]] =
given mapSchemaFor[V](using schemaFor: SchemaFor[V]): SchemaFor[Map[String, V]] =
schemaFor.map(SchemaBuilder.map().values(_))

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.sksamuel.avro4s

import eu.timepit.refined.api.{RefType, Validate}

package object refined:

given[T, P, F[_, _]](using schemaFor: SchemaFor[T]): SchemaFor[F[T, P]] = schemaFor.forType

given[T: Encoder, P, F[_, _] : RefType]: Encoder[F[T, P]] = Encoder[T].contramap(RefType[F].unwrap)

given[T: Decoder, P, F[_, _] : RefType](using validate: Validate[T, P]): Decoder[F[T, P]] = Decoder[T].map(RefType[F].refine[P].unsafeFrom[T])

given[A, P, F[_, _]: RefType, B](using schemaForA: SchemaFor[A], schemaForB: SchemaFor[B], isString: A <:< String): SchemaFor[Map[F[A, P], B]] =
SchemaFor.mapSchemaFor[B].forType

given[A: Encoder, B: Encoder, P, F[_, _]: RefType](using isString: A <:< String): Encoder[Map[F[A, P], B]] =
Encoder.mapEncoder[B].contramap[Map[F[A, P], B]]: theMap =>
theMap.map:
case (k, v) => RefType[F].unwrap(k).asInstanceOf[String] -> v

given[A: Decoder, B: Decoder, P, F[_, _]: RefType](using validate: Validate[A, P], isString: A <:< String): Decoder[Map[F[A, P], B]] =
Decoder.mapDecoder[B].map: theMap =>
theMap.map:
case (str, b) => (RefType[F].refine[P].unsafeFrom[A](str.asInstanceOf[A]), b)

// implicit def refinedTypeGuardedDecoding[T: WeakTypeTag, P, F[_, _]: RefType]: TypeGuardedDecoding[F[T, P]] = new TypeGuardedDecoding[F[T, P]] {
// override final def guard(decoderT: Decoder[F[T, P]]): PartialFunction[Any, F[T, P]] =
// TypeGuardedDecoding[T].guard(decoderT.map(RefType[F].unwrap)).andThen(RefType[F].unsafeWrap(_))
// }
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.sksamuel.avro4s.refined

import com.sksamuel.avro4s.streams.input.InputStreamTest
import eu.timepit.refined.types.numeric.PosInt
import eu.timepit.refined.types.string.NonEmptyString
import shapeless.*

import scala.util.Failure

class RefinedRoundtripTest extends InputStreamTest:

type C1 = NonEmptyString :+: CNil
case class Container1(c1: C1)
// type C2 = Int :+: NonEmptyString :+: CNil
// case class Container2(c2: C2)
// type C3 = PosInt :+: NonEmptyString :+: CNil
// case class Container3(c3: C3)
// case class Container4(map: Map[String, NonEmptyString], c3: C3, list: List[(Int, PosInt)])
// type C1b = String :+: CNil
// case class Container1b(c1: C1b)
// case class Container5(c5: Either[NonEmptyString, Int])
// case class Container6(c6: Map[NonEmptyString, PosInt])

// test("a union of one refined type inside a record should rountrip"):
// writeRead(Container1(Coproduct[C1](NonEmptyString.unsafeFrom("a"))))

// test("a union of one refined type and more standard types inside a record should rountrip") {
// writeRead(Container2(Coproduct[C2](NonEmptyString("a"))))
// }

// test("a union of more than one refined type inside a record should rountrip") {
// writeRead(Container3(Coproduct[C3](PosInt(42))))
// }

// test("a more complex record should rountrip") {
// writeRead(Container4(Map("bla" -> NonEmptyString("a")), Coproduct[C3](NonEmptyString("b")), List(23 -> PosInt(42), 42 -> PosInt(23))))
// }

// test("a broken encoder will not decode") {
// val out = writeData(Container1b(Coproduct[C1b]("")))
// val result = tryReadData[Container1](out.toByteArray).next()
// result should matchPattern { case Failure(iae: IllegalArgumentException) if iae.getMessage == "Predicate isEmpty() did not fail." => }
// }

// test("an either of one refined type inside a record should roundtrip") {
// writeRead(Container5(Left(NonEmptyString("a"))))
// }

// test("a map with refined types on both key and value should roundtrip") {
// val key: NonEmptyString = NonEmptyString("foo")
// val value: PosInt = PosInt(1)
// writeRead(Container6(Map(key -> value)))
// }
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package com.sksamuel.avro4s.refined

import com.sksamuel.avro4s.*
import eu.timepit.refined.api.Refined
import eu.timepit.refined.collection.NonEmpty
import org.apache.avro.Schema
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import eu.timepit.refined.auto.*
import eu.timepit.refined.types.string.NonEmptyString
import eu.timepit.refined.types.numeric.NonNegInt

case class Foo(nonEmptyStr: String Refined NonEmpty)
case class FooMap(nonEmptyStrKeyMap: Map[NonEmptyString, NonNegInt])

class RefinedTest extends AnyWordSpec with Matchers:
val fooSchema: Schema = AvroSchema[Foo]
val fooMapSchema: Schema = AvroSchema[FooMap]

"refinedSchemaFor" should :
"use the schema for the underlying type" in:
AvroSchema[Foo] shouldBe new Schema.Parser().parse(
"""
|{
| "type": "record",
| "name": "Foo",
| "namespace": "com.sksamuel.avro4s.refined",
| "fields": [{
| "name": "nonEmptyStr",
| "type": "string"
| }]
|}
""".stripMargin)

"generate correct schemas for a Map when refined instances are in scope" in:
case class Test(map: Map[String, Int], nonEmptyStr: String Refined NonEmpty)
val schema = AvroSchema[Test]

println(s"schema: $schema")

schema.getField("map").schema().getType shouldBe Schema.Type.MAP
schema.getField("nonEmptyStr").schema().getType shouldBe Schema.Type.STRING

"refinedStringMapKeySchemaFor" should:
"use the schema for the underlying type" in:
AvroSchema[FooMap] shouldBe new Schema.Parser().parse(
"""
|{
| "type": "record",
| "name": "FooMap",
| "namespace": "com.sksamuel.avro4s.refined",
| "fields": [{
| "name": "nonEmptyStrKeyMap",
| "type": {
| "type": "map",
| "values": "int"
| }
| }]
|}
""".stripMargin
)

"refinedEncoder" should:
"use the encoder for the underlying type" in:
val expected: String Refined NonEmpty = NonEmptyString.unsafeFrom("foo")
val record = ToRecord[Foo](fooSchema).to(Foo(expected))
record.get("nonEmptyStr").toString shouldBe expected.value

"refinedStringMapKeyEncoder" should:
"use the encoder for the underlying type" in:
val key: NonEmptyString = NonEmptyString.unsafeFrom("foo")
val value: NonNegInt = NonNegInt.unsafeFrom(1)
val expected: Map[NonEmptyString, NonNegInt] = Map(key -> value)
val record = ToRecord[FooMap](fooMapSchema).to(FooMap(expected))
val encodedMap = record.get("nonEmptyStrKeyMap").asInstanceOf[java.util.Map[String, Int]]
encodedMap.get(key.value) shouldBe value.value

"refinedDecoder" should:
"use the decoder for the underlying type" in:
val expected: String Refined NonEmpty = NonEmptyString.unsafeFrom("foo")
val record = ImmutableRecord(AvroSchema[Foo], Vector(expected.value))
FromRecord[Foo](fooSchema).from(record) shouldBe Foo(expected)

"throw when the value does not conform to the refined predicate" in:
val record = ImmutableRecord(AvroSchema[Foo], Vector(""))
assertThrows[IllegalArgumentException](FromRecord[Foo](fooSchema).from(record))

"refinedStringMapKeyDecoder" should:
"use the decoder for the underlying type" in:
val key: NonEmptyString = NonEmptyString.unsafeFrom("foo")
val value: NonNegInt = NonNegInt.unsafeFrom(1)

val jMap = new java.util.HashMap[String, Int]()
jMap.put(key.value, value.value)

val expected = Map(key -> value)
val record = ImmutableRecord(AvroSchema[FooMap], Vector(jMap))

FromRecord[FooMap](fooMapSchema).from(record) shouldBe FooMap(expected)
13 changes: 11 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ lazy val root = Project("avro4s", file("."))
)
.aggregate(
`avro4s-core`,
`avro4s-cats`
// `avro4s-kafka`
`avro4s-cats`,
// `avro4s-kafka`
`avro4s-refined`
)

val `avro4s-core` = project.in(file("avro4s-core"))
Expand Down Expand Up @@ -44,6 +45,14 @@ val `avro4s-cats` = project.in(file("avro4s-cats"))
// )
// )

val `avro4s-refined` = project.in(file("avro4s-refined"))
.dependsOn(`avro4s-core` % "compile->compile;test->test")
.settings(
libraryDependencies ++= Seq(
"eu.timepit" %% "refined" % RefinedVersion
)
)

val benchmarks = project
.in(file("benchmarks"))
.dependsOn(`avro4s-core`)
Expand Down