Skip to content

Commit

Permalink
Constraint solving for function overloading (#213)
Browse files Browse the repository at this point in the history
Co-authored-by: Lionel Parreaux <[email protected]>
  • Loading branch information
auht and LPTK authored Oct 4, 2024
1 parent c389926 commit 5841627
Show file tree
Hide file tree
Showing 14 changed files with 969 additions and 92 deletions.
16 changes: 12 additions & 4 deletions compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
val nlhs = liftType(lb)
val nrhs = liftType(ub)
Bounds(nlhs._1, nrhs._1) -> (nlhs._2 ++ nrhs._2)
case Constrained(base: Type, bounds, where) =>
case Constrained(base: Type, bounds, where, tscs) =>
val (nTargs, nCtx) = bounds.map { case (tv, Bounds(lb, ub)) =>
val nlhs = liftType(lb)
val nrhs = liftType(ub)
Expand All @@ -521,10 +521,18 @@ class ClassLifter(logDebugMsg: Boolean = false) {
val nrhs = liftType(ub)
Bounds(nlhs._1, nrhs._1) -> (nlhs._2 ++ nrhs._2)
}.unzip
val (tscs0, nCtx3) = tscs.map { case (tvs, cs) =>
val (ntvs,c0) = tvs.map { case (p,v) =>
val (nv, c) = liftType(v)
(p,nv) -> c
}.unzip
val (ncs,c1) = cs.map(_.map(liftType).unzip).unzip
(ntvs,ncs) -> (c0 ++ c1.flatten)
}.unzip
val (nBase, bCtx) = liftType(base)
Constrained(nBase, nTargs, bounds2) ->
((nCtx ++ nCtx2).fold(emptyCtx)(_ ++ _) ++ bCtx)
case Constrained(_, _, _) => die
Constrained(nBase, nTargs, bounds2, tscs0) ->
((nCtx ++ nCtx2 ++ nCtx3.flatten).fold(emptyCtx)(_ ++ _) ++ bCtx)
case Constrained(_, _, _, _) => die
case Function(lhs, rhs) =>
val nlhs = liftType(lhs)
val nrhs = liftType(rhs)
Expand Down
88 changes: 88 additions & 0 deletions shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,35 @@ class ConstraintSolver extends NormalForms { self: Typer =>
recLb(ar.inner, b.inner)
rec(b.inner.ub, ar.inner.ub, false)
case (LhsRefined(S(b: ArrayBase), ts, r, _), _) => reportError()
case (LhsRefined(S(ov: Overload), ts, r, trs), RhsBases(_, S(L(f: FunctionType)), _)) if noApproximateOverload =>
TupleSetConstraints.mk(ov, f) match {
case S(tsc) =>
if (tsc.tvs.nonEmpty) {
tsc.tvs.mapValuesIter(_.unwrapProxies).zipWithIndex.flatMap {
case ((true, tv: TV), i) => tv.lowerBounds.iterator.map((_,tv,i,true))
case ((false, tv: TV), i) => tv.upperBounds.iterator.map((_,tv,i,false))
case _ => Nil
}.find {
case (b,_,i,_) =>
tsc.updateImpl(i,b)
tsc.constraints.isEmpty
}.foreach {
case (b,tv,_,p) => if (p) rec(b,tv,false) else rec(tv,b,false)
}
if (tsc.constraints.sizeCompare(1) === 0) {
tsc.tvs.values.map(_.unwrapProxies).foreach {
case tv: TV => tv.tsc.remove(tsc)
case _ => ()
}
tsc.constraints.head.iterator.zip(tsc.tvs).foreach {
case (c, (pol, t)) =>
if (!pol) rec(c, t, false)
if (pol) rec(t, c, false)
}
}
}
case N => reportError(S(msg"is not an instance of `${f.expNeg}`"))
}
case (LhsRefined(S(ov: Overload), ts, r, trs), _) =>
annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints
case (LhsRefined(S(Without(b, ns)), ts, r, _), RhsBases(pts, N | S(L(_)), _)) =>
Expand Down Expand Up @@ -832,13 +861,57 @@ class ConstraintSolver extends NormalForms { self: Typer =>
val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
lhs.upperBounds ::= newBound // update the bound
if (noApproximateOverload) {
lhs.tsc.foreachEntry { (tsc, v) =>
v.foreach { i =>
if (!tsc.tvs(i)._1) {
tsc.updateOn(i, rhs)
if (tsc.constraints.isEmpty) reportError()
}
}
}
val u = lhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
u._1.foreach { k =>
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case _ => ()
}
}
u._2.foreach { k =>
k.constraints.head.iterator.zip(k.tvs).foreach {
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
}
}
}
lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound

case (lhs, rhs: TypeVariable) if lhs.level <= rhs.level =>
println(s"NEW $rhs LB (${lhs.level})")
val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
rhs.lowerBounds ::= newBound // update the bound
if (noApproximateOverload) {
rhs.tsc.foreachEntry { (tsc, v) =>
v.foreach { i =>
if(tsc.tvs(i)._1) {
tsc.updateOn(i, lhs)
if (tsc.constraints.isEmpty) reportError()
}
}
}
val u = rhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
u._1.foreach { k =>
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case _ => ()
}
}
u._2.foreach { k =>
k.constraints.head.iterator.zip(k.tvs).foreach {
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
}
}
}
rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound


Expand Down Expand Up @@ -1562,9 +1635,24 @@ class ConstraintSolver extends NormalForms { self: Typer =>
assert(lvl <= below, "this condition should be false for the result to be correct")
lvl
})
val freshentsc = tv.tsc.flatMap { case (tsc,_) =>
if (tsc.tvs.values.map(_.unwrapProxies).forall {
case tv: TV => !freshened.contains(tv)
case _ => true
}) S(tsc) else N
}
freshened += tv -> v
v.lowerBounds = tv.lowerBounds.mapConserve(freshen)
v.upperBounds = tv.upperBounds.mapConserve(freshen)
freshentsc.foreach { tsc =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
t.constraints = t.constraints.map(_.map(freshen))
t.tvs = t.tvs.map(x => (x._1,freshen(x._2)))
t.tvs.values.map(_.unwrapProxies).zipWithIndex.foreach {
case (tv: TV, i) => tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
case _ => ()
}
}
v
}

Expand Down
14 changes: 13 additions & 1 deletion shared/src/main/scala/mlscript/NuTypeDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,19 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
ctx.nextLevel { implicit ctx: Ctx =>
assert(fd.tparams.sizeCompare(tparamsSkolems) === 0, (fd.tparams, tparamsSkolems))
vars ++ tparamsSkolems |> { implicit vars =>
typeTerm(body)
val ty = typeTerm(body)
if (noApproximateOverload) {
val ambiguous = ty.getVars.unsorted.flatMap(_.tsc.keys.flatMap(_.tvs))
.groupBy(_._2)
.filter { case (v,pvs) => pvs.sizeIs > 1 }
if (ambiguous.nonEmpty) raise(ErrorReport(
msg"ambiguous" -> N ::
ambiguous.map { case (v,_) =>
msg"cannot determine satisfiability of type ${v.expPos}" -> v.prov.loco
}.toList
, true))
}
ty
}
}
} else {
Expand Down
48 changes: 41 additions & 7 deletions shared/src/main/scala/mlscript/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ trait TypeSimplifier { self: Typer =>
println(s"allVarPols: ${printPols(allVarPols)}")

val renewed = MutMap.empty[TypeVariable, TypeVariable]
val renewedtsc = MutMap.empty[TupleSetConstraints, TupleSetConstraints]

def renew(tv: TypeVariable): TypeVariable =
renewed.getOrElseUpdate(tv,
Expand Down Expand Up @@ -78,8 +79,17 @@ trait TypeSimplifier { self: Typer =>
).map(process(_, S(false -> tv)))
.reduceOption(_ &- _).filterNot(_.isTop).toList
else Nil
if (noApproximateOverload)
nv.tsc ++= tv.tsc.iterator.map { case (tsc, i) => renewedtsc.get(tsc) match {
case S(tsc) => (tsc, i)
case N if inPlace => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewedtsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, process(x._2, N)))
(t, i)
}}
}

nv

case ComposedType(true, l, r) =>
Expand Down Expand Up @@ -549,9 +559,17 @@ trait TypeSimplifier { self: Typer =>
analyzed1.setAndIfUnset(tv -> pol(tv).getOrElse(false)) { apply(pol)(ty) }
case N =>
if (pol(tv) =/= S(false))
analyzed1.setAndIfUnset(tv -> true) { tv.lowerBounds.foreach(apply(pol.at(tv.level, true))) }
analyzed1.setAndIfUnset(tv -> true) {
tv.lowerBounds.foreach(apply(pol.at(tv.level, true)))
if (noApproximateOverload)
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
}
if (pol(tv) =/= S(true))
analyzed1.setAndIfUnset(tv -> false) { tv.upperBounds.foreach(apply(pol.at(tv.level, false))) }
analyzed1.setAndIfUnset(tv -> false) {
tv.upperBounds.foreach(apply(pol.at(tv.level, false)))
if (noApproximateOverload)
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
}
}
case _ =>
super.apply(pol)(st)
Expand Down Expand Up @@ -643,8 +661,11 @@ trait TypeSimplifier { self: Typer =>
case tv: TypeVariable =>
pol(tv) match {
case S(pol_tv) =>
if (analyzed2.add(pol_tv -> tv))
if (analyzed2.add(pol_tv -> tv)) {
processImpl(st, pol, pol_tv)
if (noApproximateOverload)
tv.tsc.keys.flatMap(_.tvs).foreach(u => processImpl(u._2,pol.at(tv.level,u._1),pol_tv))
}
case N =>
if (analyzed2.add(true -> tv))
// * To compute the positive co-occurrences
Expand Down Expand Up @@ -690,6 +711,7 @@ trait TypeSimplifier { self: Typer =>
case S(p) =>
(if (p) tv2.lowerBounds else tv2.upperBounds).foreach(go)
// (if (p) getLbs(tv2) else getUbs(tv2)).foreach(go)
if (noApproximateOverload) tv2.tsc.keys.flatMap(_.tvs).foreach(u => go(u._2))
case N =>
trace(s"Analyzing invar-occ of $tv2") {
analyze2(tv2, pol)
Expand Down Expand Up @@ -789,7 +811,7 @@ trait TypeSimplifier { self: Typer =>

// * Remove variables that are 'dominated' by another type or variable
// * A variable v dominated by T if T is in both of v's positive and negative cooccurrences
allVars.foreach { case v => if (v.assignedTo.isEmpty && !varSubst.contains(v)) {
allVars.foreach { case v => if (v.assignedTo.isEmpty && !varSubst.contains(v) && v.tsc.isEmpty) {
println(s"2[v] $v ${coOccurrences.get(true -> v)} ${coOccurrences.get(false -> v)}")

coOccurrences.get(true -> v).iterator.flatMap(_.iterator).foreach {
Expand All @@ -807,6 +829,7 @@ trait TypeSimplifier { self: Typer =>

case w: TV if !(w is v) && !varSubst.contains(w) && !varSubst.contains(v) && !recVars(v)
&& coOccurrences.get(false -> v).exists(_(w))
&& w.tsc.isEmpty
=>
// * Here we know that v is 'dominated' by w, so v can be inlined.
// * Note that we don't want to unify the two variables here
Expand All @@ -833,7 +856,7 @@ trait TypeSimplifier { self: Typer =>

// * Unify equivalent variables based on polar co-occurrence analysis:
allVars.foreach { case v =>
if (!v.assignedTo.isDefined && !varSubst.contains(v)) // TODO also handle v.assignedTo.isDefined?
if (!v.assignedTo.isDefined && !varSubst.contains(v) && v.tsc.isEmpty) // TODO also handle v.assignedTo.isDefined?
trace(s"3[v] $v +${coOccurrences.get(true -> v).mkString} -${coOccurrences.get(false -> v).mkString}") {

def go(pol: Bool): Unit = coOccurrences.get(pol -> v).iterator.flatMap(_.iterator).foreach {
Expand All @@ -850,6 +873,7 @@ trait TypeSimplifier { self: Typer =>
)
&& (v.level === w.level)
// ^ Don't merge variables of differing levels
&& w.tsc.isEmpty
=>
trace(s"[w] $w ${printPol(S(pol))}${coOccurrences.get(pol -> w).mkString}") {

Expand Down Expand Up @@ -923,6 +947,7 @@ trait TypeSimplifier { self: Typer =>
println(s"[rec] ${recVars}")

val renewals = MutMap.empty[TypeVariable, TypeVariable]
val renewaltsc = MutMap.empty[TupleSetConstraints, TupleSetConstraints]

val semp = Set.empty[TV]

Expand Down Expand Up @@ -999,7 +1024,7 @@ trait TypeSimplifier { self: Typer =>
nv
})
pol(tv) match {
case S(p) if inlineBounds && !occursInvariantly(tv) && !recVars.contains(tv) =>
case S(p) if inlineBounds && !occursInvariantly(tv) && !recVars.contains(tv) && tv.tsc.isEmpty =>
// * Inline the bounds of non-rec non-invar-occ type variables
println(s"Inlining [${printPol(p)}] bounds of $tv (~> $res)")
// if (p) mergeTransform(true, pol, tv, Set.single(tv), canDistribForall) | res
Expand All @@ -1017,6 +1042,15 @@ trait TypeSimplifier { self: Typer =>
res.lowerBounds = tv.lowerBounds.map(transform(_, pol.at(tv.level, true), Set.single(tv)))
if (occNums.contains(false -> tv))
res.upperBounds = tv.upperBounds.map(transform(_, pol.at(tv.level, false), Set.single(tv)))
if (noApproximateOverload)
res.tsc ++= tv.tsc.map { case (tsc, i) => renewaltsc.get(tsc) match {
case S(tsc) => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewaltsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, transform(x._2, PolMap.neu, Set.empty)))
(t, i)
}}
}
res
}()
Expand Down
Loading

0 comments on commit 5841627

Please sign in to comment.