From 3f251c9229a9960c5554c3eec694b75629447692 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Mon, 17 Jan 2022 18:37:41 -0500 Subject: [PATCH] Scalafmt --- .github/workflows/ci.yml | 2 +- .scalafmt.conf | 9 +- .../main/scala/sbt/contraband/CodeGen.scala | 176 ++++++------- .../scala/sbt/contraband/CodecCodeGen.scala | 124 ++++----- .../src/main/scala/sbt/contraband/Dag.scala | 131 +++++----- .../scala/sbt/contraband/Indentation.scala | 2 + .../scala/sbt/contraband/JavaCodeGen.scala | 55 ++-- .../scala/sbt/contraband/MixedCodeGen.scala | 42 ++- .../scala/sbt/contraband/ScalaCodeGen.scala | 144 ++++++----- .../scala/sbt/contraband/VersionNumber.scala | 5 +- .../scala/sbt/contraband/ast/SchemaAst.scala | 162 ++++++------ .../sbt/contraband/parser/JsonParser.scala | 218 ++++++++-------- .../sbt/contraband/parser/SchemaParser.scala | 244 ++++++++++++------ .../test/scala/GraphQLMixedCodeGenSpec.scala | 18 +- .../src/test/scala/JsonCodecCodeGenSpec.scala | 46 ++-- .../src/test/scala/JsonScalaCodeGenSpec.scala | 52 ++-- library/src/test/scala/JsonSchemaSpec.scala | 20 +- library/src/test/scala/TestUtils.scala | 8 +- plugin/src/main/scala/ContrabandPlugin.scala | 149 +++++++---- project/plugins.sbt | 1 + 20 files changed, 880 insertions(+), 728 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d21f81..32c533a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: run: sbt -v "++2.13.8!" library/test - if: ${{ matrix.jobtype == 2 }} shell: bash - run: sbt -v "++2.12.15!" test scripted + run: sbt -v "++2.12.15!" scalafmtCheckAll test scripted - shell: bash run: | rm -rf "$HOME/.sbt/scripted/" || true diff --git a/.scalafmt.conf b/.scalafmt.conf index 768a031..e0d9a66 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,11 +1,14 @@ -version = 2.3.2 +version = 3.3.1 maxColumn = 140 project.git = true project.excludeFilters = [ /sbt-test/, /input_sources/, /contraband-scala/ ] +runner.dialect = scala213 + # http://docs.scala-lang.org/style/scaladoc.html recommends the JavaDoc style. # scala/scala is written that way too https://github.com/scala/scala/blob/v2.12.2/src/library/scala/Predef.scala -docstrings = JavaDoc +docstrings.style = Asterisk +docstrings.wrap = no # This also seems more idiomatic to include whitespace in import x.{ yyy } spaces.inImportCurlyBraces = true @@ -16,6 +19,6 @@ align.openParenCallSite = false align.openParenDefnSite = false # For better code clarity -danglingParentheses = true +danglingParentheses.preset = true trailingCommas = preserve diff --git a/library/src/main/scala/sbt/contraband/CodeGen.scala b/library/src/main/scala/sbt/contraband/CodeGen.scala index 0f4e043..c85c3c0 100644 --- a/library/src/main/scala/sbt/contraband/CodeGen.scala +++ b/library/src/main/scala/sbt/contraband/CodeGen.scala @@ -10,8 +10,8 @@ import AstUtil._ */ abstract class CodeGenerator { - //make sure that EOL is *not* platform dependent by default, otherwise - //the output of contraband will be platform dependent too. + // make sure that EOL is *not* platform dependent by default, otherwise + // the output of contraband will be platform dependent too. val EOL = "\n" implicit class ListMapOp[T](m: ListMap[T, String]) { @@ -41,103 +41,100 @@ abstract class CodeGenerator { final def indentWith(config: IndentationConfiguration): String = { val buffer = new IndentationAwareBuffer(config) - code.linesIterator foreach buffer .+= + code.linesIterator foreach buffer.+= buffer.toString } } - protected def lookupInterfaces(s: Document, interfaceRefs: List[ast.NamedType]): List[InterfaceTypeDefinition] = - { - val pkg = - s.packageDecl map { case PackageDecl(nameSegments, _, _, _) => - nameSegments.mkString(".") - } - val refs = - interfaceRefs map { ref => - ref.names match { - case Nil => sys.error(s"Invalid named type: $ref") - case xs => - val namespace = xs.init match { - case Nil => pkg - case xs => Some(xs.mkString(".")) - } - (namespace, xs.last) - } - } - refs map { ref => lookupInterface(s, ref) } - } - - protected def lookupInterface(s: Document, ref: (Option[String], String)): InterfaceTypeDefinition = - { - val (ns, name) = ref - val intfs = s.definitions collect { - case i: InterfaceTypeDefinition => i + protected def lookupInterfaces(s: Document, interfaceRefs: List[ast.NamedType]): List[InterfaceTypeDefinition] = { + val pkg = + s.packageDecl map { case PackageDecl(nameSegments, _, _, _) => + nameSegments.mkString(".") } - (intfs find { i => - i.name == name && i.namespace == ns - }) match { - case Some(i) => i - case _ => sys.error(s"$ref not found") + val refs = + interfaceRefs map { ref => + ref.names match { + case Nil => sys.error(s"Invalid named type: $ref") + case xs => + val namespace = xs.init match { + case Nil => pkg + case xs => Some(xs.mkString(".")) + } + (namespace, xs.last) + } } + refs map { ref => lookupInterface(s, ref) } + } + + protected def lookupInterface(s: Document, ref: (Option[String], String)): InterfaceTypeDefinition = { + val (ns, name) = ref + val intfs = s.definitions collect { case i: InterfaceTypeDefinition => + i } + (intfs find { i => + i.name == name && i.namespace == ns + }) match { + case Some(i) => i + case _ => sys.error(s"$ref not found") + } + } - protected def lookupChildLeaves(s: Document, interface: InterfaceTypeDefinition): List[TypeDefinition] = - { - val pkg = - s.packageDecl map { case PackageDecl(nameSegments, _, _, _) => - nameSegments.mkString(".") - } - val tpe = toNamedType(interface, pkg) - def containsTpe(intfs: List[NamedType]): Boolean = - intfs exists { ref => - ref.names.size match { - case 0 => sys.error(s"Invalid reference $intfs") - case 1 => ref.names.head == tpe.names.last - case _ => ref.names == tpe.names - } + protected def lookupChildLeaves(s: Document, interface: InterfaceTypeDefinition): List[TypeDefinition] = { + val pkg = + s.packageDecl map { case PackageDecl(nameSegments, _, _, _) => + nameSegments.mkString(".") + } + val tpe = toNamedType(interface, pkg) + def containsTpe(intfs: List[NamedType]): Boolean = + intfs exists { ref => + ref.names.size match { + case 0 => sys.error(s"Invalid reference $intfs") + case 1 => ref.names.head == tpe.names.last + case _ => ref.names == tpe.names } - s.definitions flatMap { - case r: ObjectTypeDefinition if containsTpe(r.interfaces) => List(r) - case i: InterfaceTypeDefinition if containsTpe(i.interfaces) => lookupChildLeaves(s, i) - case _ => Nil } + s.definitions flatMap { + case r: ObjectTypeDefinition if containsTpe(r.interfaces) => List(r) + case i: InterfaceTypeDefinition if containsTpe(i.interfaces) => lookupChildLeaves(s, i) + case _ => Nil } + } - protected def lookupChildren(s: Document, interface: InterfaceTypeDefinition): List[TypeDefinition] = - { - val pkg = - s.packageDecl map { case PackageDecl(nameSegments, _, _, _) => - nameSegments.mkString(".") - } - val tpe = toNamedType(interface, pkg) - def containsTpe(intfs: List[NamedType]): Boolean = - intfs exists { ref => - ref.names.size match { - case 0 => sys.error(s"Invalid reference $intfs") - case 1 => ref.names.head == tpe.names.last - case _ => ref.names == tpe.names - } + protected def lookupChildren(s: Document, interface: InterfaceTypeDefinition): List[TypeDefinition] = { + val pkg = + s.packageDecl map { case PackageDecl(nameSegments, _, _, _) => + nameSegments.mkString(".") + } + val tpe = toNamedType(interface, pkg) + def containsTpe(intfs: List[NamedType]): Boolean = + intfs exists { ref => + ref.names.size match { + case 0 => sys.error(s"Invalid reference $intfs") + case 1 => ref.names.head == tpe.names.last + case _ => ref.names == tpe.names } - val result = s.definitions collect { - case r: ObjectTypeDefinition if containsTpe(r.interfaces) => r - case i: InterfaceTypeDefinition if containsTpe(i.interfaces) => i } - result + val result = s.definitions collect { + case r: ObjectTypeDefinition if containsTpe(r.interfaces) => r + case i: InterfaceTypeDefinition if containsTpe(i.interfaces) => i } + result + } - protected def localFields(cl: RecordLikeDefinition, parents: List[InterfaceTypeDefinition]): List[FieldDefinition] = - { - val allFields = cl.fields filter { _.arguments.isEmpty } - val parentFields: List[FieldDefinition] = parents flatMap { _.fields } - def inParent(f: FieldDefinition): Boolean = { - val x = parentFields exists { _.name == f.name } - x - } - allFields filterNot inParent + protected def localFields(cl: RecordLikeDefinition, parents: List[InterfaceTypeDefinition]): List[FieldDefinition] = { + val allFields = cl.fields filter { _.arguments.isEmpty } + val parentFields: List[FieldDefinition] = parents flatMap { _.fields } + def inParent(f: FieldDefinition): Boolean = { + val x = parentFields exists { _.name == f.name } + x } + allFields filterNot inParent + } /** Run an operation `op` for each different version number that affects the fields `fields`. */ - protected final def perVersionNumber[T](since: VersionNumber, fields: List[FieldDefinition])(op: (List[FieldDefinition], List[FieldDefinition]) => T): List[T] = { + protected final def perVersionNumber[T](since: VersionNumber, fields: List[FieldDefinition])( + op: (List[FieldDefinition], List[FieldDefinition]) => T + ): List[T] = { val versionNumbers = (since :: fields.map({ f => getSince(f.directives) })).sorted.distinct versionNumbers map { v => val (provided, byDefault) = fields partition { f => getSince(f.directives) <= v } @@ -155,7 +152,7 @@ abstract class CodeGenerator { case "long" | "Long" => "java.lang.Long" case "short" | "Short" => "java.lang.Short" case "double" | "Double" => "java.lang.Double" - case other => other + case other => other } protected def boxedType(tpe: String): String = @@ -168,7 +165,7 @@ abstract class CodeGenerator { case "long" | "Long" => "Long" case "short" | "Short" => "Short" case "double" | "Double" => "Double" - case other => other + case other => other } protected def unboxedType(tpe: String): String = @@ -181,7 +178,7 @@ abstract class CodeGenerator { case "long" | "Long" => "long" case "short" | "Short" => "short" case "double" | "Double" => "double" - case other => other + case other => other } protected def primitiveType(tpe: String): Boolean = @@ -194,7 +191,7 @@ abstract class CodeGenerator { case "long" | "Long" => true case "short" | "Short" => true case "double" | "Double" => true - case other => false + case other => false } protected def isPrimitive(tpe: ast.Type) = @@ -209,25 +206,24 @@ abstract class CodeGenerator { protected def containsStrictOptional(fields: List[FieldDefinition]): Boolean = fields exists { f => f.fieldType.isOptionalType && !f.fieldType.isLazyType } - protected def genJavaEquals(lhs: String, rhs: String, f0: FieldDefinition, - fieldName: String, isJava: Boolean): String = + protected def genJavaEquals(lhs: String, rhs: String, f0: FieldDefinition, fieldName: String, isJava: Boolean): String = f0 match { case f if isPrimitive(f.fieldType) => s"($lhs.$fieldName == $rhs.$fieldName)" case f if isPrimitiveArray(f.fieldType) => s"java.util.Arrays.equals($lhs.$fieldName, $rhs.$fieldName)" - case f if f.fieldType.isListType => + case f if f.fieldType.isListType => if (isJava) s"java.util.Arrays.deepEquals($lhs.$fieldName, $rhs.$fieldName)" else s"java.util.Arrays.deepEquals($lhs.$fieldName.asInstanceOf[Array[Object]], $rhs.$fieldName.asInstanceOf[Array[Object]])" - case f => s"$lhs.$fieldName.equals($rhs.$fieldName)" + case f => s"$lhs.$fieldName.equals($rhs.$fieldName)" } protected def genJavaHashCode(f0: FieldDefinition, fieldName: String, isJava: Boolean): String = f0 match { case f if isPrimitive(f.fieldType) => s"${boxedType(f.fieldType.name)}.valueOf($fieldName).hashCode()" case f if isPrimitiveArray(f.fieldType) => s"java.util.Arrays.hashCode($fieldName)" - case f if f.fieldType.isListType => + case f if f.fieldType.isListType => if (isJava) s"java.util.Arrays.deepHashCode($fieldName)" else s"java.util.Arrays.deepHashCode($fieldName.asInstanceOf[Array[Object]])" - case f => s"$fieldName.hashCode()" + case f => s"$fieldName.hashCode()" } /** Generate the code corresponding to all definitions in `s`. */ diff --git a/library/src/main/scala/sbt/contraband/CodecCodeGen.scala b/library/src/main/scala/sbt/contraband/CodecCodeGen.scala index 61eacd6..a4d56a0 100644 --- a/library/src/main/scala/sbt/contraband/CodecCodeGen.scala +++ b/library/src/main/scala/sbt/contraband/CodecCodeGen.scala @@ -14,21 +14,24 @@ import AstUtil._ * @param formatsForType Given a `TpeRef` t, returns the list of codecs needed to encode t. * @param includedSchemas List of schemas that could be referenced. */ -class CodecCodeGen(codecParents: List[String], - instantiateJavaLazy: String => String, - javaOption: String, - scalaArray: String, - formatsForType: ast.Type => List[String], - includedSchemas: List[Document]) extends CodeGenerator { +class CodecCodeGen( + codecParents: List[String], + instantiateJavaLazy: String => String, + javaOption: String, + scalaArray: String, + formatsForType: ast.Type => List[String], + includedSchemas: List[Document] +) extends CodeGenerator { import CodecCodeGen._ implicit object indentationConfiguration extends IndentationConfiguration { override val indentElement = " " override def augmentIndentAfterTrigger(s: String) = s.endsWith("{") || - (s.contains(" class ") && s.endsWith("(")) // Constructor definition + (s.contains(" class ") && s.endsWith("(")) // Constructor definition override def reduceIndentTrigger(s: String) = s.startsWith("}") - override def reduceIndentAfterTrigger(s: String) = s.endsWith(") {") || s.endsWith("extends Serializable {") // End of constructor definition + override def reduceIndentAfterTrigger(s: String) = + s.endsWith(") {") || s.endsWith("extends Serializable {") // End of constructor definition override def enterMultilineJavadoc(s: String) = s == "/**" override def exitMultilineJavadoc(s: String) = s == "*/" } @@ -98,7 +101,9 @@ class CodecCodeGen(codecParents: List[String], } val fqn = fullyQualifiedName(r) val allFields = r.fields // superFields ++ r.fields - val getFields = allFields map (f => s"""val ${bq(f.name)} = unbuilder.readField[${genRealTpe(f.fieldType, intfLanguage)}]("${f.name}")""") mkString EOL + val getFields = allFields map (f => + s"""val ${bq(f.name)} = unbuilder.readField[${genRealTpe(f.fieldType, intfLanguage)}]("${f.name}")""" + ) mkString EOL val factoryMethodName = "of" val reconstruct = if (targetLang == "Scala") s"$fqn(" + allFields.map(accessField).mkString(", ") + ")" @@ -166,7 +171,8 @@ class CodecCodeGen(codecParents: List[String], case fms => fms.mkString("self: ", " with ", " =>") } val typeFieldName = (toCodecTypeField(i.directives) orElse toCodecTypeField(s)).getOrElse("type") - val flatUnionFormat = s"""flatUnionFormat${xs.length}[$fqn, ${xs map (c => fullyQualifiedName(c)) mkString ", "}]("$typeFieldName")""" + val flatUnionFormat = + s"""flatUnionFormat${xs.length}[$fqn, ${xs map (c => fullyQualifiedName(c)) mkString ", "}]("$typeFieldName")""" s"""${genPackage(s)} | |import _root_.sjsonnew.JsonFormat @@ -182,15 +188,14 @@ class CodecCodeGen(codecParents: List[String], private def interfaceLanguage(parents: List[InterfaceTypeDefinition], targetLang: String): String = if (parents.isEmpty) targetLang - else - { + else { if (parents exists { p => toTarget(p.directives) == Some("Java") }) "Java" else targetLang } override def generate(s: Document): ListMap[File, String] = { - val codecs: ListMap[File, String] = ((s.definitions collect { - case td: TypeDefinition => td + val codecs: ListMap[File, String] = ((s.definitions collect { case td: TypeDefinition => + td } map { d => ListMap(generate(s, d).toSeq: _*) }) reduce (_ merge _)) mapV (_.indented) @@ -218,9 +223,12 @@ class CodecCodeGen(codecParents: List[String], } private def getAllFormatsForSchema(s: Document): List[String] = - getAllRequiredFormats(s, (s.definitions collect { - case td: TypeDefinition if getGenerateCodec(td.directives) => td - })) + getAllRequiredFormats( + s, + (s.definitions collect { + case td: TypeDefinition if getGenerateCodec(td.directives) => td + }) + ) /** * Returns the list of fully qualified codec names that we (transitively) need to generate a codec for `ds`, @@ -233,7 +241,7 @@ class CodecCodeGen(codecParents: List[String], def getAllDefinitions(d: TypeDefinition): List[TypeDefinition] = d match { case i: InterfaceTypeDefinition => - i :: (lookupChildLeaves(s, i) flatMap {getAllDefinitions}) + i :: (lookupChildLeaves(s, i) flatMap { getAllDefinitions }) case _ => d :: Nil } val allDefinitions = ds flatMap getAllDefinitions @@ -241,8 +249,8 @@ class CodecCodeGen(codecParents: List[String], val requiredFormats = getRequiredFormats(s, d) fullFormatsName(s, d) -> (d match { case i: InterfaceTypeDefinition => - lookupChildLeaves(s, i).map( c => fullFormatsName(s, c)) ::: requiredFormats - case _ => requiredFormats + lookupChildLeaves(s, i).map(c => fullFormatsName(s, c)) ::: requiredFormats + case _ => requiredFormats }) }: _*) val xs = Dag.topologicalSortUnchecked[String](seedFormats) { s => dependencies.get(s).getOrElse(Nil) } @@ -265,15 +273,14 @@ class CodecCodeGen(codecParents: List[String], * Returns the list of fully qualified codec names that we (non-transitively) need to generate a codec for `d`, * knowing that it inherits fields `superFields` in the context of schema `s`. */ - private def getRequiredFormats(s: Document, d: TypeDefinition): List[String] = - { - val typeFormats = - d match { - case _: EnumTypeDefinition => Nil - case c: RecordLikeDefinition => c.fields flatMap (f => lookupFormats(f.fieldType)) - } - typeFormats ++ codecParents - } + private def getRequiredFormats(s: Document, d: TypeDefinition): List[String] = { + val typeFormats = + d match { + case _: EnumTypeDefinition => Nil + case c: RecordLikeDefinition => c.fields flatMap (f => lookupFormats(f.fieldType)) + } + typeFormats ++ codecParents + } private def fullyQualifiedName(d: TypeDefinition): String = s"""${d.namespace getOrElse "_root_"}.${bq(d.name)}""" @@ -296,25 +303,25 @@ class CodecCodeGen(codecParents: List[String], private def genRealTpe(tpe: ast.Type, targetLang: String) = { val scalaTpe = lookupTpe(scalaifyType(tpe.name)) tpe match { - case x if x.isListType && targetLang == "Java" => s"Array[${scalaTpe}]" - case x if x.isListType => s"$scalaArray[$scalaTpe]" + case x if x.isListType && targetLang == "Java" => s"Array[${scalaTpe}]" + case x if x.isListType => s"$scalaArray[$scalaTpe]" case x if !x.isNotNullType && targetLang == "Java" => s"$javaOption[${javaLangBoxedType(scalaTpe)}]" - case x if !x.isNotNullType => s"Option[$scalaTpe]" - case _ => scalaTpe + case x if !x.isNotNullType => s"Option[$scalaTpe]" + case _ => scalaTpe } } private def lookupTpe(tpe: String): String = scalaifyType(tpe) match { - case "boolean" => "Boolean" - case "byte" => "Byte" - case "char" => "Char" - case "float" => "Float" - case "int" => "Int" - case "long" => "Long" - case "short" => "Short" - case "double" => "Double" + case "boolean" => "Boolean" + case "byte" => "Byte" + case "char" => "Char" + case "float" => "Float" + case "int" => "Int" + case "long" => "Long" + case "short" => "Short" + case "double" => "Double" case "StringStringMap" => "scala.collection.immutable.Map[String, String]" - case other => other + case other => other } private def generateFullCodec(s: Document, name: String): ListMap[File, String] = { @@ -324,8 +331,7 @@ class CodecCodeGen(codecParents: List[String], s"""${genPackage(s)} |trait $name $parents |object $name extends $name""".stripMargin - val syntheticDefinition = InterfaceTypeDefinition(name, None, Nil, Nil, - Directive.targetScala :: Nil, Nil, Nil, None) + val syntheticDefinition = InterfaceTypeDefinition(name, None, Nil, Nil, Directive.targetScala :: Nil, Nil, Nil, None) ListMap(new File(genFile(s, syntheticDefinition).getParentFile, s"$name.scala") -> code) } @@ -335,15 +341,14 @@ class CodecCodeGen(codecParents: List[String], case _ => formatsForType(tpe) } - private def lookupDefinition(fullName: String): Option[(Document, TypeDefinition)] = - { - val (ns, name) = splitName(fullName) - (includedSchemas flatMap { s => - s.definitions collect { - case d: TypeDefinition if d.name == name && d.namespace == ns => (s, d) - } - }).headOption - } + private def lookupDefinition(fullName: String): Option[(Document, TypeDefinition)] = { + val (ns, name) = splitName(fullName) + (includedSchemas flatMap { s => + s.definitions collect { + case d: TypeDefinition if d.name == name && d.namespace == ns => (s, d) + } + }).headOption + } } object CodecCodeGen { @@ -360,7 +365,7 @@ object CodecCodeGen { extensibleFormatsForType { ref => val tpe = ref.removeTypeParameters val (ns, name) = splitName(tpe.name) - s"${ ns getOrElse "_root_" }.${name.capitalize}Formats" :: Nil + s"${ns getOrElse "_root_"}.${name.capitalize}Formats" :: Nil } private def splitName(fullName: String): (Option[String], String) = @@ -377,12 +382,13 @@ object CodecCodeGen { def extensibleFormatsForType(forOthers: ast.Type => List[String]): ast.Type => List[String] = { tpe => tpe.removeTypeParameters.name match { case "boolean" | "byte" | "char" | "float" | "int" | "long" | "short" | "double" | "String" => Nil - case "Boolean" | "Byte" | "Char" | "Float" | "Int" | "Long" | "Short" | "Double" => Nil - case "java.util.UUID" | "java.net.URI" | "java.net.URL" | "java.util.Calendar" | "java.math.BigInteger" - | "java.math.BigDecimal" | "java.io.File" => Nil - case "StringStringMap" => Nil + case "Boolean" | "Byte" | "Char" | "Float" | "Int" | "Long" | "Short" | "Double" => Nil + case "java.util.UUID" | "java.net.URI" | "java.net.URL" | "java.util.Calendar" | "java.math.BigInteger" | "java.math.BigDecimal" | + "java.io.File" => + Nil + case "StringStringMap" => Nil case "Throwable" | "java.lang.Throwable" => Nil - case _ => forOthers(tpe) + case _ => forOthers(tpe) } } } diff --git a/library/src/main/scala/sbt/contraband/Dag.scala b/library/src/main/scala/sbt/contraband/Dag.scala index d8c111f..ac60577 100644 --- a/library/src/main/scala/sbt/contraband/Dag.scala +++ b/library/src/main/scala/sbt/contraband/Dag.scala @@ -12,49 +12,50 @@ object Dag { def topologicalSort[T](root: T)(dependencies: T => Iterable[T]): List[T] = topologicalSort(root :: Nil)(dependencies) - def topologicalSort[T](nodes: Iterable[T])(dependencies: T => Iterable[T]): List[T] = - { - val discovered = new mutable.HashSet[T] - val finished = (new java.util.LinkedHashSet[T]).asScala - - def visitAll(nodes: Iterable[T]) = nodes foreach visit - def visit(node: T): Unit = { - if (!discovered(node)) { - discovered(node) = true; - try { visitAll(dependencies(node)); } catch { case c: Cyclic => throw node :: c } - finished += node - () - } else if (!finished(node)) - throw new Cyclic(node) - } - - visitAll(nodes) + def topologicalSort[T](nodes: Iterable[T])(dependencies: T => Iterable[T]): List[T] = { + val discovered = new mutable.HashSet[T] + val finished = (new java.util.LinkedHashSet[T]).asScala - finished.toList + def visitAll(nodes: Iterable[T]) = nodes foreach visit + def visit(node: T): Unit = { + if (!discovered(node)) { + discovered(node) = true; + try { visitAll(dependencies(node)); } + catch { case c: Cyclic => throw node :: c } + finished += node + () + } else if (!finished(node)) + throw new Cyclic(node) } + + visitAll(nodes) + + finished.toList + } // doesn't check for cycles def topologicalSortUnchecked[T](node: T)(dependencies: T => Iterable[T]): List[T] = topologicalSortUnchecked(node :: Nil)(dependencies) - def topologicalSortUnchecked[T](nodes: Iterable[T])(dependencies: T => Iterable[T]): List[T] = - { - val discovered = new mutable.HashSet[T] - var finished: List[T] = Nil - - def visitAll(nodes: Iterable[T]) = nodes foreach visit - def visit(node: T): Unit = { - if (!discovered(node)) { - discovered(node) = true - visitAll(dependencies(node)) - finished ::= node - } - } + def topologicalSortUnchecked[T](nodes: Iterable[T])(dependencies: T => Iterable[T]): List[T] = { + val discovered = new mutable.HashSet[T] + var finished: List[T] = Nil - visitAll(nodes); - finished; + def visitAll(nodes: Iterable[T]) = nodes foreach visit + def visit(node: T): Unit = { + if (!discovered(node)) { + discovered(node) = true + visitAll(dependencies(node)) + finished ::= node + } } + + visitAll(nodes); + finished; + } final class Cyclic(val value: Any, val all: List[Any], val complete: Boolean) - extends Exception("Cyclic reference involving " + - (if (complete) all.mkString("\n ", "\n ", "") else value)) { + extends Exception( + "Cyclic reference involving " + + (if (complete) all.mkString("\n ", "\n ", "") else value) + ) { def this(value: Any) = this(value, value :: Nil, false) override def toString = getMessage def ::(a: Any): Cyclic = @@ -68,17 +69,22 @@ object Dag { /** A directed graph with edges labeled positive or negative. */ private[sbt] trait DirectedSignedGraph[Node] { + /** * Directed edge type that tracks the sign and target (head) vertex. * The sign can be obtained via [[isNegative]] and the target vertex via [[head]]. */ type Arrow + /** List of initial nodes. */ def nodes: List[Arrow] + /** Outgoing edges for `n`. */ def dependencies(n: Node): List[Arrow] + /** `true` if the edge `a` is "negative", false if it is "positive". */ def isNegative(a: Arrow): Boolean + /** The target of the directed edge `a`. */ def head(a: Arrow): Node } @@ -89,36 +95,35 @@ object Dag { * If a cycle containing a "negative" edge is detected, its member edges are returned in order. * Otherwise, the empty list is returned. */ - private[sbt] def findNegativeCycle[Node](graph: DirectedSignedGraph[Node]): List[graph.Arrow] = - { - import graph._ - val finished = new mutable.HashSet[Node] - val visited = new mutable.HashSet[Node] - - def visit(edges: List[Arrow], stack: List[Arrow]): List[Arrow] = edges match { - case Nil => Nil - case edge :: tail => - val node = head(edge) - if (!visited(node)) { - visited += node - visit(dependencies(node), edge :: stack) match { - case Nil => - finished += node - visit(tail, stack) - case cycle => cycle - } - } else if (!finished(node)) { - // cycle. If a negative edge is involved, it is an error. - val between = edge :: stack.takeWhile(f => head(f) != node) - if (between exists isNegative) - between - else + private[sbt] def findNegativeCycle[Node](graph: DirectedSignedGraph[Node]): List[graph.Arrow] = { + import graph._ + val finished = new mutable.HashSet[Node] + val visited = new mutable.HashSet[Node] + + def visit(edges: List[Arrow], stack: List[Arrow]): List[Arrow] = edges match { + case Nil => Nil + case edge :: tail => + val node = head(edge) + if (!visited(node)) { + visited += node + visit(dependencies(node), edge :: stack) match { + case Nil => + finished += node visit(tail, stack) - } else + case cycle => cycle + } + } else if (!finished(node)) { + // cycle. If a negative edge is involved, it is an error. + val between = edge :: stack.takeWhile(f => head(f) != node) + if (between exists isNegative) + between + else visit(tail, stack) - } - - visit(graph.nodes, Nil) + } else + visit(tail, stack) } + visit(graph.nodes, Nil) + } + } diff --git a/library/src/main/scala/sbt/contraband/Indentation.scala b/library/src/main/scala/sbt/contraband/Indentation.scala index 34acd1a..c859454 100644 --- a/library/src/main/scala/sbt/contraband/Indentation.scala +++ b/library/src/main/scala/sbt/contraband/Indentation.scala @@ -10,6 +10,7 @@ class IndentationAwareBuffer(val config: IndentationConfiguration, private var l /** Add all the lines of `it` to the buffer. */ def +=(it: Iterator[String]): Unit = it foreach append + /** Add `s` to the buffer */ def +=(s: String): Unit = s.linesIterator foreach append @@ -28,6 +29,7 @@ class IndentationAwareBuffer(val config: IndentationConfiguration, private var l } abstract class IndentationConfiguration { + /** When this predicate holds for `s`, this line and the following should have one more level of indentation. */ def augmentIndentTrigger(s: String): Boolean = false diff --git a/library/src/main/scala/sbt/contraband/JavaCodeGen.scala b/library/src/main/scala/sbt/contraband/JavaCodeGen.scala index 2f0ee0a..c5dcd09 100644 --- a/library/src/main/scala/sbt/contraband/JavaCodeGen.scala +++ b/library/src/main/scala/sbt/contraband/JavaCodeGen.scala @@ -25,8 +25,8 @@ class JavaCodeGen( } override def generate(s: Document): ListMap[File, String] = - ListMap((s.definitions collect { - case td: TypeDefinition => td + ListMap((s.definitions collect { case td: TypeDefinition => + td }) flatMap (generate(s, _).toList): _*) mapV (_.indented) override def generateInterface(s: Document, i: InterfaceTypeDefinition): ListMap[File, String] = { @@ -96,9 +96,8 @@ class JavaCodeGen( val valuesCode = if (values.isEmpty) "" else - (values map { - case EnumValueDefinition(name, dir, comments, _) => - s"""${genDoc(toDoc(comments))} + (values map { case EnumValueDefinition(name, dir, comments, _) => + s"""${genDoc(toDoc(comments))} |$name""".stripMargin }).mkString("", "," + EOL, ";") @@ -317,31 +316,33 @@ class JavaCodeGen( } else s"${f._1.name}()" } - allFields map { - case (f, idx) => - val (before, after) = allFields filterNot (_._2 == idx) splitAt idx - val tpe = f.fieldType - val params = (before map nonParam) ::: f.name :: (after map nonParam) mkString ", " - s"""public ${r.name} with${capitalize(f.name)}(${genRealTpe(tpe)} ${f.name}) { + allFields map { case (f, idx) => + val (before, after) = allFields filterNot (_._2 == idx) splitAt idx + val tpe = f.fieldType + val params = (before map nonParam) ::: f.name :: (after map nonParam) mkString ", " + s"""public ${r.name} with${capitalize(f.name)}(${genRealTpe(tpe)} ${f.name}) { | return new ${r.name}($params); |}""".stripMargin + - (if (tpe.isListType || tpe.isNotNullType) "" - else { - val wrappedParams = (before map nonParam) ::: instantiateJavaOptional(boxedType(tpe.name), f.name) :: (after map nonParam) mkString ", " - s""" + (if (tpe.isListType || tpe.isNotNullType) "" + else { + val wrappedParams = + (before map nonParam) ::: instantiateJavaOptional(boxedType(tpe.name), f.name) :: (after map nonParam) mkString ", " + s""" |public ${r.name} with${capitalize(f.name)}(${genRealTpe(f.fieldType.notNull)} ${f.name}) { | return new ${r.name}($wrappedParams); |}""".stripMargin - }) + }) } mkString (EOL + EOL) } private def genEquals(cl: RecordLikeDefinition) = { val allFields = cl.fields filter { _.arguments.isEmpty } val body = - if (allFields exists { f => - f.fieldType.isLazyType - }) { + if ( + allFields exists { f => + f.fieldType.isLazyType + } + ) { "return this == obj; // We have lazy members, so use object identity to avoid circularity." } else { val comparisonCode = @@ -373,9 +374,11 @@ class JavaCodeGen( val fqcn = cl.namespace.fold("")(_ + ".") + cl.name val seed = s"""37 * (17 + "$fqcn".hashCode())""" val body = - if (allFields exists { f => - f.fieldType.isLazyType - }) { + if ( + allFields exists { f => + f.fieldType.isLazyType + } + ) { "return super.hashCode(); // Avoid evaluating lazy members in hashCode to avoid circularity." } else { val computation = (seed /: allFields) { (acc, f) => @@ -393,9 +396,11 @@ class JavaCodeGen( private def genToString(cl: RecordLikeDefinition, toString: List[String]) = { val body = if (toString.isEmpty) { val allFields = cl.fields filter { _.arguments.isEmpty } - if (allFields exists { f => - f.fieldType.isLazyType - }) { + if ( + allFields exists { f => + f.fieldType.isLazyType + } + ) { "return super.toString(); // Avoid evaluating lazy members in toString to avoid circularity." } else { allFields diff --git a/library/src/main/scala/sbt/contraband/MixedCodeGen.scala b/library/src/main/scala/sbt/contraband/MixedCodeGen.scala index 146c764..d9d43f0 100644 --- a/library/src/main/scala/sbt/contraband/MixedCodeGen.scala +++ b/library/src/main/scala/sbt/contraband/MixedCodeGen.scala @@ -8,20 +8,34 @@ import AstUtil._ /** * Generator that produces both Scala and Java code. */ -class MixedCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOptional: (String, String) => String, - scalaArray: String, genScalaFileName: Any => File, - scalaSealProtocols: Boolean, scalaPrivateConstructor: Boolean, wrapOption: Boolean) extends CodeGenerator { - val javaGen = new JavaCodeGen(javaLazy, javaOptional, instantiateJavaOptional, - wrapOption) - val scalaGen = new ScalaCodeGen(javaLazy, javaOptional, instantiateJavaOptional, - scalaArray, genScalaFileName, scalaSealProtocols, scalaPrivateConstructor, - wrapOption) +class MixedCodeGen( + javaLazy: String, + javaOptional: String, + instantiateJavaOptional: (String, String) => String, + scalaArray: String, + genScalaFileName: Any => File, + scalaSealProtocols: Boolean, + scalaPrivateConstructor: Boolean, + wrapOption: Boolean +) extends CodeGenerator { + val javaGen = new JavaCodeGen(javaLazy, javaOptional, instantiateJavaOptional, wrapOption) + val scalaGen = new ScalaCodeGen( + javaLazy, + javaOptional, + instantiateJavaOptional, + scalaArray, + genScalaFileName, + scalaSealProtocols, + scalaPrivateConstructor, + wrapOption + ) def generate(s: Document): ListMap[File, String] = - s.definitions collect { - case td: TypeDefinition => td - } map (generate (s, _)) reduce (_ merge _) map { case (k, v) => - (k, generateHeader + v) } + s.definitions collect { case td: TypeDefinition => + td + } map (generate(s, _)) reduce (_ merge _) map { case (k, v) => + (k, generateHeader + v) + } def generateInterface(s: Document, i: InterfaceTypeDefinition): ListMap[File, String] = { // We generate the code that corresponds to this protocol, but without its children, because they @@ -41,8 +55,8 @@ class MixedCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption def generateRecord(s: Document, r: ObjectTypeDefinition): ListMap[File, String] = { toTarget(r.directives) match { - case Some("Java") => javaGen.generateRecord(s, r) mapV (_ indentWith javaGen.indentationConfiguration) - case _ => scalaGen.generateRecord(s, r) mapV (_ indentWith scalaGen.indentationConfiguration) + case Some("Java") => javaGen.generateRecord(s, r) mapV (_ indentWith javaGen.indentationConfiguration) + case _ => scalaGen.generateRecord(s, r) mapV (_ indentWith scalaGen.indentationConfiguration) } } diff --git a/library/src/main/scala/sbt/contraband/ScalaCodeGen.scala b/library/src/main/scala/sbt/contraband/ScalaCodeGen.scala index 146a27c..a1d8252 100644 --- a/library/src/main/scala/sbt/contraband/ScalaCodeGen.scala +++ b/library/src/main/scala/sbt/contraband/ScalaCodeGen.scala @@ -9,27 +9,32 @@ import AstUtil._ /** * Code generator for Scala. */ -class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOptional: (String, String) => String, - scalaArray: String, genFile: Any => File, - scalaSealProtocols: Boolean, scalaPrivateConstructor: Boolean, - wrapOption: Boolean) extends CodeGenerator { +class ScalaCodeGen( + javaLazy: String, + javaOptional: String, + instantiateJavaOptional: (String, String) => String, + scalaArray: String, + genFile: Any => File, + scalaSealProtocols: Boolean, + scalaPrivateConstructor: Boolean, + wrapOption: Boolean +) extends CodeGenerator { implicit object indentationConfiguration extends IndentationConfiguration { override val indentElement = " " override def augmentIndentAfterTrigger(s: String) = s.endsWith("{") || - (s.contains(" class ") && s.endsWith("(")) // Constructor definition + (s.contains(" class ") && s.endsWith("(")) // Constructor definition override def reduceIndentTrigger(s: String) = s.startsWith("}") override def reduceIndentAfterTrigger(s: String) = s.endsWith(") {") || s.endsWith(" Serializable {") // End of constructor definition override def enterMultilineJavadoc(s: String) = s == "/**" override def exitMultilineJavadoc(s: String) = s == "*/" } - override def generate(s: Document): ListMap[File, String] = - (s.definitions collect { - case td: TypeDefinition => td - }) map (generate (s, _)) reduce (_ merge _) mapV (_.indented) + (s.definitions collect { case td: TypeDefinition => + td + }) map (generate(s, _)) reduce (_ merge _) mapV (_.indented) override def generateEnum(s: Document, e: EnumTypeDefinition): ListMap[File, String] = { val values = @@ -153,8 +158,7 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption private def interfaceLanguage(parents: List[InterfaceTypeDefinition]): String = if (parents.isEmpty) "Scala" - else - { + else { if (parents exists { p => toTarget(p.directives) == Some("Java") }) "Java" else "Scala" } @@ -182,28 +186,29 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption } private def genParam(f: FieldDefinition, intfLang: String): String = genParam(f.name, f.fieldType, intfLang) - private def genParam(name: String, fieldType: Type, intfLang: String): String = s"${bq(name)}: ${genRealTpe(fieldType, isParam = true, intfLang)}" + private def genParam(name: String, fieldType: Type, intfLang: String): String = + s"${bq(name)}: ${genRealTpe(fieldType, isParam = true, intfLang)}" private def lookupTpe(tpe: String): String = tpe match { - case "boolean" => "Boolean" - case "byte" => "Byte" - case "char" => "Char" - case "float" => "Float" - case "int" => "Int" - case "long" => "Long" - case "short" => "Short" - case "double" => "Double" + case "boolean" => "Boolean" + case "byte" => "Byte" + case "char" => "Char" + case "float" => "Float" + case "int" => "Int" + case "long" => "Long" + case "short" => "Short" + case "double" => "Double" case "StringStringMap" => "scala.collection.immutable.Map[String, String]" - case other => other + case other => other } private def genRealTpe(tpe: ast.Type, isParam: Boolean, intfLang: String) = if (intfLang == "Scala") { val scalaTpe = lookupTpe(tpe.name) val base = tpe match { - case x if x.isListType => s"$scalaArray[$scalaTpe]" - case x if !x.isNotNullType => s"Option[$scalaTpe]" - case _ => scalaTpe + case x if x.isListType => s"$scalaArray[$scalaTpe]" + case x if !x.isNotNullType => s"Option[$scalaTpe]" + case _ => scalaTpe } if (tpe.isLazyType && isParam) s"=> $base" else base } else { @@ -230,9 +235,12 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption case "Scala" => ("x", allFields map (f => s"(this.${bq(f.name)} == x.${bq(f.name)})") mkString " && ") case _ => - ("x", (allFields map { f => - genJavaEquals("this", "x", f, s"${bq(f.name)}", false) - }).mkString(" && ")) + ( + "x", + (allFields map { f => + genJavaEquals("this", "x", f, s"${bq(f.name)}", false) + }).mkString(" && ") + ) } s"""override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match { @@ -291,8 +299,8 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption v match { case x: ObjectValue => val args = x.fields map { f => f.value.renderPretty } - s"""${tpe.name}(${ args.mkString(", ") })""" - case _ => v.renderPretty + s"""${tpe.name}(${args.mkString(", ")})""" + case _ => v.renderPretty } if (tpe.isListType) s"Vector($str)" else if (tpe.isNotNullType) str @@ -312,8 +320,8 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption v match { case x: ObjectValue => val args = x.fields map { f => f.value.renderPretty } - s"""${tpe.name}(${ args.mkString(", ") })""" - case _ => v.renderPretty + s"""${tpe.name}(${args.mkString(", ")})""" + case _ => v.renderPretty } if (tpe.isListType) "Array(${str})" else if (tpe.isNotNullType) str @@ -340,7 +348,7 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption case None if f.fieldType.isListType || !f.fieldType.isNotNullType => if (intfLang == "Scala") renderScalaValue(NullValue(), f.fieldType) else renderJavaValue(NullValue(), f.fieldType) - case _ => sys.error(s"Needs a default value for field ${f.name}.") + case _ => sys.error(s"Needs a default value for field ${f.name}.") } private def genApplyOverloads(r: ObjectTypeDefinition, allFields: List[FieldDefinition], intfLang: String): List[String] = @@ -352,8 +360,7 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption val applyParameters = provided map { f => genParam(f, intfLang) } mkString ", " val ctorCallArguments = provided map (f => bq(f.name)) mkString ", " - s"def apply($applyParameters): ${r.name} = new ${r.name}($ctorCallArguments)" + - { + s"def apply($applyParameters): ${r.name} = new ${r.name}($ctorCallArguments)" + { if (!containsOptional(provided) || !wrapOption) "" else { val applyParameters2 = (provided map { f => @@ -361,19 +368,26 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption else genParam(f, intfLang) }).mkString(", ") val ctorCallArguments2 = - provided.map { f => - if (f.fieldType.isOptionalType) - mkOptional(bq(f.name), f.fieldType, intfLang) - else - bq(f.name) - }.mkString(", ") + provided + .map { f => + if (f.fieldType.isOptionalType) + mkOptional(bq(f.name), f.fieldType, intfLang) + else + bq(f.name) + } + .mkString(", ") EOL + s"def apply($applyParameters2): ${r.name} = new ${r.name}($ctorCallArguments2)" } } } } - private def genAlternativeConstructors(since: VersionNumber, allFields: List[FieldDefinition], privateConstructor: Boolean, intfLang: String) = + private def genAlternativeConstructors( + since: VersionNumber, + allFields: List[FieldDefinition], + privateConstructor: Boolean, + intfLang: String + ) = perVersionNumber(since, allFields) { case (provided, byDefault) if byDefault.nonEmpty => // Don't duplicate up-to-date constructor val ctorParameters = provided map { f => genParam(f, intfLang) } mkString ", " @@ -393,31 +407,30 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption // parameter. Because val parameters may not be call-by-name, we prefix the parameter with `_` // and we will create the actual lazy val as a regular class member. // Non-lazy fields that belong to `cl` are made val parameters. - private def genCtorParameters(cl: RecordLikeDefinition, parent: Option[InterfaceTypeDefinition], intfLang: String): List[String] = - { - val allFields = cl.fields filter { _.arguments.isEmpty } - val parentFields: List[FieldDefinition] = - parent match { - case Some(x) => x.fields filter { _.arguments.isEmpty } - case _ => Nil - } - def inParent(f: FieldDefinition): Boolean = { - val x = parentFields exists { _.name == f.name } - x - } - allFields map { - case f if !inParent(f) && f.fieldType.isLazyType => - EOL + "_" + genParam(f, intfLang) - case f if !inParent(f) => - s"""${EOL}val ${genParam(f, intfLang)}""".stripMargin - case f => EOL + genParam(f, intfLang) + private def genCtorParameters(cl: RecordLikeDefinition, parent: Option[InterfaceTypeDefinition], intfLang: String): List[String] = { + val allFields = cl.fields filter { _.arguments.isEmpty } + val parentFields: List[FieldDefinition] = + parent match { + case Some(x) => x.fields filter { _.arguments.isEmpty } + case _ => Nil } + def inParent(f: FieldDefinition): Boolean = { + val x = parentFields exists { _.name == f.name } + x } + allFields map { + case f if !inParent(f) && f.fieldType.isLazyType => + EOL + "_" + genParam(f, intfLang) + case f if !inParent(f) => + s"""${EOL}val ${genParam(f, intfLang)}""".stripMargin + case f => EOL + genParam(f, intfLang) + } + } private def genLazyMembers(fields: List[FieldDefinition], intfLang: String): List[String] = fields filter (_.fieldType.isLazyType) map { f => - val doc = toDoc(f.comments) - s"""${genDoc(doc)} + val doc = toDoc(f.comments) + s"""${genDoc(doc)} |lazy val ${bq(f.name)}: ${genRealTpe(f.fieldType, isParam = false, intfLang)} = _${f.name}""".stripMargin } @@ -456,17 +469,16 @@ class ScalaCodeGen(javaLazy: String, javaOptional: String, instantiateJavaOption } private def genWith(r: ObjectTypeDefinition, intfLang: String) = { - def capitalize(s: String) = { val (fst, rst) = s.splitAt(1) ; fst.toUpperCase + rst } + def capitalize(s: String) = { val (fst, rst) = s.splitAt(1); fst.toUpperCase + rst } r.fields map { f => s"""def with${capitalize(f.name)}(${bq(f.name)}: ${genRealTpe(f.fieldType, isParam = true, intfLang)}): ${r.name} = { | copy(${bq(f.name)} = ${bq(f.name)}) |}""".stripMargin + - ( if (f.fieldType.isListType || f.fieldType.isNotNullType) "" - else s""" + (if (f.fieldType.isListType || f.fieldType.isNotNullType) "" + else s""" |def with${capitalize(f.name)}(${bq(f.name)}: ${genRealTpe(f.fieldType.notNull, isParam = true, intfLang)}): ${r.name} = { | copy(${bq(f.name)} = ${mkOptional(bq(f.name), f.fieldType, intfLang)}) - |}""".stripMargin - ) + |}""".stripMargin) } mkString (EOL + EOL) } } diff --git a/library/src/main/scala/sbt/contraband/VersionNumber.scala b/library/src/main/scala/sbt/contraband/VersionNumber.scala index 9e6e438..518223a 100644 --- a/library/src/main/scala/sbt/contraband/VersionNumber.scala +++ b/library/src/main/scala/sbt/contraband/VersionNumber.scala @@ -2,10 +2,7 @@ package sbt.contraband import scala.annotation.tailrec -final class VersionNumber private[sbt] ( - val numbers: Seq[Long], - val tags: Seq[String], - val extras: Seq[String]) { +final class VersionNumber private[sbt] (val numbers: Seq[Long], val tags: Seq[String], val extras: Seq[String]) { def _1: Option[Long] = get(0) def _2: Option[Long] = get(1) def _3: Option[Long] = get(2) diff --git a/library/src/main/scala/sbt/contraband/ast/SchemaAst.scala b/library/src/main/scala/sbt/contraband/ast/SchemaAst.scala index ab9e46c..57387be 100644 --- a/library/src/main/scala/sbt/contraband/ast/SchemaAst.scala +++ b/library/src/main/scala/sbt/contraband/ast/SchemaAst.scala @@ -15,7 +15,8 @@ final case class Document( directives: List[Directive], trailingComments: List[Comment] = Nil, position: Option[Position] = None -) extends AstNode with WithTrailingComments +) extends AstNode + with WithTrailingComments final case class PackageDecl( nameSegments: List[String], @@ -81,9 +82,9 @@ sealed trait Type extends AstNode { loop(this) } - /** Removes all type parameters from `tpe` */ def removeTypeParameters: ast.Type = { + /** Removes all type parameters from `tpe` */ def removeTp(tpe: String): String = tpe.replaceAll("<.+>", "").replaceAll("\\[.+\\]", "") def loop(tpe: Type): Type = @@ -155,7 +156,7 @@ final case class LazyType(ofType: Type, position: Option[Position] = None) exten override def equals(other: Any): Boolean = other match { case that: LazyType => (that canEqual this) && ofType == that.ofType - case _ => false + case _ => false } override def hashCode: Int = 37 * (17 + ofType.##) + "LazyType".## @@ -166,7 +167,6 @@ sealed trait NameValue extends AstNode with WithComments { def value: Value } - final case class Directive( name: String, arguments: List[Argument], @@ -188,13 +188,13 @@ object Directive { def modifier(value: String): Directive = Directive("modifier", Argument(None, StringValue(value)) :: Nil) } - final case class Argument( nameOpt: Option[String], value: Value, comments: List[Comment] = Nil, position: Option[Position] = None -) extends AstNode with WithComments +) extends AstNode + with WithComments sealed trait Value extends AstNode with WithComments { def renderPretty: String = @@ -231,12 +231,13 @@ final case class RawValue(value: String, comments: List[Comment] = Nil, position final case class ObjectValue(fields: List[ObjectField], comments: List[Comment] = Nil, position: Option[Position] = None) extends Value { lazy val fieldsByName: ListMap[String, Value] = - fields.foldLeft(ListMap.empty[String, Value]) { - case (acc, field) ⇒ acc + (field.name → field.value) + fields.foldLeft(ListMap.empty[String, Value]) { case (acc, field) ⇒ + acc + (field.name → field.value) } } -final case class ObjectField(name: String, value: Value, comments: List[Comment] = Nil, position: Option[Position] = None) extends NameValue { +final case class ObjectField(name: String, value: Value, comments: List[Comment] = Nil, position: Option[Position] = None) + extends NameValue { def renderPretty: String = s"$name" } @@ -255,13 +256,13 @@ final case class CompanionExtraComment(text: String, position: Option[Position] // Schema definitions final case class FieldDefinition( - name: String, - fieldType: Type, - arguments: List[InputValueDefinition], - defaultValue: Option[Value] = None, - directives: List[Directive] = Nil, - comments: List[Comment] = Nil, - position: Option[Position] = None + name: String, + fieldType: Type, + arguments: List[InputValueDefinition], + defaultValue: Option[Value] = None, + directives: List[Directive] = Nil, + comments: List[Comment] = Nil, + position: Option[Position] = None ) extends SchemaAstNode { override def canEqual(other: Any): Boolean = other.isInstanceOf[FieldDefinition] @@ -269,9 +270,9 @@ final case class FieldDefinition( override def equals(other: Any): Boolean = other match { case that: FieldDefinition => (that canEqual this) && - name == that.name && - fieldType == that.fieldType && - arguments == that.arguments + name == that.name && + fieldType == that.fieldType && + arguments == that.arguments case _ => false } @@ -280,12 +281,12 @@ final case class FieldDefinition( } case class InputValueDefinition( - name: String, - valueType: Type, - defaultValue: Option[Value], - directives: List[Directive] = Nil, - comments: List[Comment] = Nil, - position: Option[Position] = None + name: String, + valueType: Type, + defaultValue: Option[Value], + directives: List[Directive] = Nil, + comments: List[Comment] = Nil, + position: Option[Position] = None ) extends SchemaAstNode { override def canEqual(other: Any): Boolean = other.isInstanceOf[InputValueDefinition] @@ -293,8 +294,8 @@ case class InputValueDefinition( override def equals(other: Any): Boolean = other match { case that: InputValueDefinition => (that canEqual this) && - name == that.name && - valueType == that.valueType + name == that.name && + valueType == that.valueType case _ => false } @@ -302,42 +303,45 @@ case class InputValueDefinition( } final case class ObjectTypeDefinition( - name: String, - namespace: Option[String], - interfaces: List[NamedType], - fields: List[FieldDefinition], - directives: List[Directive] = Nil, - comments: List[Comment] = Nil, - trailingComments: List[Comment] = Nil, - position: Option[Position] = None -) extends RecordLikeDefinition with WithTrailingComments + name: String, + namespace: Option[String], + interfaces: List[NamedType], + fields: List[FieldDefinition], + directives: List[Directive] = Nil, + comments: List[Comment] = Nil, + trailingComments: List[Comment] = Nil, + position: Option[Position] = None +) extends RecordLikeDefinition + with WithTrailingComments final case class InterfaceTypeDefinition( - name: String, - namespace: Option[String], - interfaces: List[NamedType], - fields: List[FieldDefinition], - directives: List[Directive] = Nil, - comments: List[Comment] = Nil, - trailingComments: List[Comment] = Nil, - position: Option[Position] = None -) extends RecordLikeDefinition with WithTrailingComments + name: String, + namespace: Option[String], + interfaces: List[NamedType], + fields: List[FieldDefinition], + directives: List[Directive] = Nil, + comments: List[Comment] = Nil, + trailingComments: List[Comment] = Nil, + position: Option[Position] = None +) extends RecordLikeDefinition + with WithTrailingComments final case class EnumTypeDefinition( - name: String, - namespace: Option[String], - values: List[EnumValueDefinition], - directives: List[Directive] = Nil, - comments: List[Comment] = Nil, - trailingComments: List[Comment] = Nil, - position: Option[Position] = None -) extends TypeDefinition with WithTrailingComments + name: String, + namespace: Option[String], + values: List[EnumValueDefinition], + directives: List[Directive] = Nil, + comments: List[Comment] = Nil, + trailingComments: List[Comment] = Nil, + position: Option[Position] = None +) extends TypeDefinition + with WithTrailingComments final case class EnumValueDefinition( - name: String, - directives: List[Directive] = Nil, - comments: List[Comment] = Nil, - position: Option[Position] = None + name: String, + directives: List[Directive] = Nil, + comments: List[Comment] = Nil, + position: Option[Position] = None ) extends SchemaAstNode sealed trait SchemaAstNode extends AstNode with WithComments @@ -361,46 +365,46 @@ sealed trait AstNode object AstUtil { def toDoc(comments: List[Comment]): List[String] = - comments collect { - case DocComment(text, _) => text.trim + comments collect { case DocComment(text, _) => + text.trim } def toExtra(d: TypeDefinition): List[String] = - (d.comments ++ d.trailingComments) collect { - case ExtraComment(text, _) => text + (d.comments ++ d.trailingComments) collect { case ExtraComment(text, _) => + text } def toExtraIntf(d: TypeDefinition): List[String] = - (d.comments ++ d.trailingComments) collect { - case ExtraIntfComment(text, _) => text + (d.comments ++ d.trailingComments) collect { case ExtraIntfComment(text, _) => + text } def toToStringImpl(d: TypeDefinition): List[String] = - (d.comments ++ d.trailingComments) collect { - case ToStringImplComment(text, _) => text + (d.comments ++ d.trailingComments) collect { case ToStringImplComment(text, _) => + text } def toCompanionExtraIntfComment(d: TypeDefinition): List[String] = - (d.comments ++ d.trailingComments) collect { - case CompanionExtraIntfComment(text, _) => text + (d.comments ++ d.trailingComments) collect { case CompanionExtraIntfComment(text, _) => + text } def toCompanionExtra(d: TypeDefinition): List[String] = - (d.comments ++ d.trailingComments) collect { - case CompanionExtraComment(text, _) => text + (d.comments ++ d.trailingComments) collect { case CompanionExtraComment(text, _) => + text } def getTarget(opt: Option[String]): String = opt getOrElse sys.error("@target directive must be set either at the definition or at the package.") def scanSingleStringDirective(dirs: List[Directive], name: String): Option[String] = - scanSingleDirectiveArgumentValue(dirs, name) { - case StringValue(value, _, _) => value + scanSingleDirectiveArgumentValue(dirs, name) { case StringValue(value, _, _) => + value } def scanSingleBooleanDirective(dirs: List[Directive], name: String): Option[Boolean] = - scanSingleDirectiveArgumentValue(dirs, name) { - case BooleanValue(value, _, _) => value + scanSingleDirectiveArgumentValue(dirs, name) { case BooleanValue(value, _, _) => + value } def scanSingleDirective(dirs: List[Directive], name: String): Option[Directive] = { @@ -412,8 +416,8 @@ object AstUtil { } def toTarget(dirs: List[Directive]): Option[String] = - scanSingleDirectiveArgumentValue(dirs, "target") { - case EnumValue(value, _, _) => value + scanSingleDirectiveArgumentValue(dirs, "target") { case EnumValue(value, _, _) => + value } private def scanSingleDirectiveArgumentValue[A](dirs: List[Directive], name: String)( @@ -437,18 +441,18 @@ object AstUtil { toSince(dirs) getOrElse VersionNumber.empty def toCodecPackage(d: Document): Option[String] = { - val dirs = d.directives ++ (d.packageDecl map {_.directives}).toList.flatten + val dirs = d.directives ++ (d.packageDecl map { _.directives }).toList.flatten scanSingleStringDirective(dirs, "codecPackage") orElse - scanSingleStringDirective(dirs, "codecNamespace") + scanSingleStringDirective(dirs, "codecNamespace") } def toFullCodec(d: Document): Option[String] = { - val dirs = d.directives ++ (d.packageDecl map {_.directives}).toList.flatten + val dirs = d.directives ++ (d.packageDecl map { _.directives }).toList.flatten scanSingleStringDirective(dirs, "fullCodec") } def toCodecTypeField(d: Document): Option[String] = - toCodecTypeField(d.directives ++ (d.packageDecl map {_.directives}).toList.flatten) + toCodecTypeField(d.directives ++ (d.packageDecl map { _.directives }).toList.flatten) def toCodecTypeField(dirs: List[Directive]): Option[String] = scanSingleStringDirective(dirs, "codecTypeField") diff --git a/library/src/main/scala/sbt/contraband/parser/JsonParser.scala b/library/src/main/scala/sbt/contraband/parser/JsonParser.scala index 298b76c..2b2f81e 100644 --- a/library/src/main/scala/sbt/contraband/parser/JsonParser.scala +++ b/library/src/main/scala/sbt/contraband/parser/JsonParser.scala @@ -1,7 +1,6 @@ package sbt.contraband package parser - import scala.util.Try import ast.AstUtil.toNamedType import sjsonnew.support.scalajson.unsafe.Parser @@ -26,7 +25,6 @@ trait JsonParser[T] { } } - /** Optionally retrieves the string field `key` from `jValue`. */ def ->?(key: String): Option[String] = lookup(key) map { case JString(value) => value @@ -132,6 +130,7 @@ trait JsonParser[T] { } trait Parse[T] extends JsonParser[T] { + /** Parse an instance of `T` from `input`. */ final def parse(input: String): T = { val json = Parser.parseFromString(input).get @@ -155,6 +154,7 @@ trait ParseWithSuperIntf[A] extends JsonParser[A] { } object JsonParser { + /** * Represents a complete schema definition. * Syntax: @@ -163,18 +163,17 @@ object JsonParser { * (, "fullCodec": string constant)? } */ object Document extends Parse[ast.Document] { - override def parse(json: JValue): ast.Document = - { - val types = json ->* "types" flatMap parser.JsonParser.TypeDefinitions.parse - val directives = CodecPackageDirective(json).toList ++ FullCodecDirective(json).toList - ast.Document( - None, - types, - directives, - Nil, - None - ) - } + override def parse(json: JValue): ast.Document = { + val types = json ->* "types" flatMap parser.JsonParser.TypeDefinitions.parse + val directives = CodecPackageDirective(json).toList ++ FullCodecDirective(json).toList + ast.Document( + None, + types, + directives, + Nil, + None + ) + } } object TypeDefinitions extends ParseWithSuperIntf[List[ast.TypeDefinition]] { @@ -200,12 +199,14 @@ object JsonParser { */ object EnumTypeDefinition extends Parse[ast.EnumTypeDefinition] { override def parse(json: JValue): ast.EnumTypeDefinition = - ast.EnumTypeDefinition(json -> "name", + ast.EnumTypeDefinition( + json -> "name", json ->? "namespace", (json ->* "symbols") map EnumValueDefinition.parse, TargetDirective(json).toList ++ GenerateCodecDirective(json).toList, // json ->? "since" map VersionNumber.apply getOrElse emptyVersion, - DocComment(json) ++ ExtraComment(json)) + DocComment(json) ++ ExtraComment(json) + ) } /** @@ -213,14 +214,14 @@ object JsonParser { * Syntax: * EnumerationValue := ID * | { "name": ID - (, "doc": string constant)? } - */ + * (, "doc": string constant)? } + */ object EnumValueDefinition extends Parse[ast.EnumValueDefinition] { override def parse(json: JValue): ast.EnumValueDefinition = json match { case JString(name) => ast.EnumValueDefinition(name, Nil) case json => ast.EnumValueDefinition(json -> "name", Nil, DocComment(json)) - } + } } /** @@ -234,29 +235,30 @@ object JsonParser { * (, "extra": string constant)? } */ object ObjectTypeDefinition extends ParseWithSuperIntf[ast.ObjectTypeDefinition] { - def parse(json: JValue, superIntf: Option[ast.InterfaceTypeDefinition]): ast.ObjectTypeDefinition = - { - val superFields = - superIntf match { - case Some(s) => s.fields - case _ => Nil - } - val fs = json ->* "fields" map FieldDefinition.parse - val intfs = (superIntf map { i => toNamedType(i, None) }).toList - val directives = TargetDirective(json).toList ++ - SinceDirective(json).toList ++ - ModifierDirective(json).toList ++ - GenerateCodecDirective(json).toList - ast.ObjectTypeDefinition(json -> "name", - json ->? "namespace", - intfs, - superFields ++ fs, - directives, - DocComment(json), - ExtraComment(json) ++ ExtraIntfComment(json) ++ ToStringImplComment(json) ++ + def parse(json: JValue, superIntf: Option[ast.InterfaceTypeDefinition]): ast.ObjectTypeDefinition = { + val superFields = + superIntf match { + case Some(s) => s.fields + case _ => Nil + } + val fs = json ->* "fields" map FieldDefinition.parse + val intfs = (superIntf map { i => toNamedType(i, None) }).toList + val directives = TargetDirective(json).toList ++ + SinceDirective(json).toList ++ + ModifierDirective(json).toList ++ + GenerateCodecDirective(json).toList + ast.ObjectTypeDefinition( + json -> "name", + json ->? "namespace", + intfs, + superFields ++ fs, + directives, + DocComment(json), + ExtraComment(json) ++ ExtraIntfComment(json) ++ ToStringImplComment(json) ++ CompanionExtraIntfComment(json) ++ CompanionExtraComment(json), - None) - } + None + ) + } } /** @@ -280,83 +282,64 @@ object JsonParser { } def parseInterface(json: JValue): List[ast.TypeDefinition] = parseInterface(json, None) - def parseInterface(json: JValue, superIntf: Option[ast.InterfaceTypeDefinition]): List[ast.TypeDefinition] = - { - val superFields = - superIntf match { - case Some(s) => s.fields - case _ => Nil - } - val fs = (json ->* "fields" map FieldDefinition.parse) ++ - (json ->* "messages" map FieldDefinition.parseMessage) - val parents = json multiLineOpt "parents" getOrElse Nil - val intfs = (superIntf map { i => toNamedType(i, None) }).toList - val directives = TargetDirective(json).toList ++ SinceDirective(json).toList ++ - GenerateCodecDirective(json).toList - val intf = ast.InterfaceTypeDefinition(json -> "name", - json ->? "namespace", - intfs, - superFields ++ fs, - directives, - DocComment(json), - ExtraComment(json) ++ ExtraIntfComment(json) ++ ToStringImplComment(json) ++ + def parseInterface(json: JValue, superIntf: Option[ast.InterfaceTypeDefinition]): List[ast.TypeDefinition] = { + val superFields = + superIntf match { + case Some(s) => s.fields + case _ => Nil + } + val fs = (json ->* "fields" map FieldDefinition.parse) ++ + (json ->* "messages" map FieldDefinition.parseMessage) + val parents = json multiLineOpt "parents" getOrElse Nil + val intfs = (superIntf map { i => toNamedType(i, None) }).toList + val directives = TargetDirective(json).toList ++ SinceDirective(json).toList ++ + GenerateCodecDirective(json).toList + val intf = ast.InterfaceTypeDefinition( + json -> "name", + json ->? "namespace", + intfs, + superFields ++ fs, + directives, + DocComment(json), + ExtraComment(json) ++ ExtraIntfComment(json) ++ ToStringImplComment(json) ++ CompanionExtraIntfComment(json) ++ CompanionExtraComment(json), - None) // position - val childTypes = (json ->* "types") flatMap { j => TypeDefinitions.parse(j, Some(intf)) } - intf :: childTypes - } + None + ) // position + val childTypes = (json ->* "types") flatMap { j => TypeDefinitions.parse(j, Some(intf)) } + intf :: childTypes + } } object FieldDefinition extends Parse[ast.FieldDefinition] { - override def parse(json: JValue): ast.FieldDefinition = - { - val arguments = Nil - val defaultValue = (json ->? "default") map { ast.RawValue(_) } - val directives = SinceDirective(json).toList - val tpe = Type.parse(json -> "type") - ast.FieldDefinition(json -> "name", - tpe, - arguments, - defaultValue, - directives, - DocComment(json), - None) - } + override def parse(json: JValue): ast.FieldDefinition = { + val arguments = Nil + val defaultValue = (json ->? "default") map { ast.RawValue(_) } + val directives = SinceDirective(json).toList + val tpe = Type.parse(json -> "type") + ast.FieldDefinition(json -> "name", tpe, arguments, defaultValue, directives, DocComment(json), None) + } def parseMessage(input: String): ast.FieldDefinition = { val json = Parser.parseFromString(input).get parseMessage(json) } - def parseMessage(json: JValue): ast.FieldDefinition = - { - val arguments = (json ->* "request") map InputValueDefinition.parse - val defaultValue = (json ->? "default") map { ast.RawValue(_) } - val directives = SinceDirective(json).toList - val tpe = Type.parse(json -> "response") - ast.FieldDefinition(json -> "name", - tpe, - arguments, - defaultValue, - directives, - DocComment(json), - None) - } + def parseMessage(json: JValue): ast.FieldDefinition = { + val arguments = (json ->* "request") map InputValueDefinition.parse + val defaultValue = (json ->? "default") map { ast.RawValue(_) } + val directives = SinceDirective(json).toList + val tpe = Type.parse(json -> "response") + ast.FieldDefinition(json -> "name", tpe, arguments, defaultValue, directives, DocComment(json), None) + } } object InputValueDefinition extends Parse[ast.InputValueDefinition] { - override def parse(json: JValue): ast.InputValueDefinition = - { - val arguments = Nil - val defaultValue = (json ->? "default") map { ast.RawValue(_) } - val directives = SinceDirective(json).toList - val tpe = Type.parse(json -> "type") - ast.InputValueDefinition(json -> "name", - tpe, - defaultValue, - directives, - DocComment(json), - None) - } + override def parse(json: JValue): ast.InputValueDefinition = { + val arguments = Nil + val defaultValue = (json ->? "default") map { ast.RawValue(_) } + val directives = SinceDirective(json).toList + val tpe = Type.parse(json -> "type") + ast.InputValueDefinition(json -> "name", tpe, defaultValue, directives, DocComment(json), None) + } } object Type { @@ -379,16 +362,17 @@ object JsonParser { } } - def parse(s: String): ast.Type = - { - val r = X(s) - val t0 = ast.NamedType(r.name, None) - val t1 = if (r.repeated) ast.ListType(t0, None) - else if (!r.optional) ast.NotNullType(t0, None) - else t0 - val t2 = if (r.lzy) ast.LazyType(t1) - else t1 - t2 - } + def parse(s: String): ast.Type = { + val r = X(s) + val t0 = ast.NamedType(r.name, None) + val t1 = + if (r.repeated) ast.ListType(t0, None) + else if (!r.optional) ast.NotNullType(t0, None) + else t0 + val t2 = + if (r.lzy) ast.LazyType(t1) + else t1 + t2 + } } } diff --git a/library/src/main/scala/sbt/contraband/parser/SchemaParser.scala b/library/src/main/scala/sbt/contraband/parser/SchemaParser.scala index c99f2d1..e397ed3 100644 --- a/library/src/main/scala/sbt/contraband/parser/SchemaParser.scala +++ b/library/src/main/scala/sbt/contraband/parser/SchemaParser.scala @@ -2,13 +2,12 @@ package sbt.contraband package parser import org.parboiled2._ -import CharPredicate.{HexDigit, Digit19, AlphaNum} -import scala.util.{Failure, Success, Try} - +import CharPredicate.{ HexDigit, Digit19, AlphaNum } +import scala.util.{ Failure, Success, Try } trait Tokens extends StringBuilding with PositionTracking { this: Parser with Ignored => - def Token = rule { Punctuator | Name | NumberValue | StringValue } + def Token = rule { Punctuator | Name | NumberValue | StringValue } val PunctuatorChar = CharPredicate("!$():=@[]{|}") @@ -28,10 +27,13 @@ trait Tokens extends StringBuilding with PositionTracking { this: Parser with Ig def RawNames = rule { atomic("raw\"" ~ capture(Characters) ~ "\"") ~> ((n: String) ⇒ List(n)) } - def NumberValue = rule { atomic(Comments ~ trackPos ~ IntegerValuePart ~ FloatValuePart.? ~ IgnoredNoComment.*) ~> + def NumberValue = rule { + atomic(Comments ~ trackPos ~ IntegerValuePart ~ FloatValuePart.? ~ IgnoredNoComment.*) ~> ((comment, pos, intPart, floatPart) ⇒ floatPart map (f ⇒ ast.BigDecimalValue(BigDecimal(intPart + f), comment, Some(pos))) getOrElse - ast.BigIntValue(BigInt(intPart), comment, Some(pos))) } + ast.BigIntValue(BigInt(intPart), comment, Some(pos)) + ) + } def FloatValuePart = rule { atomic(capture(FractionalPart ~ ExponentPart.? | ExponentPart)) } @@ -47,13 +49,19 @@ trait Tokens extends StringBuilding with PositionTracking { this: Parser with Ig def Sign = rule { ch('-') | '+' } - val NegativeSign = '-' + val NegativeSign = '-' val NonZeroDigit = Digit19 def Digit = rule { ch('0') | NonZeroDigit } - def StringValue = rule { atomic(Comments ~ trackPos ~ '"' ~ clearSB() ~ Characters ~ '"' ~ push(sb.toString) ~ IgnoredNoComment.* ~> ((comment, pos, s) ⇒ ast.StringValue(s, comment, Some(pos))))} + def StringValue = rule { + atomic( + Comments ~ trackPos ~ '"' ~ clearSB() ~ Characters ~ '"' ~ push(sb.toString) ~ IgnoredNoComment.* ~> ((comment, pos, s) ⇒ + ast.StringValue(s, comment, Some(pos)) + ) + ) + } def Characters = rule { (NormalChar | '\\' ~ EscapedChar).* } @@ -63,12 +71,12 @@ trait Tokens extends StringBuilding with PositionTracking { this: Parser with Ig def EscapedChar = rule { QuoteBackslash ~ appendSB() | - 'b' ~ appendSB('\b') | - 'f' ~ appendSB('\f') | - 'n' ~ appendSB('\n') | - 'r' ~ appendSB('\r') | - 't' ~ appendSB('\t') | - Unicode ~> { code ⇒ sb.append(code.asInstanceOf[Char]); () } + 'b' ~ appendSB('\b') | + 'f' ~ appendSB('\f') | + 'n' ~ appendSB('\n') | + 'r' ~ appendSB('\r') | + 't' ~ appendSB('\t') | + Unicode ~> { code ⇒ sb.append(code.asInstanceOf[Char]); () } } def Unicode = rule { 'u' ~ capture(4 times HexDigit) ~> (Integer.parseInt(_, 16)) } @@ -90,23 +98,47 @@ trait Ignored extends PositionTracking { this: Parser ⇒ def IgnoredNoComment = rule { quiet(UnicodeBOM | WhiteSpace | (CRLF | LineTerminator) ~ trackNewLine | ',') } - def Comments = rule { (DocComment | CommentCap).* ~ Ignored.* ~> (_.toList)} + def Comments = rule { (DocComment | CommentCap).* ~ Ignored.* ~> (_.toList) } - def ExtraComments = rule { (ExtraIntfComment | ToStringImplComment | CompanionExtraIntfComment | CompanionExtraComment | ExtraComment | DocComment | CommentCap).* ~ Ignored.* ~> (_.toList)} + def ExtraComments = rule { + (ExtraIntfComment | ToStringImplComment | CompanionExtraIntfComment | CompanionExtraComment | ExtraComment | DocComment | CommentCap).* ~ Ignored.* ~> (_.toList) + } - def CommentCap = rule { trackPos ~ "#" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.CommentLine(comment, Some(pos))) } + def CommentCap = rule { + trackPos ~ "#" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.CommentLine(comment, Some(pos))) + } - def DocComment = rule { trackPos ~ "##" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.DocComment(comment, Some(pos))) } + def DocComment = rule { + trackPos ~ "##" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.DocComment(comment, Some(pos))) + } - def ExtraComment = rule { trackPos ~ "#x" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.ExtraComment(comment, Some(pos))) } + def ExtraComment = rule { + trackPos ~ "#x" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.ExtraComment(comment, Some(pos))) + } - def ExtraIntfComment = rule { trackPos ~ "#xinterface" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.ExtraIntfComment(comment.trim, Some(pos))) } + def ExtraIntfComment = rule { + trackPos ~ "#xinterface" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => + ast.ExtraIntfComment(comment.trim, Some(pos)) + ) + } - def ToStringImplComment = rule { trackPos ~ "#xtostring" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.ToStringImplComment(comment.trim, Some(pos))) } + def ToStringImplComment = rule { + trackPos ~ "#xtostring" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => + ast.ToStringImplComment(comment.trim, Some(pos)) + ) + } - def CompanionExtraIntfComment = rule { trackPos ~ "#xcompanioninterface" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.CompanionExtraIntfComment(comment.trim, Some(pos))) } + def CompanionExtraIntfComment = rule { + trackPos ~ "#xcompanioninterface" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => + ast.CompanionExtraIntfComment(comment.trim, Some(pos)) + ) + } - def CompanionExtraComment = rule { trackPos ~ "#xcompanion" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => ast.CompanionExtraComment(comment.trim, Some(pos))) } + def CompanionExtraComment = rule { + trackPos ~ "#xcompanion" ~ capture(CommentChar.*) ~ IgnoredNoComment.* ~> ((pos, comment) => + ast.CompanionExtraComment(comment.trim, Some(pos)) + ) + } def Comment = rule { "#" ~ CommentChar.* } @@ -120,7 +152,12 @@ trait Ignored extends PositionTracking { this: Parser ⇒ } -trait Document { this: Parser with Tokens /* with Operations*/ with Ignored /*with Fragments with Operations with Values*/ with Directives with TypeSystemDefinitions => +trait Document { + this: Parser + with Tokens /* with Operations*/ + with Ignored /*with Fragments with Operations with Values*/ + with Directives + with TypeSystemDefinitions => def `package` = rule { Keyword("package") } @@ -136,11 +173,14 @@ trait Document { this: Parser with Tokens /* with Operations*/ with Ignored /*wi // def InputDocument = rule { IgnoredNoComment.* ~ ValueConst ~ Ignored.* ~ EOI } - def Definition = rule { /*OperationDefinition | FragmentDefinition |*/ TypeSystemDefinition } + def Definition = rule { /*OperationDefinition | FragmentDefinition |*/ + TypeSystemDefinition + } } -trait TypeSystemDefinitions { this: Parser with Tokens with Ignored with Directives with Types with Operations with Values /*with Fragments*/ => +trait TypeSystemDefinitions { + this: Parser with Tokens with Ignored with Directives with Types with Operations with Values /*with Fragments*/ => def scalar = rule { Keyword("scalar") } def `type` = rule { Keyword("type") } @@ -162,8 +202,8 @@ trait TypeSystemDefinitions { this: Parser with Tokens with Ignored with Directi def TypeDefinition = rule { // ScalarTypeDefinition | ObjectTypeDefinition | - InterfaceTypeDefinition | - EnumTypeDefinition // | + InterfaceTypeDefinition | + EnumTypeDefinition // | // UnionTypeDefinition | // InputObjectTypeDefinition } @@ -177,28 +217,36 @@ trait TypeSystemDefinitions { this: Parser with Tokens with Ignored with Directi def ObjectTypeDefinition = rule { Comments ~ trackPos ~ `type` ~ Name ~ (ImplementsInterfaces.? ~> (_ getOrElse Nil)) ~ (Directives.? ~> (_ getOrElse Nil)) ~ wsNoComment('{') ~ FieldDefinition.* ~ ExtraComments ~ wsNoComment('}') ~> ( - (comment, pos, name, interfaces, dirs, fields, tc) => ast.ObjectTypeDefinition(name, None, interfaces, fields.toList, dirs, comment, tc, Some(pos))) + (comment, pos, name, interfaces, dirs, fields, tc) => + ast.ObjectTypeDefinition(name, None, interfaces, fields.toList, dirs, comment, tc, Some(pos)) + ) } def ImplementsInterfaces = rule { implements ~ NamedType.+ ~> (_.toList) } def FieldDefinition = rule { - Comments ~ trackPos ~ Name ~ (ArgumentsDefinition.? ~> (_ getOrElse Nil)) ~ ws(':') ~ Type ~ DefaultValue.? ~ (Directives.? ~> (_ getOrElse Nil)) ~> ( - (comment, pos, name, args, fieldType, default, dirs) => ast.FieldDefinition(name, fieldType, args, default, dirs, comment, Some(pos))) + Comments ~ trackPos ~ Name ~ (ArgumentsDefinition.? ~> (_ getOrElse Nil)) ~ ws( + ':' + ) ~ Type ~ DefaultValue.? ~ (Directives.? ~> (_ getOrElse Nil)) ~> ((comment, pos, name, args, fieldType, default, dirs) => + ast.FieldDefinition(name, fieldType, args, default, dirs, comment, Some(pos)) + ) } def ArgumentsDefinition = rule { wsNoComment('(') ~ InputValueDefinition.+ ~ wsNoComment(')') ~> (_.toList) } def InputValueDefinition = rule { Comments ~ trackPos ~ Name ~ ws(':') ~ Type ~ DefaultValue.? ~ (Directives.? ~> (_ getOrElse Nil)) ~> ( - (comment, pos, name, valueType, default, dirs) => ast.InputValueDefinition(name, valueType, default, dirs, comment, Some(pos))) + (comment, pos, name, valueType, default, dirs) => ast.InputValueDefinition(name, valueType, default, dirs, comment, Some(pos)) + ) } def InterfaceTypeDefinition = rule { // Changed FieldDefinition.+ to FieldDefinition.* Comments ~ trackPos ~ interface ~ Name ~ (ImplementsInterfaces.? ~> (_ getOrElse Nil)) ~ (Directives.? ~> (_ getOrElse Nil)) ~ wsNoComment('{') ~ FieldDefinition.* ~ ExtraComments ~ wsNoComment('}') ~> ( - (comment, pos, name, interfaces, dirs, fields, tc) => ast.InterfaceTypeDefinition(name, None, interfaces, fields.toList, dirs, comment, tc, Some(pos))) + (comment, pos, name, interfaces, dirs, fields, tc) => + ast.InterfaceTypeDefinition(name, None, interfaces, fields.toList, dirs, comment, tc, Some(pos)) + ) } // def UnionTypeDefinition = rule { @@ -209,8 +257,11 @@ trait TypeSystemDefinitions { this: Parser with Tokens with Ignored with Directi // def UnionMembers = rule { NamedType.+(ws('|')) ~> (_.toList) } def EnumTypeDefinition = rule { - Comments ~ trackPos ~ enum ~ Name ~ (Directives.? ~> (_ getOrElse Nil)) ~ wsNoComment('{') ~ EnumValueDefinition.+ ~ ExtraComments ~ wsNoComment('}') ~> ( - (comment, pos, name, dirs, values, tc) ⇒ ast.EnumTypeDefinition(name, None, values.toList, dirs, comment, tc, Some(pos))) + Comments ~ trackPos ~ enum ~ Name ~ (Directives.? ~> (_ getOrElse Nil)) ~ wsNoComment( + '{' + ) ~ EnumValueDefinition.+ ~ ExtraComments ~ wsNoComment('}') ~> ((comment, pos, name, dirs, values, tc) ⇒ + ast.EnumTypeDefinition(name, None, values.toList, dirs, comment, tc, Some(pos)) + ) } def EnumValueDefinition = rule { @@ -249,7 +300,8 @@ trait TypeSystemDefinitions { this: Parser with Tokens with Ignored with Directi } -trait Operations extends PositionTracking { this: Parser with Tokens with Ignored /*with Fragments*/ with Values with Types with Directives ⇒ +trait Operations extends PositionTracking { + this: Parser with Tokens with Ignored /*with Fragments*/ with Values with Types with Directives ⇒ // def OperationDefinition = rule { // Comments ~ trackPos ~ SelectionSet ~> ((comment, pos, s) ⇒ ast.OperationDefinition(selections = s._1, comments = comment, trailingComments = s._2, position = Some(pos))) | @@ -299,8 +351,10 @@ trait Operations extends PositionTracking { this: Parser with Tokens with Ignore def Arguments = rule { Ignored.* ~ wsNoComment('(') ~ Argument.+ ~ wsNoComment(')') ~> (_.toList) } - def Argument = rule { Comments ~ trackPos ~ (Name ~ wsNoComment(':')).? ~ Value ~> - ((comment, pos, name, value) ⇒ ast.Argument(name, value, comment, Some(pos))) } + def Argument = rule { + Comments ~ trackPos ~ (Name ~ wsNoComment(':')).? ~ Value ~> + ((comment, pos, name, value) ⇒ ast.Argument(name, value, comment, Some(pos))) + } } trait Values { this: Parser with Tokens with Ignored with Operations ⇒ @@ -311,19 +365,19 @@ trait Values { this: Parser with Tokens with Ignored with Operations ⇒ def Value: Rule1[ast.Value] = rule { Comments ~ trackPos ~ Variable ~> ((comment, pos, name) ⇒ ast.VariableValue(name, comment, Some(pos))) | - NumberValue | - RawValue | - StringValue | - BooleanValue | - NullValue | - EnumValue | - ListValue | - ObjectValue + NumberValue | + RawValue | + StringValue | + BooleanValue | + NullValue | + EnumValue | + ListValue | + ObjectValue } def BooleanValue = rule { Comments ~ trackPos ~ True ~> ((comment, pos) ⇒ ast.BooleanValue(true, comment, Some(pos))) | - Comments ~ trackPos ~ False ~> ((comment, pos) ⇒ ast.BooleanValue(false, comment, Some(pos))) + Comments ~ trackPos ~ False ~> ((comment, pos) ⇒ ast.BooleanValue(false, comment, Some(pos))) } def True = rule { Keyword("true") } @@ -336,27 +390,57 @@ trait Values { this: Parser with Tokens with Ignored with Operations ⇒ def EnumValue = rule { Comments ~ !True ~ !False ~ trackPos ~ Name ~> ((comment, pos, name) ⇒ ast.EnumValue(name, comment, Some(pos))) } - def ListValueConst = rule { Comments ~ trackPos ~ wsNoComment('[') ~ ValueConst.* ~ wsNoComment(']') ~> ((comment, pos, v) ⇒ ast.ListValue(v.toList, comment, Some(pos))) } + def ListValueConst = rule { + Comments ~ trackPos ~ wsNoComment('[') ~ ValueConst.* ~ wsNoComment(']') ~> ((comment, pos, v) ⇒ + ast.ListValue(v.toList, comment, Some(pos)) + ) + } - def ListValue = rule { Comments ~ trackPos ~ wsNoComment('[') ~ Value.* ~ wsNoComment(']') ~> ((comment, pos, v) ⇒ ast.ListValue(v.toList, comment, Some(pos))) } + def ListValue = rule { + Comments ~ trackPos ~ wsNoComment('[') ~ Value.* ~ wsNoComment(']') ~> ((comment, pos, v) ⇒ ast.ListValue(v.toList, comment, Some(pos))) + } - def ObjectValueConst = rule { Comments ~ trackPos ~ wsNoComment('{') ~ ObjectFieldConst.* ~ wsNoComment('}') ~> ((comment, pos, f) ⇒ ast.ObjectValue(f.toList, comment, Some(pos))) } + def ObjectValueConst = rule { + Comments ~ trackPos ~ wsNoComment('{') ~ ObjectFieldConst.* ~ wsNoComment('}') ~> ((comment, pos, f) ⇒ + ast.ObjectValue(f.toList, comment, Some(pos)) + ) + } - def ObjectValue = rule { Comments ~ trackPos ~ wsNoComment('{') ~ ObjectField.* ~ wsNoComment('}') ~> ((comment, pos, f) ⇒ ast.ObjectValue(f.toList, comment, Some(pos))) } + def ObjectValue = rule { + Comments ~ trackPos ~ wsNoComment('{') ~ ObjectField.* ~ wsNoComment('}') ~> ((comment, pos, f) ⇒ + ast.ObjectValue(f.toList, comment, Some(pos)) + ) + } - def ObjectFieldConst = rule { Comments ~ trackPos ~ Name ~ wsNoComment(':') ~ ValueConst ~> ((comment, pos, name, value) ⇒ ast.ObjectField(name, value, comment, Some(pos))) } + def ObjectFieldConst = rule { + Comments ~ trackPos ~ Name ~ wsNoComment(':') ~ ValueConst ~> ((comment, pos, name, value) ⇒ + ast.ObjectField(name, value, comment, Some(pos)) + ) + } - def ObjectField = rule { Comments ~ trackPos ~ Name ~ wsNoComment(':') ~ Value ~> ((comment, pos, name, value) ⇒ ast.ObjectField(name, value, comment, Some(pos))) } + def ObjectField = rule { + Comments ~ trackPos ~ Name ~ wsNoComment(':') ~ Value ~> ((comment, pos, name, value) ⇒ + ast.ObjectField(name, value, comment, Some(pos)) + ) + } - def RawValue = rule { atomic(Comments ~ trackPos ~ 'r' ~ 'a' ~ 'w' ~ '"' ~ clearSB() ~ Characters ~ '"' ~ push(sb.toString) ~ IgnoredNoComment.* ~> ((comment, pos, s) ⇒ ast.RawValue(s, comment, Some(pos))))} + def RawValue = rule { + atomic( + Comments ~ trackPos ~ 'r' ~ 'a' ~ 'w' ~ '"' ~ clearSB() ~ Characters ~ '"' ~ push(sb.toString) ~ IgnoredNoComment.* ~> ( + (comment, pos, s) ⇒ ast.RawValue(s, comment, Some(pos)) + ) + ) + } } trait Directives { this: Parser with Tokens with Operations with Ignored ⇒ def Directives = rule { Directive.+ ~> (_.toList) } - def Directive = rule { Comments ~ trackPos ~ '@' ~ NameStrict ~ (Arguments.? ~> (_ getOrElse Nil)) ~> - ((comment, pos, name, args) ⇒ ast.Directive(name, args, comment, Some(pos))) } + def Directive = rule { + Comments ~ trackPos ~ '@' ~ NameStrict ~ (Arguments.? ~> (_ getOrElse Nil)) ~> + ((comment, pos, name, args) ⇒ ast.Directive(name, args, comment, Some(pos))) + } } @@ -369,19 +453,26 @@ trait Types { this: Parser with Tokens with Ignored => def TypeName: Rule1[List[String]] = rule { RawNames | DotNames } - def NamedType = rule { Ignored.* ~ trackPos ~ TypeName ~> ((pos, name) ⇒ ast.NamedType(name, Some(pos)))} + def NamedType = rule { Ignored.* ~ trackPos ~ TypeName ~> ((pos, name) ⇒ ast.NamedType(name, Some(pos))) } def ListType = rule { trackPos ~ ws('[') ~ Type ~ wsNoComment(']') ~> ((pos, tpe) ⇒ ast.ListType(tpe, Some(pos))) } def NonNullType = rule { - trackPos ~ TypeName ~ wsNoComment('!') ~> ((pos, name) ⇒ ast.NotNullType(ast.NamedType(name, Some(pos)), Some(pos))) | - trackPos ~ ListType ~ wsNoComment('!') ~> ((pos, tpe) ⇒ ast.NotNullType(tpe, Some(pos))) + trackPos ~ TypeName ~ wsNoComment('!') ~> ((pos, name) ⇒ ast.NotNullType(ast.NamedType(name, Some(pos)), Some(pos))) | + trackPos ~ ListType ~ wsNoComment('!') ~> ((pos, tpe) ⇒ ast.NotNullType(tpe, Some(pos))) } } class SchemaParser(val input: ParserInput) - extends Parser with Tokens with Ignored with Document with Operations // with Fragments - with Values with Directives with Types with TypeSystemDefinitions + extends Parser + with Tokens + with Ignored + with Document + with Operations // with Fragments + with Values + with Directives + with Types + with TypeSystemDefinitions object SchemaParser { def parse(input: String): Try[ast.Document] = @@ -401,22 +492,21 @@ object Transform { def run(doc: ast.Document): ast.Document = propateNamespace(doc) - def propateNamespace(doc: ast.Document): ast.Document = - { - val pkg = - doc.packageDecl map { case ast.PackageDecl(nameSegments, _, _, _) => - nameSegments.mkString(".") - } - val target = - doc.packageDecl flatMap { case ast.PackageDecl(_, dirs, _, _) => - toTarget(dirs) - } - val defns = - doc.definitions map { - toDefinitions(_, pkg, target) - } - doc.copy(definitions = defns) - } + def propateNamespace(doc: ast.Document): ast.Document = { + val pkg = + doc.packageDecl map { case ast.PackageDecl(nameSegments, _, _, _) => + nameSegments.mkString(".") + } + val target = + doc.packageDecl flatMap { case ast.PackageDecl(_, dirs, _, _) => + toTarget(dirs) + } + val defns = + doc.definitions map { + toDefinitions(_, pkg, target) + } + doc.copy(definitions = defns) + } def toDefinitions(d: ast.Definition, ns0: Option[String], packageTarget: Option[String]): ast.TypeDefinition = d match { diff --git a/library/src/test/scala/GraphQLMixedCodeGenSpec.scala b/library/src/test/scala/GraphQLMixedCodeGenSpec.scala index 5a068fa..f1e663c 100644 --- a/library/src/test/scala/GraphQLMixedCodeGenSpec.scala +++ b/library/src/test/scala/GraphQLMixedCodeGenSpec.scala @@ -13,12 +13,19 @@ class GraphQLMixedCodeGenSpec extends AnyFlatSpec with Matchers with Inside with "generate(Record)" should "handle mixed Java-Scala inheritance" in { val Success(ast) = SchemaParser.parse(mixedExample) // println(ast) - val gen = new MixedCodeGen(javaLazy, CodeGen.javaOptional, CodeGen.instantiateJavaOptional, - scalaArray, genFileName, scalaSealProtocols = true, scalaPrivateConstructor = true, - wrapOption = true) + val gen = new MixedCodeGen( + javaLazy, + CodeGen.javaOptional, + CodeGen.instantiateJavaOptional, + scalaArray, + genFileName, + scalaSealProtocols = true, + scalaPrivateConstructor = true, + wrapOption = true + ) val code = gen.generate(ast) - code.mapValues(_.unindent).toMap should equalMapLines ( + code.mapValues(_.unindent).toMap should equalMapLines( ListMap( new File("com/example/Greeting.java") -> """/** @@ -112,7 +119,8 @@ object SimpleGreeting { def apply(message: String, s: String): SimpleGreeting = new SimpleGreeting(message, java.util.Optional.ofNullable[String](s)) } """.stripMargin.unindent - )) + ) + ) } val javaLazy = "com.example.Lazy" diff --git a/library/src/test/scala/JsonCodecCodeGenSpec.scala b/library/src/test/scala/JsonCodecCodeGenSpec.scala index f5249ee..78a2106 100644 --- a/library/src/test/scala/JsonCodecCodeGenSpec.scala +++ b/library/src/test/scala/JsonCodecCodeGenSpec.scala @@ -17,8 +17,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val enumeration = JsonParser.EnumTypeDefinition.parse(simpleEnumerationExample) val code = gen generate enumeration - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -57,8 +56,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val intf = JsonParser.InterfaceTypeDefinition.parseInterface(simpleInterfaceExample) val code = gen generate intf - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -84,8 +82,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val intf = JsonParser.InterfaceTypeDefinition.parseInterface(oneChildInterfaceExample) val code = gen generate intf - code(new File("generated", "oneChildInterfaceExampleFormats.scala")).unindent should equalLines ( - """/** + code(new File("generated", "oneChildInterfaceExampleFormats.scala")).unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -97,8 +94,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { |trait OneChildInterfaceExampleFormats { self: sjsonnew.BasicJsonProtocol with generated.ChildRecordFormats => | implicit lazy val oneChildInterfaceExampleFormat: JsonFormat[_root_.oneChildInterfaceExample] = flatUnionFormat1[_root_.oneChildInterfaceExample, _root_.childRecord]("type") |}""".stripMargin.unindent) - code(new File("generated", "childRecordFormats.scala")).unindent should equalLines ( - """/** + code(new File("generated", "childRecordFormats.scala")).unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -136,8 +132,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val intf = JsonParser.InterfaceTypeDefinition.parseInterface(nestedInterfaceExample) val code = gen generate intf - code(new File("generated", "nestedProtocolExampleFormats.scala")).unindent should equalLines ( - """/** + code(new File("generated", "nestedProtocolExampleFormats.scala")).unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -157,8 +152,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val gen = new CodecCodeGen(codecParents, instantiateJavaLazy, javaOption, scalaArray, formatsForType, schema :: Nil) val code = gen generate schema - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -184,8 +178,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val record = JsonParser.ObjectTypeDefinition.parse(simpleRecordExample) val code = gen generate record - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -221,8 +214,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val record = JsonParser.ObjectTypeDefinition.parse(growableAddOneFieldExample) val code = gen generate record - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -258,8 +250,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { val record = JsonParser.ObjectTypeDefinition.parse(growableZeroToOneToTwoFieldsExample) val code = gen generate record - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -299,8 +290,7 @@ class JsonCodecCodeGenSpec extends GCodeGenSpec("Codec") { // println(code) - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. */ @@ -331,16 +321,14 @@ implicit lazy val primitiveTypesExample2Format: JsonFormat[_root_.primitiveTypes }""".stripMargin.unindent) } - override def recordWithModifier: Unit = { - } + override def recordWithModifier: Unit = {} override def schemaGenerateTypeReferences = { val schema = JsonParser.Document.parse(primitiveTypesExample) val gen = new CodecCodeGen(codecParents, instantiateJavaLazy, javaOption, scalaArray, formatsForType, schema :: Nil) val code = gen generate schema - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -386,8 +374,7 @@ implicit lazy val primitiveTypesExample2Format: JsonFormat[_root_.primitiveTypes val gen = new CodecCodeGen(codecParents, instantiateJavaLazy, javaOption, scalaArray, formatsForType, schema :: Nil) val code = gen generate schema - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | @@ -423,7 +410,7 @@ implicit lazy val primitiveTypesExample2Format: JsonFormat[_root_.primitiveTypes val schema = JsonParser.Document.parse(completeExample) val gen = new CodecCodeGen(codecParents, instantiateJavaLazy, javaOption, scalaArray, formatsForType, schema :: Nil) val code = gen generate schema - code.values.mkString.unindent should equalLines (completeExampleCodeCodec.unindent) + code.values.mkString.unindent should equalLines(completeExampleCodeCodec.unindent) } override def schemaGenerateCompletePlusIndent = { @@ -431,7 +418,7 @@ implicit lazy val primitiveTypesExample2Format: JsonFormat[_root_.primitiveTypes val gen = new CodecCodeGen(codecParents, instantiateJavaLazy, javaOption, scalaArray, formatsForType, schema :: Nil) val code = gen generate schema - code.values.mkString.withoutEmptyLines should equalLines (completeExampleCodeCodec.withoutEmptyLines) + code.values.mkString.withoutEmptyLines should equalLines(completeExampleCodeCodec.withoutEmptyLines) } "The full codec object" should "include the codec of all protocol defined in the schema" in { @@ -447,8 +434,7 @@ implicit lazy val primitiveTypesExample2Format: JsonFormat[_root_.primitiveTypes val gen = new CodecCodeGen(codecParents, instantiateJavaLazy, javaOption, scalaArray, formatsForType, schema :: Nil) val code = gen generate schema - code.head._2.unindent should equalLines ( - """/** + code.head._2.unindent should equalLines("""/** | * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. | */ | diff --git a/library/src/test/scala/JsonScalaCodeGenSpec.scala b/library/src/test/scala/JsonScalaCodeGenSpec.scala index 3633beb..af648a6 100644 --- a/library/src/test/scala/JsonScalaCodeGenSpec.scala +++ b/library/src/test/scala/JsonScalaCodeGenSpec.scala @@ -10,8 +10,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val enumeration = JsonParser.EnumTypeDefinition.parse(simpleEnumerationExample) val code = mkScalaCodeGen generate enumeration - code.head._2.unindent should equalLines ( - """/** Example of simple enumeration */ + code.head._2.unindent should equalLines("""/** Example of simple enumeration */ |sealed abstract class simpleEnumerationExample extends Serializable |object simpleEnumerationExample { | // Some extra code... @@ -26,8 +25,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val protocol = JsonParser.InterfaceTypeDefinition.parseInterface(simpleInterfaceExample) val code = mkScalaCodeGen generate protocol - code.head._2.unindent should equalLines ( - """/** example of simple interface */ + code.head._2.unindent should equalLines("""/** example of simple interface */ |sealed abstract class simpleInterfaceExample( | val field: type) extends Interface1 with Interface2 with Serializable { | // Some extra code... @@ -53,8 +51,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val protocol = JsonParser.InterfaceTypeDefinition.parseInterface(oneChildInterfaceExample) val code = mkScalaCodeGen generate protocol - code.head._2.unindent should equalLines ( - """/** example of interface */ + code.head._2.unindent should equalLines("""/** example of interface */ |sealed abstract class oneChildInterfaceExample( | val field: Int) extends Serializable { | override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match { @@ -105,8 +102,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val protocol = JsonParser.InterfaceTypeDefinition.parseInterface(nestedInterfaceExample) val code = mkScalaCodeGen generate protocol - code.head._2.unindent should equalLines ( - """/** example of nested protocols */ + code.head._2.unindent should equalLines("""/** example of nested protocols */ |sealed abstract class nestedProtocolExample() extends Serializable { | override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match { | case _: nestedProtocolExample => true @@ -163,8 +159,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val schema = JsonParser.Document.parse(generateArgDocExample) val code = mkScalaCodeGen generate schema - code.head._2.withoutEmptyLines should equalLines ( - """sealed abstract class generateArgDocExample( + code.head._2.withoutEmptyLines should equalLines("""sealed abstract class generateArgDocExample( | val field: Int) extends Serializable { | /** | * A very simple example of a message. @@ -195,8 +190,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val record = JsonParser.ObjectTypeDefinition.parse(simpleRecordExample) val code = mkScalaCodeGen generate record - code.head._2.unindent should equalLines ( - """/** Example of simple record */ + code.head._2.unindent should equalLines("""/** Example of simple record */ |final class simpleRecordExample private ( |val field: java.net.URL) extends Serializable { | // Some extra code... @@ -227,8 +221,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val record = JsonParser.ObjectTypeDefinition.parse(growableAddOneFieldExample) val code = mkScalaCodeGen generate record - code.head._2.unindent should equalLines ( - """final class growableAddOneField private ( + code.head._2.unindent should equalLines("""final class growableAddOneField private ( | val field: Int) extends Serializable { | private def this() = this(0) | override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match { @@ -259,8 +252,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val record = JsonParser.ObjectTypeDefinition.parse(growableZeroToOneToTwoFieldsExample) val code = mkScalaCodeGen generate record - code.head._2.unindent should equalLines ( - """final class Foo private ( + code.head._2.unindent should equalLines("""final class Foo private ( | val x: Option[Int], | val y: Vector[Int]) extends Serializable { | private def this() = this(Option(0), Vector(0)) @@ -302,8 +294,7 @@ class JsonScalaCodeGenSpec extends GCodeGenSpec("Scala") { val record = JsonParser.ObjectTypeDefinition.parse(primitiveTypesExample2) val code = mkScalaCodeGen generate record - code.head._2.unindent should equalLines ( - """final class primitiveTypesExample2 private ( + code.head._2.unindent should equalLines("""final class primitiveTypesExample2 private ( val smallBoolean: Boolean, val bigBoolean: Boolean) extends Serializable { @@ -339,8 +330,7 @@ object primitiveTypesExample2 { val record = JsonParser.ObjectTypeDefinition.parse(modifierExample) val code = mkScalaCodeGen generate record - code.head._2.unindent should equalLines ( - """sealed class modifierExample private ( + code.head._2.unindent should equalLines("""sealed class modifierExample private ( |val field: Int) extends Serializable { | override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match { | case x: modifierExample => (this.field == x.field) @@ -368,8 +358,7 @@ object primitiveTypesExample2 { override def schemaGenerateTypeReferences = { val schema = JsonParser.Document.parse(primitiveTypesExample) val code = mkScalaCodeGen generate schema - code.head._2.unindent should equalLines ( - """final class primitiveTypesExample private ( + code.head._2.unindent should equalLines("""final class primitiveTypesExample private ( | val simpleInteger: Int, | _lazyInteger: => Int, | val arrayInteger: Vector[Int], @@ -431,8 +420,7 @@ object primitiveTypesExample2 { val schema = JsonParser.Document.parse(primitiveTypesNoLazyExample) val code = mkScalaCodeGen generate schema - code.head._2.unindent should equalLines ( - """final class primitiveTypesNoLazyExample private ( + code.head._2.unindent should equalLines("""final class primitiveTypesNoLazyExample private ( | | val simpleInteger: Int, | @@ -469,18 +457,26 @@ object primitiveTypesExample2 { override def schemaGenerateComplete = { val schema = JsonParser.Document.parse(completeExample) val code = mkScalaCodeGen generate schema - code.head._2.unindent should equalLines (completeExampleCodeScala.unindent) + code.head._2.unindent should equalLines(completeExampleCodeScala.unindent) } override def schemaGenerateCompletePlusIndent = { val schema = JsonParser.Document.parse(completeExample) val code = mkScalaCodeGen generate schema - code.head._2.withoutEmptyLines should equalLines (completeExampleCodeScala.withoutEmptyLines) + code.head._2.withoutEmptyLines should equalLines(completeExampleCodeScala.withoutEmptyLines) } def mkScalaCodeGen: ScalaCodeGen = - new ScalaCodeGen(javaLazy, CodeGen.javaOptional, CodeGen.instantiateJavaOptional, scalaArray, genFileName, - scalaSealProtocols = true, scalaPrivateConstructor = true, wrapOption = true) + new ScalaCodeGen( + javaLazy, + CodeGen.javaOptional, + CodeGen.instantiateJavaOptional, + scalaArray, + genFileName, + scalaSealProtocols = true, + scalaPrivateConstructor = true, + wrapOption = true + ) val javaLazy = "com.example.Lazy" val outputFile = new File("output.scala") val scalaArray = "Vector" diff --git a/library/src/test/scala/JsonSchemaSpec.scala b/library/src/test/scala/JsonSchemaSpec.scala index 0196963..adb9d7e 100644 --- a/library/src/test/scala/JsonSchemaSpec.scala +++ b/library/src/test/scala/JsonSchemaSpec.scala @@ -123,18 +123,20 @@ class JsonSchemaSpec extends AnyFlatSpec with Matchers with Inside { "Enumeration.parse" should "parse simple enumeration" in { JsonParser.EnumTypeDefinition.parse(simpleEnumerationExample) match { - case e@EnumTypeDefinition(name, namespace, values, directives, comments, _, _) => + case e @ EnumTypeDefinition(name, namespace, values, directives, comments, _, _) => val doc = toDoc(comments) val target = toTarget(directives) val extra = toExtra(e) - assert((name === "simpleEnumerationExample") && - (target === Some("Scala")) && - (namespace === None) && - (doc === List("Example of simple enumeration")) && - (values.size === 2) && - (values(0) === EnumValueDefinition("first", Nil, List(DocComment("First symbol")), None)) && - (values(1) === EnumValueDefinition("second", Nil, Nil, None)) && - (extra === List("// Some extra code..."))) + assert( + (name === "simpleEnumerationExample") && + (target === Some("Scala")) && + (namespace === None) && + (doc === List("Example of simple enumeration")) && + (values.size === 2) && + (values(0) === EnumValueDefinition("first", Nil, List(DocComment("First symbol")), None)) && + (values(1) === EnumValueDefinition("second", Nil, Nil, None)) && + (extra === List("// Some extra code...")) + ) } } diff --git a/library/src/test/scala/TestUtils.scala b/library/src/test/scala/TestUtils.scala index c541953..6aaeb7c 100644 --- a/library/src/test/scala/TestUtils.scala +++ b/library/src/test/scala/TestUtils.scala @@ -6,7 +6,13 @@ import scala.collection.JavaConverters._ import difflib._ object TestUtils { - def unifiedDiff(expectedName: String, obtainedName: String, expected: sciSeq[String], obtained: sciSeq[String], contextSize: Int): Vector[String] = { + def unifiedDiff( + expectedName: String, + obtainedName: String, + expected: sciSeq[String], + obtained: sciSeq[String], + contextSize: Int + ): Vector[String] = { val patch = DiffUtils.diff(expected.asJava, obtained.asJava) DiffUtils.generateUnifiedDiff(expectedName, obtainedName, expected.asJava, patch, contextSize).asScala.toVector } diff --git a/plugin/src/main/scala/ContrabandPlugin.scala b/plugin/src/main/scala/ContrabandPlugin.scala index 2d05f9e..8ad3508 100644 --- a/plugin/src/main/scala/ContrabandPlugin.scala +++ b/plugin/src/main/scala/ContrabandPlugin.scala @@ -3,13 +3,14 @@ package sbt.contraband import sbt.Keys._ import sbt._ import sbt.contraband.ast._ -import sbt.contraband.parser.{JsonParser, SchemaParser} +import sbt.contraband.parser.{ JsonParser, SchemaParser } object ContrabandPlugin extends AutoPlugin { private def scalaDef2File(x: Any) = x match { - case d: TypeDefinition => d.namespace map (ns => new File(ns.replace(".", "/"))) map (new File(_, d.name + ".scala")) getOrElse new File(d.name + ".scala") + case d: TypeDefinition => + d.namespace map (ns => new File(ns.replace(".", "/"))) map (new File(_, d.name + ".scala")) getOrElse new File(d.name + ".scala") } object autoImport { @@ -26,8 +27,10 @@ object ContrabandPlugin extends AutoPlugin { val contrabandScalaPrivateConstructor = settingKey[Boolean]("Hide the constructors in Scala.") val contrabandWrapOption = settingKey[Boolean]("Provide constructors that automatically wraps the options.") val contrabandCodecParents = settingKey[List[String]]("Parents to add all o of the codec object.") - val contrabandInstantiateJavaLazy = settingKey[String => String]("Function that instantiate a lazy expression from an expression in Java.") - val contrabandInstantiateJavaOptional = settingKey[(String, String) => String]("Function that instantiate a optional expression from an expression in Java.") + val contrabandInstantiateJavaLazy = + settingKey[String => String]("Function that instantiate a lazy expression from an expression in Java.") + val contrabandInstantiateJavaOptional = + settingKey[(String, String) => String]("Function that instantiate a optional expression from an expression in Java.") val contrabandFormatsForType = settingKey[Type => List[String]]("Function that maps types to the list of required codecs for them.") val contrabandSjsonNewVersion = settingKey[String]("The version of sjson-new to use") @@ -57,7 +60,8 @@ object ContrabandPlugin extends AutoPlugin { generateContrabands / contrabandInstantiateJavaOptional := CodeGen.instantiateJavaOptional, generateContrabands / contrabandFormatsForType := CodecCodeGen.formatsForType, generateContrabands := { - Generate((generateContrabands / contrabandSource).value, + Generate( + (generateContrabands / contrabandSource).value, !(generateContrabands / skipGeneration).value, !(generateJsonCodecs / skipGeneration).value, (generateContrabands / sourceManaged).value, @@ -73,7 +77,8 @@ object ContrabandPlugin extends AutoPlugin { (generateContrabands / contrabandInstantiateJavaLazy).value, (generateContrabands / contrabandInstantiateJavaOptional).value, (generateContrabands / contrabandFormatsForType).value, - streams.value) + streams.value + ) }, Compile / sourceGenerators += generateContrabands.taskValue ) @@ -108,22 +113,24 @@ object ContrabandPlugin extends AutoPlugin { } object Generate { - private def generate(createDatatypes: Boolean, - createCodecs: Boolean, - definitions: Array[File], - target: File, - javaLazy: String, - javaOption: String, - scalaArray: String, - scalaFileNames: Any => File, - scalaSealInterface: Boolean, - scalaPrivateConstructor: Boolean, - wrapOption: Boolean, - codecParents: List[String], - instantiateJavaLazy: String => String, - instantiateJavaOptional: (String, String) => String, - formatsForType: Type => List[String], - log: Logger): Seq[File] = { + private def generate( + createDatatypes: Boolean, + createCodecs: Boolean, + definitions: Array[File], + target: File, + javaLazy: String, + javaOption: String, + scalaArray: String, + scalaFileNames: Any => File, + scalaSealInterface: Boolean, + scalaPrivateConstructor: Boolean, + wrapOption: Boolean, + codecParents: List[String], + instantiateJavaLazy: String => String, + instantiateJavaOptional: (String, String) => String, + formatsForType: Type => List[String], + log: Logger + ): Seq[File] = { val jsonFiles = definitions.toList collect { case f: File if f.getName endsWith ".json" => f } @@ -132,27 +139,36 @@ object Generate { } val input = (jsonFiles map { f => JsonParser.Document.parse(IO read f) }) ++ - (contraFiles map { f => - val ast = SchemaParser.parse(IO read f).get - ast - }) - val generator = new MixedCodeGen(javaLazy, javaOption, instantiateJavaOptional, - scalaArray, scalaFileNames, scalaSealInterface, scalaPrivateConstructor, wrapOption) - val jsonFormatsGenerator = new CodecCodeGen(codecParents, instantiateJavaLazy, - javaOption, scalaArray, formatsForType, input) + (contraFiles map { f => + val ast = SchemaParser.parse(IO read f).get + ast + }) + val generator = new MixedCodeGen( + javaLazy, + javaOption, + instantiateJavaOptional, + scalaArray, + scalaFileNames, + scalaSealInterface, + scalaPrivateConstructor, + wrapOption + ) + val jsonFormatsGenerator = new CodecCodeGen(codecParents, instantiateJavaLazy, javaOption, scalaArray, formatsForType, input) val datatypes = if (createDatatypes) { input flatMap { s => - generator.generate(s).map { - case (file, code) => + generator + .generate(s) + .map { case (file, code) => val outputFile = new File(target, "/" + file.toString) IO.write(outputFile, code) log.debug(s"sbt-contraband created $outputFile") // println(code) // println("---------") outputFile - }.toList + } + .toList } } else { List.empty @@ -161,8 +177,9 @@ object Generate { val formats = if (createCodecs) { input flatMap { s => - jsonFormatsGenerator.generate(s).map { - case (file, code) => + jsonFormatsGenerator + .generate(s) + .map { case (file, code) => // println(code) // println("---------") val outputFile = new File(target, "/" + file.toString) @@ -170,7 +187,8 @@ object Generate { log.debug(s"sbt-contraband created $outputFile") outputFile - }.toList + } + .toList } } else { List.empty @@ -178,31 +196,48 @@ object Generate { datatypes ++ formats } - def apply(base: File, - createDatatypes: Boolean, - createCodecs: Boolean, - target: File, - javaLazy: String, - javaOption: String, - scalaArray: String, - scalaFileNames: Any => File, - scalaSealInterface: Boolean, - scalaPrivateConstructor: Boolean, - scalaVersion: String, - wrapOption: Boolean, - codecParents: List[String], - instantiateJavaLazy: String => String, - instantiateJavaOptional: (String, String) => String, - formatsForType: Type => List[String], - s: TaskStreams): Seq[File] = { + def apply( + base: File, + createDatatypes: Boolean, + createCodecs: Boolean, + target: File, + javaLazy: String, + javaOption: String, + scalaArray: String, + scalaFileNames: Any => File, + scalaSealInterface: Boolean, + scalaPrivateConstructor: Boolean, + scalaVersion: String, + wrapOption: Boolean, + codecParents: List[String], + instantiateJavaLazy: String => String, + instantiateJavaOptional: (String, String) => String, + formatsForType: Type => List[String], + s: TaskStreams + ): Seq[File] = { val definitions = IO listFiles base - def gen() = generate(createDatatypes, createCodecs, definitions, target, javaLazy, javaOption, scalaArray, - scalaFileNames, scalaSealInterface, scalaPrivateConstructor, wrapOption, - codecParents, instantiateJavaLazy, instantiateJavaOptional, formatsForType, s.log) + def gen() = generate( + createDatatypes, + createCodecs, + definitions, + target, + javaLazy, + javaOption, + scalaArray, + scalaFileNames, + scalaSealInterface, + scalaPrivateConstructor, + wrapOption, + codecParents, + instantiateJavaLazy, + instantiateJavaOptional, + formatsForType, + s.log + ) val scalaVersionSubDir = scalaVersion match { case VersionNumber(Seq(x, y, _*), _, _) => s"scala-$x.$y" - case _ => throw new IllegalArgumentException(s"Invalid Scala version: '$scalaVersion'") + case _ => throw new IllegalArgumentException(s"Invalid Scala version: '$scalaVersion'") } val cacheDirectory = s.cacheDirectory / scalaVersionSubDir / "gen-api" diff --git a/project/plugins.sbt b/project/plugins.sbt index 6f4bfba..2fc87fd 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,4 @@ addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1") addSbtPlugin("com.typesafe.sbt" % "sbt-site" % "1.4.0") addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.6.3") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6")