From 5841627c91e3fe246cf33fb31ae491ec437a654e Mon Sep 17 00:00:00 2001 From: auht <101095686+auht@users.noreply.github.com> Date: Fri, 4 Oct 2024 22:45:15 +0800 Subject: [PATCH] Constraint solving for function overloading (#213) Co-authored-by: Lionel Parreaux --- .../scala/mlscript/compiler/ClassLifter.scala | 16 +- .../scala/mlscript/ConstraintSolver.scala | 88 ++++ .../src/main/scala/mlscript/NuTypeDefs.scala | 14 +- .../main/scala/mlscript/TypeSimplifier.scala | 48 ++- shared/src/main/scala/mlscript/Typer.scala | 44 +- .../main/scala/mlscript/TyperDatatypes.scala | 118 ++++- .../main/scala/mlscript/TyperHelpers.scala | 48 +-- .../codegen/typescript/TsTypegen.scala | 2 +- shared/src/main/scala/mlscript/helpers.scala | 14 +- shared/src/main/scala/mlscript/syntax.scala | 2 +- shared/src/test/diff/fcp/Overloads.mls | 63 ++- .../src/test/diff/fcp/Overloads_Precise.mls | 197 +++++++++ shared/src/test/diff/nu/HeungTung.mls | 402 +++++++++++++++++- .../src/test/scala/mlscript/DiffTests.scala | 5 +- 14 files changed, 969 insertions(+), 92 deletions(-) create mode 100644 shared/src/test/diff/fcp/Overloads_Precise.mls diff --git a/compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala b/compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala index 533ab83b16..b6aa7bcce3 100644 --- a/compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala +++ b/compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala @@ -510,7 +510,7 @@ class ClassLifter(logDebugMsg: Boolean = false) { val nlhs = liftType(lb) val nrhs = liftType(ub) Bounds(nlhs._1, nrhs._1) -> (nlhs._2 ++ nrhs._2) - case Constrained(base: Type, bounds, where) => + case Constrained(base: Type, bounds, where, tscs) => val (nTargs, nCtx) = bounds.map { case (tv, Bounds(lb, ub)) => val nlhs = liftType(lb) val nrhs = liftType(ub) @@ -521,10 +521,18 @@ class ClassLifter(logDebugMsg: Boolean = false) { val nrhs = liftType(ub) Bounds(nlhs._1, nrhs._1) -> (nlhs._2 ++ nrhs._2) }.unzip + val (tscs0, nCtx3) = tscs.map { case (tvs, cs) => + val (ntvs,c0) = tvs.map { case (p,v) => + val (nv, c) = liftType(v) + (p,nv) -> c + }.unzip + val (ncs,c1) = cs.map(_.map(liftType).unzip).unzip + (ntvs,ncs) -> (c0 ++ c1.flatten) + }.unzip val (nBase, bCtx) = liftType(base) - Constrained(nBase, nTargs, bounds2) -> - ((nCtx ++ nCtx2).fold(emptyCtx)(_ ++ _) ++ bCtx) - case Constrained(_, _, _) => die + Constrained(nBase, nTargs, bounds2, tscs0) -> + ((nCtx ++ nCtx2 ++ nCtx3.flatten).fold(emptyCtx)(_ ++ _) ++ bCtx) + case Constrained(_, _, _, _) => die case Function(lhs, rhs) => val nlhs = liftType(lhs) val nrhs = liftType(rhs) diff --git a/shared/src/main/scala/mlscript/ConstraintSolver.scala b/shared/src/main/scala/mlscript/ConstraintSolver.scala index 54b7e91210..5d60ac1ff5 100644 --- a/shared/src/main/scala/mlscript/ConstraintSolver.scala +++ b/shared/src/main/scala/mlscript/ConstraintSolver.scala @@ -627,6 +627,35 @@ class ConstraintSolver extends NormalForms { self: Typer => recLb(ar.inner, b.inner) rec(b.inner.ub, ar.inner.ub, false) case (LhsRefined(S(b: ArrayBase), ts, r, _), _) => reportError() + case (LhsRefined(S(ov: Overload), ts, r, trs), RhsBases(_, S(L(f: FunctionType)), _)) if noApproximateOverload => + TupleSetConstraints.mk(ov, f) match { + case S(tsc) => + if (tsc.tvs.nonEmpty) { + tsc.tvs.mapValuesIter(_.unwrapProxies).zipWithIndex.flatMap { + case ((true, tv: TV), i) => tv.lowerBounds.iterator.map((_,tv,i,true)) + case ((false, tv: TV), i) => tv.upperBounds.iterator.map((_,tv,i,false)) + case _ => Nil + }.find { + case (b,_,i,_) => + tsc.updateImpl(i,b) + tsc.constraints.isEmpty + }.foreach { + case (b,tv,_,p) => if (p) rec(b,tv,false) else rec(tv,b,false) + } + if (tsc.constraints.sizeCompare(1) === 0) { + tsc.tvs.values.map(_.unwrapProxies).foreach { + case tv: TV => tv.tsc.remove(tsc) + case _ => () + } + tsc.constraints.head.iterator.zip(tsc.tvs).foreach { + case (c, (pol, t)) => + if (!pol) rec(c, t, false) + if (pol) rec(t, c, false) + } + } + } + case N => reportError(S(msg"is not an instance of `${f.expNeg}`")) + } case (LhsRefined(S(ov: Overload), ts, r, trs), _) => annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints case (LhsRefined(S(Without(b, ns)), ts, r, _), RhsBases(pts, N | S(L(_)), _)) => @@ -832,6 +861,28 @@ class ConstraintSolver extends NormalForms { self: Typer => val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) => if (c.prov is noProv) ty else mkProxy(ty, c.prov)) lhs.upperBounds ::= newBound // update the bound + if (noApproximateOverload) { + lhs.tsc.foreachEntry { (tsc, v) => + v.foreach { i => + if (!tsc.tvs(i)._1) { + tsc.updateOn(i, rhs) + if (tsc.constraints.isEmpty) reportError() + } + } + } + val u = lhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate + u._1.foreach { k => + k.tvs.mapValuesIter(_.unwrapProxies).foreach { + case (_,tv: TV) => tv.tsc.remove(k) + case _ => () + } + } + u._2.foreach { k => + k.constraints.head.iterator.zip(k.tvs).foreach { + case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false) + } + } + } lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound case (lhs, rhs: TypeVariable) if lhs.level <= rhs.level => @@ -839,6 +890,28 @@ class ConstraintSolver extends NormalForms { self: Typer => val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) => if (c.prov is noProv) ty else mkProxy(ty, c.prov)) rhs.lowerBounds ::= newBound // update the bound + if (noApproximateOverload) { + rhs.tsc.foreachEntry { (tsc, v) => + v.foreach { i => + if(tsc.tvs(i)._1) { + tsc.updateOn(i, lhs) + if (tsc.constraints.isEmpty) reportError() + } + } + } + val u = rhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate + u._1.foreach { k => + k.tvs.mapValuesIter(_.unwrapProxies).foreach { + case (_,tv: TV) => tv.tsc.remove(k) + case _ => () + } + } + u._2.foreach { k => + k.constraints.head.iterator.zip(k.tvs).foreach { + case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false) + } + } + } rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound @@ -1562,9 +1635,24 @@ class ConstraintSolver extends NormalForms { self: Typer => assert(lvl <= below, "this condition should be false for the result to be correct") lvl }) + val freshentsc = tv.tsc.flatMap { case (tsc,_) => + if (tsc.tvs.values.map(_.unwrapProxies).forall { + case tv: TV => !freshened.contains(tv) + case _ => true + }) S(tsc) else N + } freshened += tv -> v v.lowerBounds = tv.lowerBounds.mapConserve(freshen) v.upperBounds = tv.upperBounds.mapConserve(freshen) + freshentsc.foreach { tsc => + val t = new TupleSetConstraints(tsc.constraints, tsc.tvs) + t.constraints = t.constraints.map(_.map(freshen)) + t.tvs = t.tvs.map(x => (x._1,freshen(x._2))) + t.tvs.values.map(_.unwrapProxies).zipWithIndex.foreach { + case (tv: TV, i) => tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i)))) + case _ => () + } + } v } diff --git a/shared/src/main/scala/mlscript/NuTypeDefs.scala b/shared/src/main/scala/mlscript/NuTypeDefs.scala index 3470c6807b..45a19986aa 100644 --- a/shared/src/main/scala/mlscript/NuTypeDefs.scala +++ b/shared/src/main/scala/mlscript/NuTypeDefs.scala @@ -1165,7 +1165,19 @@ class NuTypeDefs extends ConstraintSolver { self: Typer => ctx.nextLevel { implicit ctx: Ctx => assert(fd.tparams.sizeCompare(tparamsSkolems) === 0, (fd.tparams, tparamsSkolems)) vars ++ tparamsSkolems |> { implicit vars => - typeTerm(body) + val ty = typeTerm(body) + if (noApproximateOverload) { + val ambiguous = ty.getVars.unsorted.flatMap(_.tsc.keys.flatMap(_.tvs)) + .groupBy(_._2) + .filter { case (v,pvs) => pvs.sizeIs > 1 } + if (ambiguous.nonEmpty) raise(ErrorReport( + msg"ambiguous" -> N :: + ambiguous.map { case (v,_) => + msg"cannot determine satisfiability of type ${v.expPos}" -> v.prov.loco + }.toList + , true)) + } + ty } } } else { diff --git a/shared/src/main/scala/mlscript/TypeSimplifier.scala b/shared/src/main/scala/mlscript/TypeSimplifier.scala index f444498dad..f5a0bd8136 100644 --- a/shared/src/main/scala/mlscript/TypeSimplifier.scala +++ b/shared/src/main/scala/mlscript/TypeSimplifier.scala @@ -25,6 +25,7 @@ trait TypeSimplifier { self: Typer => println(s"allVarPols: ${printPols(allVarPols)}") val renewed = MutMap.empty[TypeVariable, TypeVariable] + val renewedtsc = MutMap.empty[TupleSetConstraints, TupleSetConstraints] def renew(tv: TypeVariable): TypeVariable = renewed.getOrElseUpdate(tv, @@ -78,8 +79,17 @@ trait TypeSimplifier { self: Typer => ).map(process(_, S(false -> tv))) .reduceOption(_ &- _).filterNot(_.isTop).toList else Nil + if (noApproximateOverload) + nv.tsc ++= tv.tsc.iterator.map { case (tsc, i) => renewedtsc.get(tsc) match { + case S(tsc) => (tsc, i) + case N if inPlace => (tsc, i) + case N => + val t = new TupleSetConstraints(tsc.constraints, tsc.tvs) + renewedtsc += tsc -> t + t.tvs = t.tvs.map(x => (x._1, process(x._2, N))) + (t, i) + }} } - nv case ComposedType(true, l, r) => @@ -549,9 +559,17 @@ trait TypeSimplifier { self: Typer => analyzed1.setAndIfUnset(tv -> pol(tv).getOrElse(false)) { apply(pol)(ty) } case N => if (pol(tv) =/= S(false)) - analyzed1.setAndIfUnset(tv -> true) { tv.lowerBounds.foreach(apply(pol.at(tv.level, true))) } + analyzed1.setAndIfUnset(tv -> true) { + tv.lowerBounds.foreach(apply(pol.at(tv.level, true))) + if (noApproximateOverload) + tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2)) + } if (pol(tv) =/= S(true)) - analyzed1.setAndIfUnset(tv -> false) { tv.upperBounds.foreach(apply(pol.at(tv.level, false))) } + analyzed1.setAndIfUnset(tv -> false) { + tv.upperBounds.foreach(apply(pol.at(tv.level, false))) + if (noApproximateOverload) + tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2)) + } } case _ => super.apply(pol)(st) @@ -643,8 +661,11 @@ trait TypeSimplifier { self: Typer => case tv: TypeVariable => pol(tv) match { case S(pol_tv) => - if (analyzed2.add(pol_tv -> tv)) + if (analyzed2.add(pol_tv -> tv)) { processImpl(st, pol, pol_tv) + if (noApproximateOverload) + tv.tsc.keys.flatMap(_.tvs).foreach(u => processImpl(u._2,pol.at(tv.level,u._1),pol_tv)) + } case N => if (analyzed2.add(true -> tv)) // * To compute the positive co-occurrences @@ -690,6 +711,7 @@ trait TypeSimplifier { self: Typer => case S(p) => (if (p) tv2.lowerBounds else tv2.upperBounds).foreach(go) // (if (p) getLbs(tv2) else getUbs(tv2)).foreach(go) + if (noApproximateOverload) tv2.tsc.keys.flatMap(_.tvs).foreach(u => go(u._2)) case N => trace(s"Analyzing invar-occ of $tv2") { analyze2(tv2, pol) @@ -789,7 +811,7 @@ trait TypeSimplifier { self: Typer => // * Remove variables that are 'dominated' by another type or variable // * A variable v dominated by T if T is in both of v's positive and negative cooccurrences - allVars.foreach { case v => if (v.assignedTo.isEmpty && !varSubst.contains(v)) { + allVars.foreach { case v => if (v.assignedTo.isEmpty && !varSubst.contains(v) && v.tsc.isEmpty) { println(s"2[v] $v ${coOccurrences.get(true -> v)} ${coOccurrences.get(false -> v)}") coOccurrences.get(true -> v).iterator.flatMap(_.iterator).foreach { @@ -807,6 +829,7 @@ trait TypeSimplifier { self: Typer => case w: TV if !(w is v) && !varSubst.contains(w) && !varSubst.contains(v) && !recVars(v) && coOccurrences.get(false -> v).exists(_(w)) + && w.tsc.isEmpty => // * Here we know that v is 'dominated' by w, so v can be inlined. // * Note that we don't want to unify the two variables here @@ -833,7 +856,7 @@ trait TypeSimplifier { self: Typer => // * Unify equivalent variables based on polar co-occurrence analysis: allVars.foreach { case v => - if (!v.assignedTo.isDefined && !varSubst.contains(v)) // TODO also handle v.assignedTo.isDefined? + if (!v.assignedTo.isDefined && !varSubst.contains(v) && v.tsc.isEmpty) // TODO also handle v.assignedTo.isDefined? trace(s"3[v] $v +${coOccurrences.get(true -> v).mkString} -${coOccurrences.get(false -> v).mkString}") { def go(pol: Bool): Unit = coOccurrences.get(pol -> v).iterator.flatMap(_.iterator).foreach { @@ -850,6 +873,7 @@ trait TypeSimplifier { self: Typer => ) && (v.level === w.level) // ^ Don't merge variables of differing levels + && w.tsc.isEmpty => trace(s"[w] $w ${printPol(S(pol))}${coOccurrences.get(pol -> w).mkString}") { @@ -923,6 +947,7 @@ trait TypeSimplifier { self: Typer => println(s"[rec] ${recVars}") val renewals = MutMap.empty[TypeVariable, TypeVariable] + val renewaltsc = MutMap.empty[TupleSetConstraints, TupleSetConstraints] val semp = Set.empty[TV] @@ -999,7 +1024,7 @@ trait TypeSimplifier { self: Typer => nv }) pol(tv) match { - case S(p) if inlineBounds && !occursInvariantly(tv) && !recVars.contains(tv) => + case S(p) if inlineBounds && !occursInvariantly(tv) && !recVars.contains(tv) && tv.tsc.isEmpty => // * Inline the bounds of non-rec non-invar-occ type variables println(s"Inlining [${printPol(p)}] bounds of $tv (~> $res)") // if (p) mergeTransform(true, pol, tv, Set.single(tv), canDistribForall) | res @@ -1017,6 +1042,15 @@ trait TypeSimplifier { self: Typer => res.lowerBounds = tv.lowerBounds.map(transform(_, pol.at(tv.level, true), Set.single(tv))) if (occNums.contains(false -> tv)) res.upperBounds = tv.upperBounds.map(transform(_, pol.at(tv.level, false), Set.single(tv))) + if (noApproximateOverload) + res.tsc ++= tv.tsc.map { case (tsc, i) => renewaltsc.get(tsc) match { + case S(tsc) => (tsc, i) + case N => + val t = new TupleSetConstraints(tsc.constraints, tsc.tvs) + renewaltsc += tsc -> t + t.tvs = t.tvs.map(x => (x._1, transform(x._2, PolMap.neu, Set.empty))) + (t, i) + }} } res }() diff --git a/shared/src/main/scala/mlscript/Typer.scala b/shared/src/main/scala/mlscript/Typer.scala index ad79b34a47..af88e803bc 100644 --- a/shared/src/main/scala/mlscript/Typer.scala +++ b/shared/src/main/scala/mlscript/Typer.scala @@ -34,6 +34,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne var constrainedTypes: Boolean = false var recordProvenances: Boolean = true + var noApproximateOverload: Boolean = false type Binding = Str -> SimpleType type Bindings = Map[Str, SimpleType] @@ -168,7 +169,17 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne assert(b.level > lvl) if (p) (b, tv) else (tv, b) } }.toList, innerTy) - + + if (noApproximateOverload) { + val ambiguous = innerTy.getVars.unsorted.flatMap(_.tsc.keys.flatMap(_.tvs)) + .groupBy(_._2) + .filter { case (v,pvs) => pvs.sizeIs > 1 } + if (ambiguous.nonEmpty) raise(ErrorReport( + msg"ambiguous" -> N :: + ambiguous.map { case (v,_) => msg"cannot determine satisfiability of type ${v.expPos}" -> v.prov.loco }.toList + , true)) + } + println(s"Inferred poly constr: $cty —— where ${cty.showBounds}") val cty_fresh = @@ -662,7 +673,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne tv.assignedTo = S(bod) tv case Rem(base, fs) => Without(rec(base), fs.toSortedSet)(tyTp(ty.toLoc, "field removal type")) - case Constrained(base, tvbs, where) => + case Constrained(base, tvbs, where, tscs) => val res = rec(base match { case ty: Type => ty case _ => die @@ -675,6 +686,14 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne constrain(rec(lo), rec(hi))(raise, tp(mergeOptions(lo.toLoc, hi.toLoc)(_ ++ _), "constraint specifiation"), ctx) } + tscs.foreach { case (typevars, constrs) => + val tvs = typevars.map(x => (x._1, rec(x._2))) + val tsc = new TupleSetConstraints(constrs.map(_.map(rec)), tvs) + tvs.values.map(_.unwrapProxies).zipWithIndex.foreach { + case (tv: TV, i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i)))) + case _ => () + } + } res case PolyType(vars, ty) => val oldLvl = ctx.lvl @@ -1860,8 +1879,10 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne val expandType = () var bounds: Ls[TypeVar -> Bounds] = Nil + var tscs: Ls[Ls[(Bool, Type)] -> Ls[Ls[Type]]] = Nil val seenVars = mutable.Set.empty[TV] + val seenTscs = mutable.Set.empty[TupleSetConstraints] def field(ft: FieldType)(implicit ectx: ExpCtx): Field = ft match { case FieldType(S(l: TV), u: TV) if l === u => @@ -1969,6 +1990,14 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne if (l =/= Bot || u =/= Top) bounds ::= nv -> Bounds(l, u) } + tv.tsc.foreachEntry { + case (tsc, i) => + if (seenTscs.add(tsc)) { + val tvs = tsc.tvs.map(x => (x._1,go(x._2))) + val constrs = tsc.constraints.map(_.map(go)) + tscs ::= tvs -> constrs + } + } } nv }) @@ -2018,17 +2047,20 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne case Overload(as) => as.map(go).reduce(Inter) case PolymorphicType(lvl, bod) => val boundsSize = bounds.size + val tscsSize = tscs.size val b = go(bod) // This is not completely correct: if we've already traversed TVs as part of a previous sibling PolymorphicType, // the bounds of these TVs won't be registered again... // FIXME in principle we'd want to compute a transitive closure... val newBounds = bounds.reverseIterator.drop(boundsSize).toBuffer + val newTscs = tscs.reverseIterator.drop(tscsSize).toBuffer val qvars = bod.varsBetween(lvl, MaxLevel).iterator val ftvs = b.freeTypeVariables ++ newBounds.iterator.map(_._1) ++ - newBounds.iterator.flatMap(_._2.freeTypeVariables) + newBounds.iterator.flatMap(_._2.freeTypeVariables) ++ + newTscs.iterator.flatMap(_._1.map(_._2)) val fvars = qvars.filter(tv => ftvs.contains(tv.asTypeVar)) if (fvars.isEmpty) b else PolyType(fvars @@ -2042,7 +2074,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne val lbs = groups2.toList val bounds = (ubs.mapValues(_.reduce(_ &- _)) ++ lbs.mapValues(_.reduce(_ | _)).map(_.swap)) val processed = bounds.map { case (lo, hi) => Bounds(go(lo), go(hi)) } - Constrained(go(bod), Nil, processed) + Constrained(go(bod), Nil, processed, Nil) // case DeclType(lvl, info) => @@ -2050,8 +2082,8 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne // }(r => s"~> $r") val res = goLike(st)(new ExpCtx(Map.empty)) - if (bounds.isEmpty) res - else Constrained(res, bounds, Nil) + if (bounds.isEmpty && tscs.isEmpty) res + else Constrained(res, bounds, Nil, tscs) // goLike(st) } diff --git a/shared/src/main/scala/mlscript/TyperDatatypes.scala b/shared/src/main/scala/mlscript/TyperDatatypes.scala index 755ba9bfe0..17a95fc8ff 100644 --- a/shared/src/main/scala/mlscript/TyperDatatypes.scala +++ b/shared/src/main/scala/mlscript/TyperDatatypes.scala @@ -1,7 +1,7 @@ package mlscript import scala.collection.mutable -import scala.collection.mutable.{Map => MutMap, Set => MutSet, Buffer} +import scala.collection.mutable.{Map => MutMap, Set => MutSet, Buffer, LinkedHashMap} import scala.collection.immutable.{SortedSet, SortedMap} import scala.util.chaining._ import scala.annotation.tailrec @@ -558,6 +558,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer => require(value.forall(_.level <= level)) _assignedTo = value } + + val tsc: LinkedHashMap[TupleSetConstraints, Set[Int]] = LinkedHashMap.empty // * Bounds should always be disregarded when `equatedTo` is defined, as they are then irrelevant: def lowerBounds: List[SimpleType] = { require(assignedTo.isEmpty, this); _lowerBounds } @@ -670,5 +672,117 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer => lazy val underlying: SimpleType = tt.neg() val prov = noProv } - + + class TupleSetConstraints(var constraints: Ls[Ls[ST]], var tvs: Ls[(Bool, ST)]) { + def updateImpl(index: Int, bound: ST)(implicit raise: Raise, ctx: Ctx) : Unit = { + val u0 = constraints.flatMap { c => + TupleSetConstraints.lcg(tvs(index)._1, bound, c(index)).map(tvs.zip(c)++_) + } + val u = u0.map { x => + x.groupMap(_._1)(_._2).map { case (u@(p,_),l) => + (u,l.reduce((x,y) => ComposedType(!p,x,y)(noProv))) + } + } + if (!u.isEmpty) { + tvs.values.map(_.unwrapProxies).foreach { + case tv: TV => tv.tsc += this -> Set.empty + case _ => () + } + tvs = u.flatMap(_.keys).distinct + constraints = tvs.map(x => u.map(_.getOrElse(x,if (x._1) TopType else BotType))).transpose + tvs.values.map(_.unwrapProxies).zipWithIndex.foreach { + case (tv: TV, i) => tv.tsc.updateWith(this)(_.map(_ + i).orElse(S(Set(i)))) + case _ => () + } + } else { + constraints = Nil + } + } + def updateOn(index: Int, bound: ST)(implicit raise: Raise, ctx: Ctx) : Unit = { + updateImpl(index, bound) + println(s"TSC update: $tvs in $constraints") + } + } + object TupleSetConstraints { + def lcgField(pol: Bool, first: FieldType, rest: FieldType)(implicit ctx: Ctx) + : Opt[Ls[(Bool, ST) -> ST]] = { + for { + ubm <- lcg(pol, first.ub, rest.ub) + lbm <- { + if (first.lb.isEmpty && rest.lb.isEmpty) + S(Nil) + else + lcg(!pol, first.lb.getOrElse(BotType), rest.lb.getOrElse(BotType)) + } + } yield { + ubm ++ lbm + } + } + def lcg(pol: Bool, first: ST, rest: ST)(implicit ctx: Ctx) + : Opt[Ls[(Bool, ST) -> ST]] = (first.unwrapProxies, rest.unwrapProxies) match { + case (a, ExtrType(p)) if p =/= pol => S(Nil) + case (a, ComposedType(p,l,r)) if p =/= pol => + for { + lm <- lcg(pol,a,l) + rm <- lcg(pol,a,r) + } yield { + lm ++ rm + } + case (a: TV, b: TV) if a.compare(b) === 0 => S(Nil) + case (a: TV, b) => S(List((pol, first) -> rest)) + case (a, b: TV) => S(List((pol, first) -> rest)) + case (a: FT, b: FT) => lcgFunction(pol, a, b) + case (a: ArrayType, b: ArrayType) => lcgField(pol, a.inner, b.inner) + case (a: TupleType, b: TupleType) if a.fields.sizeCompare(b.fields) === 0 => + val fs = a.fields.map(_._2).zip(b.fields.map(_._2)).map(u => lcgField(pol, u._1, u._2)) + if (!fs.contains(N)) { + S(fs.flatten.reduce(_++_)) + } else N + case (a: TupleType, b: RecordType) if pol => lcg(pol, a.toRecord, b) + case (a: RecordType, b: RecordType) => + val default = FieldType(N, if (pol) TopType else BotType)(noProv) + if (b.fields.map(_._1).forall(a.fields.map(_._1).contains)) { + val u = a.fields.map { + case (v, f) => lcgField(pol, f, b.fields.find(_._1 === v).fold(default)(_._2)) + } + if (!u.contains(N)) { + S(u.flatten.reduce(_++_)) + } else N + } else N + case (a, b) if a === b => S(Nil) + case (a, b) => + val dnf = DNF.mk(MaxLevel, Nil, if (pol) a & b.neg() else b & a.neg(), true) + if (dnf.isBot) + S(Nil) + else if (dnf.cs.forall(c => !(c.vars.isEmpty && c.nvars.isEmpty))) + S(List((pol, first) -> rest)) + else N + } + def lcgFunction(pol: Bool, first: FT, rest: FT)(implicit ctx: Ctx) + : Opt[Ls[(Bool, ST) -> ST]] = { + for { + lm <- lcg(!pol, first.lhs, rest.lhs) + rm <- lcg(pol, first.rhs, rest.rhs) + } yield { + lm ++ rm + } + } + def mk(ov: Overload, f: FT)(implicit raise: Raise, ctx: Ctx): Opt[TupleSetConstraints] = { + val u = ov.alts.flatMap(lcgFunction(false, f, _)).map { x => + x.groupMap(_._1)(_._2).map { case (u@(p,_),l) => + (u,l.reduce((x,y) => ComposedType(!p,x,y)(noProv))) + } + } + if (u.isEmpty) { return N } + val tvs = u.flatMap(_.keys).distinct + val m = tvs.map(x => u.map(_.getOrElse(x,if (x._1) TopType else BotType))) + val tsc = new TupleSetConstraints(m.transpose, tvs) + tvs.values.map(_.unwrapProxies).zipWithIndex.foreach { + case (tv: TV, i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i)))) + case _ => () + } + println(s"TSC mk: ${tsc.tvs} in ${tsc.constraints}") + S(tsc) + } + } } diff --git a/shared/src/main/scala/mlscript/TyperHelpers.scala b/shared/src/main/scala/mlscript/TyperHelpers.scala index 9b967a1d99..c4e9c2299b 100644 --- a/shared/src/main/scala/mlscript/TyperHelpers.scala +++ b/shared/src/main/scala/mlscript/TyperHelpers.scala @@ -695,36 +695,6 @@ abstract class TyperHelpers { Typer: Typer => case _ => this :: Nil } - def childrenPol(pol: Opt[Bool])(implicit ctx: Ctx): List[Opt[Bool] -> SimpleType] = { - def childrenPolField(fld: FieldType): List[Opt[Bool] -> SimpleType] = - fld.lb.map(pol.map(!_) -> _).toList ::: pol -> fld.ub :: Nil - this match { - case tv @ AssignedVariable(ty) => - pol -> ty :: Nil - case tv: TypeVariable => - (if (pol =/= S(false)) tv.lowerBounds.map(S(true) -> _) else Nil) ::: - (if (pol =/= S(true)) tv.upperBounds.map(S(false) -> _) else Nil) - case FunctionType(l, r) => pol.map(!_) -> l :: pol -> r :: Nil - case Overload(as) => as.map(pol -> _) - case ComposedType(_, l, r) => pol -> l :: pol -> r :: Nil - case RecordType(fs) => fs.unzip._2.flatMap(childrenPolField) - case TupleType(fs) => fs.unzip._2.flatMap(childrenPolField) - case ArrayType(fld) => childrenPolField(fld) - case SpliceType(elems) => elems flatMap {case L(l) => pol -> l :: Nil case R(r) => childrenPolField(r)} - case NegType(n) => pol.map(!_) -> n :: Nil - case ExtrType(_) => Nil - case ProxyType(und) => pol -> und :: Nil - // case _: TypeTag => Nil - case _: ObjectTag | _: Extruded => Nil - case SkolemTag(id) => pol -> id :: Nil - case tr: TypeRef => tr.mapTargs(pol)(_ -> _) - case Without(b, ns) => pol -> b :: Nil - case TypeBounds(lb, ub) => S(false) -> lb :: S(true) -> ub :: Nil - case PolymorphicType(_, und) => pol -> und :: Nil - case ConstrainedType(cs, bod) => - cs.flatMap(vbs => S(true) -> vbs._1 :: S(false) -> vbs._2 :: Nil) ::: pol -> bod :: Nil - }} - /** (exclusive, inclusive) */ def varsBetween(lb: Level, ub: Level): Set[TV] = { val res = MutSet.empty[TypeVariable] @@ -789,7 +759,8 @@ abstract class TyperHelpers { Typer: Typer => case tv: TypeVariable => val poltv = pol(tv) (if (poltv =/= S(false)) tv.lowerBounds.map(pol.at(tv.level, true) -> _) else Nil) ::: - (if (poltv =/= S(true)) tv.upperBounds.map(pol.at(tv.level, false) -> _) else Nil) + (if (poltv =/= S(true)) tv.upperBounds.map(pol.at(tv.level, false) -> _) else Nil) ++ + tv.tsc.keys.flatMap(_.tvs).map(u => pol.at(tv.level,u._1) -> u._2) case FunctionType(l, r) => pol.contravar -> l :: pol.covar -> r :: Nil case Overload(as) => as.map(pol -> _) case ComposedType(_, l, r) => pol -> l :: pol -> r :: Nil @@ -946,7 +917,7 @@ abstract class TyperHelpers { Typer: Typer => } def children(includeBounds: Bool): List[SimpleType] = this match { case tv @ AssignedVariable(ty) => if (includeBounds) ty :: Nil else Nil - case tv: TypeVariable => if (includeBounds) tv.lowerBounds ::: tv.upperBounds else Nil + case tv: TypeVariable => if (includeBounds) tv.lowerBounds ::: tv.upperBounds ++ tv.tsc.keys.flatMap(_.tvs.values) else Nil case FunctionType(l, r) => l :: r :: Nil case Overload(as) => as case ComposedType(_, l, r) => l :: r :: Nil @@ -990,8 +961,16 @@ abstract class TyperHelpers { Typer: Typer => case tv => ("\n\t\t" + tv.toString + (if (tv.lowerBounds.isEmpty) "" else " :> " + tv.lowerBounds.mkString(" | ")) + (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & "))) - }.mkString - + }.mkString + { + val visited: MutSet[TupleSetConstraints] = MutSet.empty + getVars.iterator.flatMap(_.tsc).map { case (tsc, i) => + if (visited.add(tsc)) + ("\n\t\t[ " + + tsc.tvs.map(t => s"${printPol(t._1)}${t._2}").mkString(", ") + + " ] in { " + tsc.constraints.mkString(", ") + " }") + else "" + }.mkString + } } @@ -1336,6 +1315,7 @@ abstract class TyperHelpers { Typer: Typer => val poltv = pol(tv) if (poltv =/= S(false)) tv.lowerBounds.foreach(apply(pol.at(tv.level, true))) if (poltv =/= S(true)) tv.upperBounds.foreach(apply(pol.at(tv.level, false))) + tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2)) case FunctionType(l, r) => apply(pol.contravar)(l); apply(pol)(r) case Overload(as) => as.foreach(apply(pol)) case ComposedType(_, l, r) => apply(pol)(l); apply(pol)(r) diff --git a/shared/src/main/scala/mlscript/codegen/typescript/TsTypegen.scala b/shared/src/main/scala/mlscript/codegen/typescript/TsTypegen.scala index 51a41cfa35..c75f8025e8 100644 --- a/shared/src/main/scala/mlscript/codegen/typescript/TsTypegen.scala +++ b/shared/src/main/scala/mlscript/codegen/typescript/TsTypegen.scala @@ -572,7 +572,7 @@ final class TsTypegenCodeBuilder { typeScope.getTypeAliasSymbol(tvarName).map { taliasInfo => SourceCode(taliasInfo.lexicalName) ++ SourceCode.paramList(taliasInfo.params.map(SourceCode(_))) }.getOrElse(SourceCode(tvarName)) - case Constrained(base, tvbs, where) => + case Constrained(base, tvbs, where, _) => throw CodeGenError(s"Cannot generate type for `where` clause $tvbs $where") case _: Splice | _: TypeTag | _: PolyType | _: Selection => throw CodeGenError(s"Cannot yet generate type for: $mlType") diff --git a/shared/src/main/scala/mlscript/helpers.scala b/shared/src/main/scala/mlscript/helpers.scala index aa9505c371..8131ff246d 100644 --- a/shared/src/main/scala/mlscript/helpers.scala +++ b/shared/src/main/scala/mlscript/helpers.scala @@ -1,5 +1,4 @@ package mlscript - import scala.util.chaining._ import scala.collection.mutable.{Map => MutMap, SortedMap => SortedMutMap, Set => MutSet, Buffer} import scala.collection.immutable.SortedMap @@ -115,7 +114,7 @@ trait TypeLikeImpl extends Located { self: TypeLike => .mkString("forall ", " ", ".")} ${body.showIn(0)}", outerPrec > 1 // or 0? ) - case Constrained(b, bs, ws) => + case Constrained(b, bs, ws, tscs) => val oldCtx = ctx val bStr = b.showIn(0).stripSuffix("\n") val multiline = bStr.contains('\n') @@ -138,6 +137,13 @@ trait TypeLikeImpl extends Located { self: TypeLike => }.mkString }${ws.map{ case Bounds(lo, hi) => s"\n${ctx.indStr}${lo.showIn(0)} <: ${hi.showIn(0)}" // TODO print differently from bs? + }.mkString + }${tscs.map{ + case (tvs, constrs) => + val s = tvs.map(u => (if (u._1) "+" else "-") ++ u._2.showIn(0)) + .mkString("[", ", ", "]") + s"\n${ctx.indStr}" + s + + s" in ${constrs.map(_.map(_.showIn(0)).mkString("[", ", ", "]")).mkString("{", ", ", "}")}" }.mkString}" }, outerPrec > 0) case fd @ NuFunDef(isLetRec, nme, snme, targs, rhs) => @@ -207,7 +213,7 @@ trait TypeLikeImpl extends Located { self: TypeLike => case WithExtension(b, r) => b :: r :: Nil case PolyType(targs, body) => targs.map(_.fold(identity, identity)) :+ body case Splice(fs) => fs.flatMap{ case L(l) => l :: Nil case R(r) => r.in.toList ++ (r.out :: Nil) } - case Constrained(b, bs, ws) => b :: bs.flatMap(c => c._1 :: c._2 :: Nil) ::: ws.flatMap(c => c.lb :: c.ub :: Nil) + case Constrained(b, bs, ws, tscs) => b :: bs.flatMap(c => c._1 :: c._2 :: Nil) ::: ws.flatMap(c => c.lb :: c.ub :: Nil) ::: tscs.flatMap(tsc => tsc._1.map(_._2) ::: tsc._2.flatten) case Signature(xs, res) => xs ::: res.toList case NuFunDef(isLetRec, nme, snme, targs, rhs) => targs ::: rhs.toOption.toList case NuTypeDef(kind, nme, tparams, params, ctor, sig, parents, sup, ths, body) => @@ -782,7 +788,7 @@ trait TermImpl extends StatementImpl { self: Term => Constrained(body.toType_!, Nil, where.map { case Asc(l, r) => Bounds(l.toType_!, r) case s => throw new NotAType(s) - }) + }, Nil) case Forall(ps, bod) => PolyType(ps.map(R(_)), bod.toType_!) // diff --git a/shared/src/main/scala/mlscript/syntax.scala b/shared/src/main/scala/mlscript/syntax.scala index 5cab302df7..44a57eeeee 100644 --- a/shared/src/main/scala/mlscript/syntax.scala +++ b/shared/src/main/scala/mlscript/syntax.scala @@ -161,7 +161,7 @@ final case class Rem(base: Type, names: Ls[Var]) extends Type final case class Bounds(lb: Type, ub: Type) extends Type final case class WithExtension(base: Type, rcd: Record) extends Type final case class Splice(fields: Ls[Either[Type, Field]]) extends Type -final case class Constrained(base: TypeLike, tvBounds: Ls[TypeVar -> Bounds], where: Ls[Bounds]) extends Type +final case class Constrained(base: TypeLike, tvBounds: Ls[TypeVar -> Bounds], where: Ls[Bounds], tscs: Ls[Ls[(Bool, Type)] -> Ls[Ls[Type]]]) extends Type // final case class FirstClassDefn(defn: NuTypeDef) extends Type // TODO // final case class Refinement(base: Type, decls: TypingUnit) extends Type // TODO diff --git a/shared/src/test/diff/fcp/Overloads.mls b/shared/src/test/diff/fcp/Overloads.mls index 2e0e51a78c..770079c253 100644 --- a/shared/src/test/diff/fcp/Overloads.mls +++ b/shared/src/test/diff/fcp/Overloads.mls @@ -70,8 +70,51 @@ IISS 0 (if true then IISS else BBNN) 0 //│ res: bool | number | string -fun x -> (if true then IISS else BBNN) x -//│ res: int -> (bool | number | string) +def f = fun x -> (if true then IISS else BBNN) x +//│ f: int -> (bool | number | string) + +f(0) +//│ res: bool | number | string + +:e +f(0) + 1 +//│ ╔══[ERROR] Type mismatch in operator application: +//│ ║ l.80: f(0) + 1 +//│ ║ ^^^^^^ +//│ ╟── type `bool` is not an instance of type `int` +//│ ║ l.13: def BBNN: bool -> bool & number -> number +//│ ║ ^^^^ +//│ ╟── but it flows into application with expected type `int` +//│ ║ l.80: f(0) + 1 +//│ ╙── ^^^^ +//│ res: error | int + +:e +f : int -> number +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.93: f : int -> number +//│ ║ ^ +//│ ╟── type `bool` is not an instance of type `number` +//│ ║ l.13: def BBNN: bool -> bool & number -> number +//│ ║ ^^^^ +//│ ╟── Note: constraint arises from type reference: +//│ ║ l.93: f : int -> number +//│ ╙── ^^^^^^ +//│ res: int -> number + +:e +f : number -> int +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.106: f : number -> int +//│ ║ ^ +//│ ╟── type `number` does not match type `int | string` +//│ ║ l.106: f : number -> int +//│ ║ ^^^^^^ +//│ ╟── Note: constraint arises from reference: +//│ ║ l.73: def f = fun x -> (if true then IISS else BBNN) x +//│ ╙── ^ +//│ res: number -> int + if true then IISS else BBNN //│ res: bool -> bool & number -> number | int -> int & string -> string @@ -85,11 +128,11 @@ if true then IISS else BBNN :e (if true then IISS else BBNN) : (0 | 1 | true) -> number //│ ╔══[ERROR] Type mismatch in type ascription: -//│ ║ l.86: (if true then IISS else BBNN) : (0 | 1 | true) -> number -//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ║ l.129: (if true then IISS else BBNN) : (0 | 1 | true) -> number +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ //│ ╟── type `true` does not match type `int | string` -//│ ║ l.86: (if true then IISS else BBNN) : (0 | 1 | true) -> number -//│ ╙── ^^^^ +//│ ║ l.129: (if true then IISS else BBNN) : (0 | 1 | true) -> number +//│ ╙── ^^^^ //│ res: (0 | 1 | true) -> number @@ -107,13 +150,13 @@ not test //│ <: test: //│ ~(int -> int) //│ ╔══[ERROR] Type mismatch in application: -//│ ║ l.105: not test +//│ ║ l.148: not test //│ ║ ^^^^^^^^ //│ ╟── type `~(int -> int)` is not an instance of type `bool` -//│ ║ l.99: def test: ~(int -> int) -//│ ║ ^^^^^^^^^^^^^ +//│ ║ l.142: def test: ~(int -> int) +//│ ║ ^^^^^^^^^^^^^ //│ ╟── but it flows into reference with expected type `bool` -//│ ║ l.105: not test +//│ ║ l.148: not test //│ ╙── ^^^^ //│ res: bool | error diff --git a/shared/src/test/diff/fcp/Overloads_Precise.mls b/shared/src/test/diff/fcp/Overloads_Precise.mls new file mode 100644 index 0000000000..d24fc5dfc0 --- /dev/null +++ b/shared/src/test/diff/fcp/Overloads_Precise.mls @@ -0,0 +1,197 @@ + +:NoJS +:NoApproximateOverload + +type IISS = int -> int & string -> string +type BBNN = bool -> bool & number -> number +type ZZII = 0 -> 0 & int -> int +//│ Defined type alias IISS +//│ Defined type alias BBNN +//│ Defined type alias ZZII + +def IISS: int -> int & string -> string +def BBNN: bool -> bool & number -> number +def ZZII: 0 -> 0 & int -> int +//│ IISS: int -> int & string -> string +//│ BBNN: bool -> bool & number -> number +//│ ZZII: 0 -> 0 & int -> int + + +IISS : IISS +//│ res: IISS + +IISS : int -> int & string -> string +//│ res: int -> int & string -> string + +IISS : IISS | BBNN +//│ res: BBNN | IISS + +:e +IISS : ZZII +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.30: IISS : ZZII +//│ ║ ^^^^ +//│ ╟── type `int -> int & string -> string` is not an instance of `0 -> 0` +//│ ║ l.12: def IISS: int -> int & string -> string +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `0 -> 0` +//│ ║ l.30: IISS : ZZII +//│ ║ ^^^^ +//│ ╟── Note: constraint arises from function type: +//│ ║ l.7: type ZZII = 0 -> 0 & int -> int +//│ ║ ^^^^^^ +//│ ╟── from type reference: +//│ ║ l.30: IISS : ZZII +//│ ╙── ^^^^ +//│ res: ZZII + +:e +IISS : BBNN +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.49: IISS : BBNN +//│ ║ ^^^^ +//│ ╟── type `int -> int & string -> string` is not an instance of `bool -> bool` +//│ ║ l.12: def IISS: int -> int & string -> string +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `bool -> bool` +//│ ║ l.49: IISS : BBNN +//│ ║ ^^^^ +//│ ╟── Note: constraint arises from function type: +//│ ║ l.6: type BBNN = bool -> bool & number -> number +//│ ║ ^^^^^^^^^^^^ +//│ ╟── from type reference: +//│ ║ l.49: IISS : BBNN +//│ ╙── ^^^^ +//│ res: BBNN + + +// * These tests show that we currently throw away information when constraining LHS overloading sets: + +IISS : int -> int +//│ res: int -> int + +IISS : (0 | 1) -> number +//│ res: (0 | 1) -> number + +IISS : 'a -> 'a +//│ res: 'a -> 'a +//│ where +//│ [-'a, +'a] in {[int, int], [string, string]} + +IISS 0 +//│ res: int + +(IISS : int -> int) 0 +//│ res: int + +(if true then IISS else BBNN) 0 +//│ res: number + +// * Note that this is not considered ambiguous +// * because the type variable occurrences are polar, +// * meaning that the TSCs are always trivially satisfiable +// * and thus the code is well-typed. +// * Conceptually, we'd expect this inferred type to reduce to `int -> number`, +// * but it's tricky to do such simplifications in general. +def f = fun x -> (if true then IISS else BBNN) x +//│ f: 'a -> 'b +//│ where +//│ [+'a, -'b] in {[int, int], [string, string]} +//│ [+'a, -'b] in {[bool, bool], [number, number]} + +f(0) +//│ res: number + +:e +f(0) + 1 +//│ ╔══[ERROR] Type mismatch in operator application: +//│ ║ l.106: f(0) + 1 +//│ ║ ^^^^^^ +//│ ╟── type `number` is not an instance of type `int` +//│ ║ l.13: def BBNN: bool -> bool & number -> number +//│ ║ ^^^^^^ +//│ ╟── but it flows into application with expected type `int` +//│ ║ l.106: f(0) + 1 +//│ ╙── ^^^^ +//│ res: error | int + +f : int -> number +//│ res: int -> number + +:e +f : number -> int +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.122: f : number -> int +//│ ║ ^ +//│ ╟── type `number` does not match type `?a` +//│ ║ l.122: f : number -> int +//│ ╙── ^^^^^^ +//│ res: number -> int + + +if true then IISS else BBNN +//│ res: bool -> bool & number -> number | int -> int & string -> string + +(if true then IISS else ZZII) : int -> int +//│ res: int -> int + +(if true then IISS else BBNN) : (0 | 1) -> number +//│ res: (0 | 1) -> number + +:e +(if true then IISS else BBNN) : (0 | 1 | true) -> number +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.142: (if true then IISS else BBNN) : (0 | 1 | true) -> number +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── type `int -> int & string -> string` is not an instance of `(0 | 1 | true) -> number` +//│ ║ l.12: def IISS: int -> int & string -> string +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `(0 | 1 | true) -> number` +//│ ║ l.142: (if true then IISS else BBNN) : (0 | 1 | true) -> number +//│ ║ ^^^^ +//│ ╟── Note: constraint arises from function type: +//│ ║ l.142: (if true then IISS else BBNN) : (0 | 1 | true) -> number +//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^ +//│ res: (0 | 1 | true) -> number + + +// * Note that type normalization used to be very aggressive at approximating non-tag type negations, +// * to simplify the result, but this was changed as it was unsound + +def test: ~(int -> int) +//│ test: ~(int -> int) + +// * See also test file BooleanFail.mls about this previous unsoundness +:e +test = 42 +not test +//│ 42 +//│ <: test: +//│ ~(int -> int) +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.167: not test +//│ ║ ^^^^^^^^ +//│ ╟── type `~(int -> int)` is not an instance of type `bool` +//│ ║ l.161: def test: ~(int -> int) +//│ ║ ^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `bool` +//│ ║ l.167: not test +//│ ╙── ^^^^ +//│ res: bool | error + +def test: ~(int -> int) & ~bool +//│ test: ~bool & ~(int -> int) + +def test: ~(int -> int) & bool +//│ test: bool + +def test: ~(int -> int) & ~(bool -> bool) +//│ test: ~(nothing -> (bool | int)) + +def test: ~(int -> int | bool -> bool) +//│ test: ~(nothing -> (bool | int)) + +def test: ~(int -> int & string -> string) & ~(bool -> bool & number -> number) +//│ test: in ~(nothing -> (number | string) & int -> number & nothing -> (bool | string) & nothing -> (bool | int)) out ~(nothing -> (bool | int) & nothing -> (bool | string) & int -> number & nothing -> (number | string)) + + diff --git a/shared/src/test/diff/nu/HeungTung.mls b/shared/src/test/diff/nu/HeungTung.mls index cace8c821d..3785f32dd9 100644 --- a/shared/src/test/diff/nu/HeungTung.mls +++ b/shared/src/test/diff/nu/HeungTung.mls @@ -1,4 +1,5 @@ :NewDefs +:NoApproximateOverload @@ -66,8 +67,21 @@ fun g = h //│ fun g: (Int | false | true) -> (Int | false | true) // * In one step +:e // TODO: argument of union type fun g: (Int | Bool) -> (Int | Bool) fun g = f +//│ ╔══[ERROR] Type mismatch in definition: +//│ ║ l.72: fun g = f +//│ ║ ^^^^^ +//│ ╟── type `Int -> Int & Bool -> Bool` is not an instance of `(Int | false | true) -> (Int | false | true)` +//│ ║ l.51: fun f: (Int -> Int) & (Bool -> Bool) +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `(Int | false | true) -> (Int | false | true)` +//│ ║ l.72: fun g = f +//│ ║ ^ +//│ ╟── Note: constraint arises from function type: +//│ ║ l.71: fun g: (Int | Bool) -> (Int | Bool) +//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ //│ fun g: Int -> Int & Bool -> Bool //│ fun g: (Int | false | true) -> (Int | false | true) @@ -88,9 +102,17 @@ fun j = i fun j: (Int & Bool) -> (Int & Bool) fun j = f //│ ╔══[ERROR] Type mismatch in definition: -//│ ║ l.89: fun j = f -//│ ║ ^^^^^ -//│ ╙── expression of type `Int` does not match type `nothing` +//│ ║ l.103: fun j = f +//│ ║ ^^^^^ +//│ ╟── type `Int -> Int & Bool -> Bool` is not an instance of `nothing -> nothing` +//│ ║ l.51: fun f: (Int -> Int) & (Bool -> Bool) +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `nothing -> nothing` +//│ ║ l.103: fun j = f +//│ ║ ^ +//│ ╟── Note: constraint arises from function type: +//│ ║ l.102: fun j: (Int & Bool) -> (Int & Bool) +//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ //│ fun j: Int -> Int & Bool -> Bool //│ fun j: nothing -> nothing @@ -106,7 +128,7 @@ fun g = f // * With match-type-based constraint solving, we could return Int here f(0) -//│ Int | false | true +//│ Int //│ res //│ = 0 @@ -114,15 +136,26 @@ f(0) x => f(x) -//│ (Int | false | true) -> (Int | false | true) +//│ forall 'a 'b. 'a -> 'b +//│ where +//│ [+'a, -'b] in {[Int, Int], [Bool, Bool]} //│ res //│ = [Function: res] // : forall 'a: 'a -> case 'a of { Int => Int; Bool => Bool } where 'a <: Int | Bool - +:e f(if true then 0 else false) -//│ Int | false | true +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.148: f(if true then 0 else false) +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── type `Int -> Int & Bool -> Bool` is not an instance of `(0 | false) -> ?a` +//│ ║ l.51: fun f: (Int -> Int) & (Bool -> Bool) +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `(0 | false) -> ?a` +//│ ║ l.148: f(if true then 0 else false) +//│ ╙── ^ +//│ error //│ res //│ = 0 @@ -132,15 +165,25 @@ f(if true then 0 else false) :w f(refined if true then 0 else false) // this one can be precise again! //│ ╔══[WARNING] Paren-less applications should use the 'of' keyword -//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ l.166: f(refined if true then 0 else false) // this one can be precise again! //│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ //│ ╔══[ERROR] Illegal use of reserved operator: refined -//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ l.166: f(refined if true then 0 else false) // this one can be precise again! //│ ╙── ^^^^^^^ //│ ╔══[ERROR] identifier not found: refined -//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ l.166: f(refined if true then 0 else false) // this one can be precise again! //│ ╙── ^^^^^^^ -//│ Int | false | true +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.166: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── application of type `error` does not match type `?a` +//│ ║ l.166: f(refined if true then 0 else false) // this one can be precise again! +//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ 'a +//│ where +//│ 'b :> error +//│ 'a :> error +//│ [+'b, -'a] in {} //│ Code generation encountered an error: //│ unresolved symbol refined @@ -196,7 +239,7 @@ type T = List[Int] :e // TODO application types type Res = M(T) //│ ╔══[ERROR] Wrong number of type arguments – expected 0, found 1 -//│ ║ l.197: type Res = M(T) +//│ ║ l.240: type Res = M(T) //│ ╙── ^^^^ //│ type Res = M @@ -219,7 +262,7 @@ fun f: Int -> Int fun f: Bool -> Bool fun f = id //│ ╔══[ERROR] A type signature for 'f' was already given -//│ ║ l.219: fun f: Bool -> Bool +//│ ║ l.262: fun f: Bool -> Bool //│ ╙── ^^^^^^^^^^^^^^^^^^^ //│ fun f: forall 'a. 'a -> 'a //│ fun f: Int -> Int @@ -227,13 +270,13 @@ fun f = id :e // TODO support f: (Int -> Int) & (Bool -> Bool) //│ ╔══[ERROR] Type mismatch in type ascription: -//│ ║ l.228: f: (Int -> Int) & (Bool -> Bool) +//│ ║ l.271: f: (Int -> Int) & (Bool -> Bool) //│ ║ ^ //│ ╟── type `Bool` is not an instance of type `Int` -//│ ║ l.228: f: (Int -> Int) & (Bool -> Bool) +//│ ║ l.271: f: (Int -> Int) & (Bool -> Bool) //│ ║ ^^^^ //│ ╟── Note: constraint arises from type reference: -//│ ║ l.218: fun f: Int -> Int +//│ ║ l.261: fun f: Int -> Int //│ ╙── ^^^ //│ Int -> Int & Bool -> Bool //│ res @@ -300,17 +343,17 @@ fun test(x) = refined if x is A then 0 B then 1 //│ ╔══[WARNING] Paren-less applications should use the 'of' keyword -//│ ║ l.299: fun test(x) = refined if x is +//│ ║ l.342: fun test(x) = refined if x is //│ ║ ^^^^^^^^^^^^^^^ -//│ ║ l.300: A then 0 +//│ ║ l.343: A then 0 //│ ║ ^^^^^^^^^^ -//│ ║ l.301: B then 1 +//│ ║ l.344: B then 1 //│ ╙── ^^^^^^^^^^ //│ ╔══[ERROR] Illegal use of reserved operator: refined -//│ ║ l.299: fun test(x) = refined if x is +//│ ║ l.342: fun test(x) = refined if x is //│ ╙── ^^^^^^^ //│ ╔══[ERROR] identifier not found: refined -//│ ║ l.299: fun test(x) = refined if x is +//│ ║ l.342: fun test(x) = refined if x is //│ ╙── ^^^^^^^ //│ fun test: (A | B) -> error //│ Code generation encountered an error: @@ -320,3 +363,320 @@ fun test(x) = refined if x is +fun q: (0|1) -> true & (1|2) -> false +//│ fun q: (0 | 1) -> true & (1 | 2) -> false + +q(0) +//│ true +//│ res +//│ = +//│ q is not implemented + +q(0) : true +//│ true +//│ res +//│ = +//│ q is not implemented + +q(1) +//│ 'a +//│ where +//│ [-'a] in {[true], [false]} +//│ res +//│ = +//│ q is not implemented + +q(1) : Bool +//│ Bool +//│ res +//│ = +//│ q is not implemented + +x => q(x): true +//│ (0 | 1) -> true +//│ res +//│ = +//│ q is not implemented + +x => q(x) +//│ forall 'a 'b. 'a -> 'b +//│ where +//│ [+'a, -'b] in {[0 | 1, true], [1 | 2, false]} +//│ res +//│ = +//│ q is not implemented + +:e +(x => q(x))(1):Int +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.410: (x => q(x))(1):Int +//│ ║ ^^^^^^^^^^^^^^ +//│ ╟── application of type `?a` does not match type `Int` +//│ ╟── Note: constraint arises from type reference: +//│ ║ l.410: (x => q(x))(1):Int +//│ ╙── ^^^ +//│ Int +//│ res +//│ = +//│ q is not implemented + +:e +q(1):int +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.424: q(1):int +//│ ║ ^^^^ +//│ ╟── application of type `?a` does not match type `int` +//│ ╟── Note: constraint arises from type reference: +//│ ║ l.424: q(1):int +//│ ╙── ^^^ +//│ int +//│ res +//│ = +//│ q is not implemented + +fun w = x => q(x) +//│ fun w: forall 'a 'b. 'a -> 'b +//│ where +//│ [+'a, -'b] in {[0 | 1, true], [1 | 2, false]} + +w(0) +//│ true +//│ res +//│ = +//│ w and q are not implemented + +x => (f: forall a: ((0, Int) -> 'a & (1, Str) -> ['a])) => f(0, x) + 1 +//│ Int -> (f: (0, Int) -> Int & (1, Str) -> [Int]) -> Int +//│ res +//│ = [Function: res] + +fun r: Int -> Int & Bool -> Bool +//│ fun r: Int -> Int & Bool -> Bool + +:e +x => r(r(x)) +//│ ╔══[ERROR] ambiguous +//│ ╟── cannot determine satisfiability of type ?a +//│ ║ l.457: x => r(r(x)) +//│ ╙── ^^^^ +//│ forall 'a 'b 'c. 'a -> 'c +//│ where +//│ [+'a, -'b] in {[Int, Int], [Bool, Bool]} +//│ [-'c, +'b] in {[Int, Int], [Bool, Bool]} +//│ res +//│ = +//│ r is not implemented + + +r(r(0)) +//│ Int +//│ res +//│ = +//│ r is not implemented + +x => r(r(x))+1 +//│ Int -> Int +//│ res +//│ = +//│ r is not implemented + +fun u: {x:0, y:Int} -> Int & {x:1, z: Str} -> Str +//│ fun u: {x: 0, y: Int} -> Int & {x: 1, z: Str} -> Str + +(a, b, c) => u({x: a, y: b, z: c}) +//│ forall 'a 'b 'c 'd. ('a, 'c, 'd) -> 'b +//│ where +//│ [-'b, +'a, +'c, +'d] in {[Int, 0, Int, anything], [Str, 1, anything, Str]} +//│ res +//│ = +//│ u is not implemented + +(a, b) => u({x: a, y: "abc", z: b}) +//│ (1, Str) -> Str +//│ res +//│ = +//│ u is not implemented + +fun s: Str -> Str & AA -> AA +//│ fun s: Str -> Str & AA -> AA + +:e +let g = x => s(r(x)) +//│ ╔══[ERROR] ambiguous +//│ ╟── cannot determine satisfiability of type ?a +//│ ║ l.504: let g = x => s(r(x)) +//│ ╙── ^^^^ +//│ let g: forall 'a 'b 'c. 'a -> 'c +//│ where +//│ [+'a, -'b] in {[Int, Int], [Bool, Bool]} +//│ [+'b, -'c] in {[Str, Str], [AA, AA]} +//│ g +//│ = +//│ s is not implemented + +:e +fun g(x) = s(r(x)) +//│ ╔══[ERROR] ambiguous +//│ ╟── cannot determine satisfiability of type ?a +//│ ║ l.518: fun g(x) = s(r(x)) +//│ ╙── ^^^^ +//│ fun g: forall 'a 'b 'c. 'a -> 'c +//│ where +//│ [+'a, -'b] in {[Int, Int], [Bool, Bool]} +//│ [-'c, +'b] in {[Str, Str], [AA, AA]} + +:e +x => s(r(x)) +//│ ╔══[ERROR] ambiguous +//│ ╟── cannot determine satisfiability of type ?a +//│ ║ l.529: x => s(r(x)) +//│ ╙── ^^^^ +//│ forall 'a 'b 'c. 'a -> 'c +//│ where +//│ [+'a, -'b] in {[Int, Int], [Bool, Bool]} +//│ [-'c, +'b] in {[Str, Str], [AA, AA]} +//│ res +//│ = +//│ s is not implemented + +:e +g(0) +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.543: g(0) +//│ ║ ^^^^ +//│ ╟── expression of type `Int` does not match type `?a` +//│ ╟── Note: constraint arises from application: +//│ ║ l.518: fun g(x) = s(r(x)) +//│ ╙── ^^^^ +//│ error +//│ res +//│ = +//│ g and s are not implemented + +fun rt: {0: Int} -> Int & {0: Str} -> Str +//│ fun rt: {0: Int} -> Int & {0: Str} -> Str + +rt([1,"str"]) +//│ Int +//│ res +//│ = +//│ rt is not implemented + +rt(["str",1]) +//│ Str +//│ res +//│ = +//│ rt is not implemented + +fun app2: ('a -> 'a -> 'a) -> 'a -> 'a +//│ fun app2: forall 'a. ('a -> 'a -> 'a) -> 'a -> 'a + +fun snd: A -> Int -> Int & Str -> Str -> Str +//│ fun snd: A -> Int -> Int & Str -> Str -> Str + +:e +x => app2(snd)(x):Int +//│ ╔══[ERROR] Type mismatch in type ascription: +//│ ║ l.578: x => app2(snd)(x):Int +//│ ║ ^^^^^^^^^^^^ +//│ ╟── type `Int` is not an instance of type `A` +//│ ║ l.571: fun app2: ('a -> 'a -> 'a) -> 'a -> 'a +//│ ║ ^^ +//│ ╟── Note: constraint arises from type reference: +//│ ║ l.574: fun snd: A -> Int -> Int & Str -> Str -> Str +//│ ╙── ^ +//│ nothing -> Int +//│ res +//│ = +//│ app2 is not implemented + +fun app2_ (f:'a -> 'a -> 'a)(x) = f(x)(x) +//│ fun app2_: forall 'a. (f: 'a -> 'a -> 'a) -> 'a -> 'a + +app2_(snd) +//│ 'a -> 'b +//│ where +//│ 'a <: 'b +//│ [-'b, -'a, +'a] in {[Int, Int, A & Int], [Str, Str, Str]} +//│ res +//│ = +//│ snd is not implemented + +// * Example from WeirdUnions.mls. +// * This type merges the input tuples: +fun f: (Str => Str) & ((Str, Int) => Str) +//│ fun f: (...Array[Int | Str] & {0: Str}) -> Str + +f("abc", "abc") +//│ Str +//│ res +//│ = +//│ f is not implemented + +fun f: (Str => Str) & ((Str, Int) => Int) +//│ fun f: Str -> Str & (Str, Int) -> Int + +// * Different from WeirdUnions.mls: +:e +f("abc", "abc") +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.621: f("abc", "abc") +//│ ║ ^^^^^^^^^^^^^^^ +//│ ╟── type `Str -> Str & (Str, Int) -> Int` is not an instance of `("abc", "abc") -> ?a` +//│ ║ l.616: fun f: (Str => Str) & ((Str, Int) => Int) +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `("abc", "abc") -> ?a` +//│ ║ l.621: f("abc", "abc") +//│ ╙── ^ +//│ error +//│ res +//│ = +//│ f is not implemented + +f("abcabc") +//│ Str +//│ res +//│ = +//│ f is not implemented + +:e +x => rt([not(x)]) +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.643: x => rt([not(x)]) +//│ ║ ^^^^^^^^^^^^ +//│ ╟── application of type `Bool` does not match type `?a` +//│ ║ l.643: x => rt([not(x)]) +//│ ╙── ^^^^^^ +//│ forall 'a 'b. Bool -> 'a +//│ where +//│ 'b :> Bool +//│ 'a :> error +//│ [-'a, +'b] in {} +//│ res +//│ = +//│ rt is not implemented + +:e +rt(0) +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.660: rt(0) +//│ ║ ^^^^^ +//│ ╟── type `{0: Int} -> Int & {0: Str} -> Str` is not an instance of `0 -> ?a` +//│ ║ l.556: fun rt: {0: Int} -> Int & {0: Str} -> Str +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── but it flows into reference with expected type `0 -> ?a` +//│ ║ l.660: rt(0) +//│ ╙── ^^ +//│ error +//│ res +//│ = +//│ rt is not implemented + +fun z: {0:Int} -> nothing & Str -> Str +//│ fun z: {0: Int} -> nothing & Str -> Str + +z([1]) +//│ nothing +//│ res +//│ = +//│ z is not implemented diff --git a/shared/src/test/scala/mlscript/DiffTests.scala b/shared/src/test/scala/mlscript/DiffTests.scala index 5e3076d28f..6965d73c28 100644 --- a/shared/src/test/scala/mlscript/DiffTests.scala +++ b/shared/src/test/scala/mlscript/DiffTests.scala @@ -228,6 +228,7 @@ class DiffTests(state: DiffTests.State) // Enable this to see the errors from unfinished `PreTyper`. var showPreTyperErrors = false var noTailRec = false + var noApproximateOverload = false // * This option makes some test cases pass which assume generalization should happen in arbitrary arguments // * but it's way too aggressive to be ON by default, as it leads to more extrusion, cycle errors, etc. @@ -299,6 +300,7 @@ class DiffTests(state: DiffTests.State) case "GeneralizeArguments" => generalizeArguments = true; mode case "DontGeneralizeArguments" => generalizeArguments = false; mode case "IrregularTypes" => irregularTypes = true; mode + case "NoApproximateOverload" => noApproximateOverload = true; mode case str @ "Fuel" => // println("'"+line.drop(str.length + 2)+"'") typer.startingFuel = line.drop(str.length + 2).toInt; mode @@ -559,6 +561,7 @@ class DiffTests(state: DiffTests.State) typer.explainErrors = mode.explainErrors stdout = mode.stdout typer.preciselyTypeRecursion = mode.preciselyTypeRecursion + typer.noApproximateOverload = noApproximateOverload val oldCtx = ctx @@ -588,7 +591,7 @@ class DiffTests(state: DiffTests.State) exp match { // * Strip top-level implicitly-quantified type variables case pt: PolyType => stripPoly(pt) - case Constrained(pt: PolyType, bs, cs) => Constrained(stripPoly(pt), bs, cs) + case Constrained(pt: PolyType, bs, cs, tscs) => Constrained(stripPoly(pt), bs, cs, tscs) case ty => ty } }