Skip to content

Commit

Permalink
guard with flag
Browse files Browse the repository at this point in the history
  • Loading branch information
auht committed Oct 4, 2024
1 parent 4f21762 commit e740159
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 115 deletions.
64 changes: 34 additions & 30 deletions shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -861,24 +861,26 @@ class ConstraintSolver extends NormalForms { self: Typer =>
val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
lhs.upperBounds ::= newBound // update the bound
lhs.tsc.foreachEntry { (tsc, v) =>
v.foreach { i =>
if (!tsc.tvs(i)._1) {
tsc.updateOn(i, rhs)
if (tsc.constraints.isEmpty) reportError()
if (noApproximateOverload) {
lhs.tsc.foreachEntry { (tsc, v) =>
v.foreach { i =>
if (!tsc.tvs(i)._1) {
tsc.updateOn(i, rhs)
if (tsc.constraints.isEmpty) reportError()
}
}
}
}
val u = lhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
u._1.foreach { k =>
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case _ => ()
val u = lhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
u._1.foreach { k =>
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case _ => ()
}
}
}
u._2.foreach { k =>
k.constraints.head.iterator.zip(k.tvs).foreach {
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
u._2.foreach { k =>
k.constraints.head.iterator.zip(k.tvs).foreach {
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
}
}
}
lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound
Expand All @@ -888,24 +890,26 @@ class ConstraintSolver extends NormalForms { self: Typer =>
val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
rhs.lowerBounds ::= newBound // update the bound
rhs.tsc.foreachEntry { (tsc, v) =>
v.foreach { i =>
if(tsc.tvs(i)._1) {
tsc.updateOn(i, lhs)
if (tsc.constraints.isEmpty) reportError()
if (noApproximateOverload) {
rhs.tsc.foreachEntry { (tsc, v) =>
v.foreach { i =>
if(tsc.tvs(i)._1) {
tsc.updateOn(i, lhs)
if (tsc.constraints.isEmpty) reportError()
}
}
}
}
val u = rhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
u._1.foreach { k =>
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case _ => ()
val u = rhs.tsc.keysIterator.filter(_.constraints.sizeCompare(1)===0).duplicate
u._1.foreach { k =>
k.tvs.mapValuesIter(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case _ => ()
}
}
}
u._2.foreach { k =>
k.constraints.head.iterator.zip(k.tvs).foreach {
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
u._2.foreach { k =>
k.constraints.head.iterator.zip(k.tvs).foreach {
case (c, (pol, t)) => if (pol) rec(t, c, false) else rec(c, t, false)
}
}
}
rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound
Expand Down
47 changes: 26 additions & 21 deletions shared/src/main/scala/mlscript/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ trait TypeSimplifier { self: Typer =>
).map(process(_, S(false -> tv)))
.reduceOption(_ &- _).filterNot(_.isTop).toList
else Nil
nv.tsc ++= tv.tsc.map { case (tsc, i) => renewedtsc.get(tsc) match {
case S(tsc) => (tsc, i)
case N if inPlace => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewedtsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, process(x._2, N)))
(t, i)
}}
if (noApproximateOverload)
nv.tsc ++= tv.tsc.iterator.map { case (tsc, i) => renewedtsc.get(tsc) match {
case S(tsc) => (tsc, i)
case N if inPlace => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewedtsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, process(x._2, N)))
(t, i)
}}
}
nv

Expand Down Expand Up @@ -560,12 +561,14 @@ trait TypeSimplifier { self: Typer =>
if (pol(tv) =/= S(false))
analyzed1.setAndIfUnset(tv -> true) {
tv.lowerBounds.foreach(apply(pol.at(tv.level, true)))
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
if (noApproximateOverload)
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
}
if (pol(tv) =/= S(true))
analyzed1.setAndIfUnset(tv -> false) {
tv.upperBounds.foreach(apply(pol.at(tv.level, false)))
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
if (noApproximateOverload)
tv.tsc.keys.flatMap(_.tvs).foreach(u => apply(pol.at(tv.level,u._1))(u._2))
}
}
case _ =>
Expand Down Expand Up @@ -660,7 +663,8 @@ trait TypeSimplifier { self: Typer =>
case S(pol_tv) =>
if (analyzed2.add(pol_tv -> tv)) {
processImpl(st, pol, pol_tv)
tv.tsc.keys.flatMap(_.tvs).foreach(u => processImpl(u._2,pol.at(tv.level,u._1),pol_tv))
if (noApproximateOverload)
tv.tsc.keys.flatMap(_.tvs).foreach(u => processImpl(u._2,pol.at(tv.level,u._1),pol_tv))
}
case N =>
if (analyzed2.add(true -> tv))
Expand Down Expand Up @@ -707,7 +711,7 @@ trait TypeSimplifier { self: Typer =>
case S(p) =>
(if (p) tv2.lowerBounds else tv2.upperBounds).foreach(go)
// (if (p) getLbs(tv2) else getUbs(tv2)).foreach(go)
tv2.tsc.keys.flatMap(_.tvs).foreach(u => go(u._2))
if (noApproximateOverload) tv2.tsc.keys.flatMap(_.tvs).foreach(u => go(u._2))
case N =>
trace(s"Analyzing invar-occ of $tv2") {
analyze2(tv2, pol)
Expand Down Expand Up @@ -1038,14 +1042,15 @@ trait TypeSimplifier { self: Typer =>
res.lowerBounds = tv.lowerBounds.map(transform(_, pol.at(tv.level, true), Set.single(tv)))
if (occNums.contains(false -> tv))
res.upperBounds = tv.upperBounds.map(transform(_, pol.at(tv.level, false), Set.single(tv)))
res.tsc ++= tv.tsc.map { case (tsc, i) => renewaltsc.get(tsc) match {
case S(tsc) => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewaltsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, transform(x._2, PolMap.neu, Set.empty)))
(t, i)
}}
if (noApproximateOverload)
res.tsc ++= tv.tsc.map { case (tsc, i) => renewaltsc.get(tsc) match {
case S(tsc) => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewaltsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, transform(x._2, PolMap.neu, Set.empty)))
(t, i)
}}
}
res
}()
Expand Down
16 changes: 9 additions & 7 deletions shared/src/main/scala/mlscript/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,15 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
if (p) (b, tv) else (tv, b) }
}.toList, innerTy)

val ambiguous = innerTy.getVars.unsorted.flatMap(_.tsc.keys.flatMap(_.tvs))
.groupBy(_._2)
.filter { case (v,pvs) => pvs.sizeIs > 1 }
if (ambiguous.nonEmpty) raise(ErrorReport(
msg"ambiguous" -> N ::
ambiguous.map { case (v,_) => msg"cannot determine satisfiability of type ${v.expPos}" -> v.prov.loco }.toList
, true))
if (noApproximateOverload) {
val ambiguous = innerTy.getVars.unsorted.flatMap(_.tsc.keys.flatMap(_.tvs))
.groupBy(_._2)
.filter { case (v,pvs) => pvs.sizeIs > 1 }
if (ambiguous.nonEmpty) raise(ErrorReport(
msg"ambiguous" -> N ::
ambiguous.map { case (v,_) => msg"cannot determine satisfiability of type ${v.expPos}" -> v.prov.loco }.toList
, true))
}

println(s"Inferred poly constr: $cty —— where ${cty.showBounds}")

Expand Down
101 changes: 44 additions & 57 deletions shared/src/test/diff/fcp/Overloads.mls
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

:NoJS
:NoApproximateOverload


type IISS = int -> int & string -> string
type BBNN = bool -> bool & number -> number
Expand Down Expand Up @@ -31,37 +31,22 @@ IISS : ZZII
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.30: IISS : ZZII
//│ ║ ^^^^
//│ ╟── type `int -> int & string -> string` is not an instance of `0 -> 0`
//│ ╟── type `int` does not match type `0`
//│ ║ l.12: def IISS: int -> int & string -> string
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── but it flows into reference with expected type `0 -> 0`
//│ ║ l.30: IISS : ZZII
//│ ║ ^^^^
//│ ╟── Note: constraint arises from function type:
//│ ║ ^^^
//│ ╟── Note: constraint arises from literal type:
//│ ║ l.7: type ZZII = 0 -> 0 & int -> int
//│ ║ ^^^^^^
//│ ╟── from type reference:
//│ ║ l.30: IISS : ZZII
//│ ╙── ^^^^
//│ ╙── ^
//│ res: ZZII

:e
IISS : BBNN
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.49: IISS : BBNN
//│ ║ ^^^^
//│ ╟── type `int -> int & string -> string` is not an instance of `bool -> bool`
//│ ║ l.12: def IISS: int -> int & string -> string
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── but it flows into reference with expected type `bool -> bool`
//│ ║ l.49: IISS : BBNN
//│ ║ l.43: IISS : BBNN
//│ ║ ^^^^
//│ ╟── Note: constraint arises from function type:
//│ ╟── type `bool` does not match type `int | string`
//│ ║ l.6: type BBNN = bool -> bool & number -> number
//│ ║ ^^^^^^^^^^^^
//│ ╟── from type reference:
//│ ║ l.49: IISS : BBNN
//│ ╙── ^^^^
//│ ╙── ^^^^
//│ res: BBNN


Expand All @@ -74,18 +59,16 @@ IISS : (0 | 1) -> number
//│ res: (0 | 1) -> number

IISS : 'a -> 'a
//│ res: 'a -> 'a
//│ where
//│ [-'a, +'a] in {[int, int], [string, string]}
//│ res: ('a & (int | string)) -> (int | string | 'a)

IISS 0
//│ res: int
//│ res: int | string

(IISS : int -> int) 0
//│ res: int

(if true then IISS else BBNN) 0
//│ res: number
//│ res: bool | number | string

// * Note that this is not considered ambiguous
// * because the type variable occurrences are polar,
Expand All @@ -94,38 +77,48 @@ IISS 0
// * Conceptually, we'd expect this inferred type to reduce to `int -> number`,
// * but it's tricky to do such simplifications in general.
def f = fun x -> (if true then IISS else BBNN) x
//│ f: 'a -> 'b
//│ where
//│ [+'a, -'b] in {[int, int], [string, string]}
//│ [+'a, -'b] in {[bool, bool], [number, number]}
//│ f: int -> (bool | number | string)

f(0)
//│ res: number
//│ res: bool | number | string

// FIXME
:e
f(0) + 1
//│ ╔══[ERROR] Type mismatch in operator application:
//│ ║ l.106: f(0) + 1
//│ ║ ^^^^^^
//│ ╟── type `number` is not an instance of type `int`
//│ ║ l.86: f(0) + 1
//│ ║ ^^^^^^
//│ ╟── type `bool` is not an instance of type `int`
//│ ║ l.13: def BBNN: bool -> bool & number -> number
//│ ║ ^^^^^^
//│ ║ ^^^^
//│ ╟── but it flows into application with expected type `int`
//│ ║ l.106: f(0) + 1
//│ ╙── ^^^^
//│ ║ l.86: f(0) + 1
//│ ╙── ^^^^
//│ res: error | int

:e
f : int -> number
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.99: f : int -> number
//│ ║ ^
//│ ╟── type `bool` is not an instance of type `number`
//│ ║ l.13: def BBNN: bool -> bool & number -> number
//│ ║ ^^^^
//│ ╟── Note: constraint arises from type reference:
//│ ║ l.99: f : int -> number
//│ ╙── ^^^^^^
//│ res: int -> number

:e
f : number -> int
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.122: f : number -> int
//│ ║ l.112: f : number -> int
//│ ║ ^
//│ ╟── type `number` does not match type `?a`
//│ ║ l.122: f : number -> int
//│ ╙── ^^^^^^
//│ ╟── type `number` does not match type `int | string`
//│ ║ l.112: f : number -> int
//│ ║ ^^^^^^
//│ ╟── Note: constraint arises from reference:
//│ ║ l.79: def f = fun x -> (if true then IISS else BBNN) x
//│ ╙── ^
//│ res: number -> int


Expand All @@ -141,17 +134,11 @@ if true then IISS else BBNN
:e
(if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.142: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ l.135: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── type `int -> int & string -> string` is not an instance of `(0 | 1 | true) -> number`
//│ ║ l.12: def IISS: int -> int & string -> string
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── but it flows into reference with expected type `(0 | 1 | true) -> number`
//│ ║ l.142: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ ^^^^
//│ ╟── Note: constraint arises from function type:
//│ ║ l.142: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── type `true` does not match type `int | string`
//│ ║ l.135: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ╙── ^^^^
//│ res: (0 | 1 | true) -> number


Expand All @@ -169,13 +156,13 @@ not test
//│ <: test:
//│ ~(int -> int)
//│ ╔══[ERROR] Type mismatch in application:
//│ ║ l.167: not test
//│ ║ l.154: not test
//│ ║ ^^^^^^^^
//│ ╟── type `~(int -> int)` is not an instance of type `bool`
//│ ║ l.161: def test: ~(int -> int)
//│ ║ l.148: def test: ~(int -> int)
//│ ║ ^^^^^^^^^^^^^
//│ ╟── but it flows into reference with expected type `bool`
//│ ║ l.167: not test
//│ ║ l.154: not test
//│ ╙── ^^^^
//│ res: bool | error

Expand Down
Loading

0 comments on commit e740159

Please sign in to comment.