Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constraint solving for function overloading #203

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,9 @@ class ConstraintSolver extends NormalForms { self: Typer =>
rec(b.inner.ub, ar.inner.ub, false)
case (LhsRefined(S(b: ArrayBase), ts, r, _), _) => reportError()
case (LhsRefined(S(ov: Overload), ts, r, trs), _) =>
annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints
val t = TupleSetConstraints.mk(ov)
annoying(Nil, LhsRefined(S(t), ts, r, trs), Nil, done_rs)
// annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints
case (LhsRefined(S(Without(b, ns)), ts, r, _), RhsBases(pts, N | S(L(_)), _)) =>
rec(b, done_rs.toType(), true)
case (_, RhsBases(pts, S(L(Without(base, ns))), _)) =>
Expand Down Expand Up @@ -818,13 +820,23 @@ class ConstraintSolver extends NormalForms { self: Typer =>
val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
lhs.upperBounds ::= newBound // update the bound
lhs.lbtsc.foreach {
case (tsc, i) =>
tsc.filterUB(i, rhs)
if (tsc.constraints.isEmpty) reportError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably raise a proper error here, which should refer to the failing TSC.

}
lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound

case (lhs, rhs: TypeVariable) if lhs.level <= rhs.level =>
println(s"NEW $rhs LB (${lhs.level})")
val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
rhs.lowerBounds ::= newBound // update the bound
rhs.ubtsc.foreach {
case (tsc, i) =>
tsc.filterLB(i, lhs)
if (tsc.constraints.isEmpty) reportError()
}
rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound


Expand Down
1 change: 0 additions & 1 deletion shared/src/main/scala/mlscript/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ trait TypeSimplifier { self: Typer =>
.reduceOption(_ &- _).filterNot(_.isTop).toList
else Nil
}

auht marked this conversation as resolved.
Show resolved Hide resolved
nv

case ComposedType(true, l, r) =>
Expand Down
112 changes: 111 additions & 1 deletion shared/src/main/scala/mlscript/TyperDatatypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
require(value.forall(_.level <= level))
_assignedTo = value
}

var lbtsc: Opt[(TupleSetConstraints, Int)] = N
var ubtsc: Opt[(TupleSetConstraints, Int)] = N

// * Bounds should always be disregarded when `equatedTo` is defined, as they are then irrelevant:
def lowerBounds: List[SimpleType] = { require(assignedTo.isEmpty, this); _lowerBounds }
Expand Down Expand Up @@ -646,5 +649,112 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
lazy val underlying: SimpleType = tt.neg()
val prov = noProv
}


class TupleSetConstraints(val constraints: MutSet[Ls[ST]], var tvs: Ls[TV])(val prov: TypeProvenance) {
def filterUB(index: Int, ub: ST)(implicit raise: Raise, ctx: Ctx): Unit = {
def go(ub: ST): Unit = ub match {
case ub: TV =>
ub.upperBounds.foreach(go)
ub.lbtsc = S(this, index)
case _ =>
constraints.filterInPlace { constrs =>
val ty = constrs(index)
val dnf = DNF.mk(MaxLevel, Nil, ty & ub.neg(), true)
dnf.isBot || dnf.cs.forall(c => !(c.vars.isEmpty && c.nvars.isEmpty))
}
}
go(ub)
println(s"TSC filterUB: $tvs in $constraints")
if (constraints.sizeCompare(1) === 0) {
constraints.head.zip(tvs).foreach {
case (ty, tv) =>
tv.lbtsc = N
tv.ubtsc = N
constrain(tv, ty)(raise, prov, ctx)
constrain(ty, tv)(raise, prov, ctx)
}
}
}
def filterLB(index: Int, lb: ST)(implicit raise: Raise, ctx: Ctx): Unit = {
constraints.filterInPlace { constrs =>
val ty = constrs(index)
val dnf = DNF.mk(MaxLevel, Nil, lb & ty.neg(), true)
dnf.isBot || dnf.cs.forall(c => !(c.vars.isEmpty && c.nvars.isEmpty))
}
println(s"TSC filterLB: $tvs in $constraints")
if (constraints.sizeCompare(1) === 0) {
constraints.head.zip(tvs).foreach {
case (ty, tv) =>
tv.lbtsc = N
tv.ubtsc = N
constrain(tv, ty)(raise, prov, ctx)
constrain(ty, tv)(raise, prov, ctx)
}
}
}
}
object TupleSetConstraints {
def lcgField(first: FieldType, rest: Ls[FieldType])
(implicit prov: TypeProvenance, lvl: Level)
: (FieldType, Ls[TV], Ls[Ls[ST]]) = {
val (ub, tvs, constrs) = lcg(first.ub, rest.map(_.ub))
if (first.lb.isEmpty && rest.forall(_.lb.isEmpty)) {
(FieldType(N, ub)(prov), tvs, constrs)
} else {
val (lb, ltvs, lconstrs) = lcg(first.lb.getOrElse(BotType), rest.map(_.lb.getOrElse(BotType)))
(FieldType(S(lb), ub)(prov), tvs ++ ltvs, constrs ++ lconstrs)
}
}
def lcg(first: ST, rest: Ls[ST])
(implicit prov: TypeProvenance, lvl: Level)
: (ST, Ls[TV], Ls[Ls[ST]]) = first match {
case a: FunctionType if rest.forall(_.isInstanceOf[FunctionType]) =>
val (lhss, rhss) = rest.collect {
case FunctionType(lhs, rhs) => lhs -> rhs
}.unzip
val (lhs, ltvs, lconstrs) = lcg(a.lhs, lhss)
val (rhs, rtvs, rconstrs) = lcg(a.rhs, rhss)
(FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
case a: ArrayType if rest.forall(_.isInstanceOf[ArrayType]) =>
val inners = rest.collect { case b: ArrayType => b.inner }
val (t, tvs, constrs) = lcgField(a.inner, inners)
(ArrayType(t)(prov), tvs, constrs)
case a: TupleType if rest.forall { case b: TupleType => a.fields.sizeCompare(b.fields.size) === 0; case _ => false } =>
val fields = rest.collect { case TupleType(fields) => fields.map(_._2) }
val (fts, tvss, constrss) = a.fields.map(_._2).zip(fields.transpose).map { case (a, bs) => lcgField(a, bs) }.unzip3
(TupleType(fts.map(N -> _))(prov), tvss.flatten, constrss.flatten)
case a: TR if rest.forall { case b: TR => a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0; case _ => false } =>
val targs = rest.collect { case b: TR => b.targs }
val (ts, tvss, constrss) = a.targs.zip(targs.transpose).map { case (a, bs) => lcg(a, bs) }.unzip3
(TypeRef(a.defn, ts)(prov), tvss.flatten, constrss.flatten)
case a: TV if rest.forall { case b: TV => a.compare(b) === 0; case _ => false } => (a, Nil, Nil)
case a if rest.forall(_ === a) => (a, Nil, Nil)
case _ =>
val tv = freshVar(prov, N)
(tv, List(tv), List(first :: rest))
}
def lcgFunction(first: FunctionType, rest: Ls[FunctionType])(implicit prov: TypeProvenance, lvl: Level)
: (FunctionType, Ls[TV], Ls[Ls[ST]]) = {
val (lhss, rhss) = rest.map {
case FunctionType(lhs, rhs) => lhs -> rhs
}.unzip
val (lhs, ltvs, lconstrs) = lcg(first.lhs, lhss)
val (rhs, rtvs, rconstrs) = lcg(first.rhs, rhss)
(FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
}
def mk(ov: Overload)(implicit lvl: Level): FunctionType = {
def unwrap(t: ST): ST = t.map(unwrap)
if (ov.alts.tail.isEmpty) ov.alts.head else {
val f = ov.mapAlts(unwrap)(unwrap)
val (t, tvs, constrs) = lcgFunction(f.alts.head, f.alts.tail)(ov.prov, lvl)
val tsc = new TupleSetConstraints(MutSet.empty ++ constrs.transpose, tvs)(ov.prov)
tvs.zipWithIndex.foreach { case (tv, i) =>
tv.lbtsc = S((tsc, i))
tv.ubtsc = S((tsc, i))
}
println(s"TSC mk: ${tsc.tvs} in ${tsc.constraints}")
t
}
}
}
}
19 changes: 15 additions & 4 deletions shared/src/main/scala/mlscript/TyperHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -983,13 +983,24 @@ abstract class TyperHelpers { Typer: Typer =>
def getVars: SortedSet[TypeVariable] = getVarsImpl(includeBounds = true)

def showBounds: String =
getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty).map {
getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty || (tv.lbtsc.isDefined && tv.ubtsc.isEmpty)).map {
case tv @ AssignedVariable(ty) => "\n\t\t" + tv.toString + " := " + ty
case tv => ("\n\t\t" + tv.toString
+ (if (tv.lowerBounds.isEmpty) "" else " :> " + tv.lowerBounds.mkString(" | "))
+ (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & ")))
}.mkString

+ (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & "))
+ tv.lbtsc.fold(""){ case (tsc, i) => " :> " + tsc.tvs(i) } )
}.mkString + {
val visited: MutSet[TV] = MutSet.empty
getVars.iterator.filter(tv => tv.ubtsc.isDefined).map {
case tv if visited.contains(tv) => ""
case tv =>
visited ++= tv.lbtsc.fold(Nil: Ls[TV])(_._1.tvs)
tv.lbtsc.fold("") { case (tsc, _) => ("\n\t\t[ "
+ tsc.tvs.mkString(", ")
+ " ] in { " + tsc.constraints.mkString(", ") + " }")
}
}.mkString
}
}


Expand Down
2 changes: 1 addition & 1 deletion shared/src/test/diff/nu/ArrayProg.mls
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ module A {
//│ }

A.g(0)
//│ Int | Str
//│ Int
//│ res
//│ = 0

Expand Down
62 changes: 44 additions & 18 deletions shared/src/test/diff/nu/HeungTung.mls
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,16 @@ fun g = h
//│ fun g: (Bool | Int) -> (Int | false | true)

// * In one step
:e // TODO: argument of union type
fun g: (Int | Bool) -> (Int | Bool)
fun g = f
//│ ╔══[ERROR] Type mismatch in definition:
//│ ║ l.71: fun g = f
//│ ║ ^^^^^
//│ ╟── expression of type `Int | false | true` does not match type `?a`
//│ ╟── Note: constraint arises from function type:
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ fun g: Int -> Int & Bool -> Bool
//│ fun g: (Bool | Int) -> (Int | false | true)

Expand All @@ -88,9 +96,11 @@ fun j = i
fun j: (Int & Bool) -> (Int & Bool)
fun j = f
//│ ╔══[ERROR] Type mismatch in definition:
//│ ║ l.89: fun j = f
//│ ║ l.97: fun j = f
//│ ║ ^^^^^
//│ ╙── expression of type `Int` does not match type `nothing`
//│ ╟── type `?a` does not match type `nothing`
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ fun j: Int -> Int & Bool -> Bool
//│ fun j: nothing -> nothing

Expand All @@ -106,23 +116,30 @@ fun g = f
// * With match-type-based constraint solving, we could return Int here

f(0)
//│ Int | false | true
//│ Int
//│ res
//│ = 0

// f(0) : case 0 of { Int => Int; Bool => Bool } == Int


x => f(x)
//│ (Bool | Int) -> (Int | false | true)
//│ anything -> nothing
auht marked this conversation as resolved.
Show resolved Hide resolved
//│ res
//│ = [Function: res]

// : forall 'a: 'a -> case 'a of { Int => Int; Bool => Bool } where 'a <: Int | Bool


:e
f(if true then 0 else false)
//│ Int | false | true
//│ ╔══[ERROR] Type mismatch in application:
//│ ║ l.134: f(if true then 0 else false)
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── expression of type `0 | false` does not match type `?a`
//│ ╟── Note: constraint arises from function type:
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ error
//│ res
//│ = 0

Expand All @@ -132,12 +149,21 @@ f(if true then 0 else false)
:w
f(refined if true then 0 else false) // this one can be precise again!
//│ ╔══[WARNING] Paren-less applications should use the 'of' keyword
//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╔══[ERROR] identifier not found: refined
//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ╙── ^^^^^^^
//│ Int | false | true
//│ ╔══[ERROR] Type mismatch in application:
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── application of type `error` does not match type `?a`
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── Note: constraint arises from function type:
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ error
//│ Code generation encountered an error:
//│ unresolved symbol refined

Expand Down Expand Up @@ -193,7 +219,7 @@ type T = List[Int]
:e // TODO application types
type Res = M(T)
//│ ╔══[ERROR] Wrong number of type arguments – expected 0, found 1
//│ ║ l.194: type Res = M(T)
//│ ║ l.220: type Res = M(T)
//│ ╙── ^^^^
//│ type Res = M

Expand All @@ -216,21 +242,21 @@ fun f: Int -> Int
fun f: Bool -> Bool
fun f = id
//│ ╔══[ERROR] A type signature for 'f' was already given
//│ ║ l.216: fun f: Bool -> Bool
//│ ║ l.242: fun f: Bool -> Bool
//│ ╙── ^^^^^^^^^^^^^^^^^^^
//│ fun f: forall 'a. 'a -> 'a
//│ fun f: Int -> Int

:e // TODO support
f: (Int -> Int) & (Bool -> Bool)
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool)
//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool)
//│ ║ ^
//│ ╟── type `Bool` is not an instance of `Int`
//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool)
//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool)
//│ ║ ^^^^
//│ ╟── Note: constraint arises from type reference:
//│ ║ l.215: fun f: Int -> Int
//│ ║ l.241: fun f: Int -> Int
//│ ╙── ^^^
//│ Int -> Int & Bool -> Bool
//│ res
Expand Down Expand Up @@ -297,14 +323,14 @@ fun test(x) = refined if x is
A then 0
B then 1
//│ ╔══[WARNING] Paren-less applications should use the 'of' keyword
//│ ║ l.296: fun test(x) = refined if x is
//│ ║ l.322: fun test(x) = refined if x is
//│ ║ ^^^^^^^^^^^^^^^
//│ ║ l.297: A then 0
//│ ║ l.323: A then 0
//│ ║ ^^^^^^^^^^
//│ ║ l.298: B then 1
//│ ║ l.324: B then 1
//│ ╙── ^^^^^^^^^^
//│ ╔══[ERROR] identifier not found: refined
//│ ║ l.296: fun test(x) = refined if x is
//│ ║ l.322: fun test(x) = refined if x is
//│ ╙── ^^^^^^^
//│ fun test: (A | B) -> error
//│ Code generation encountered an error:
Expand Down
Loading