Skip to content

Commit

Permalink
fix provtype
Browse files Browse the repository at this point in the history
  • Loading branch information
auht committed Aug 27, 2024
1 parent 09b5082 commit e1b7fba
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 79 deletions.
22 changes: 7 additions & 15 deletions shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,6 @@ class ConstraintSolver extends NormalForms { self: Typer =>
case S(tsc) => if (!tsc.tvs.isEmpty && tsc.constraints.isEmpty) reportError()
case N => reportError()
}
// val t = TupleSetConstraints.mk(ov)
// annoying(Nil, LhsRefined(S(t), ts, r, trs), Nil, done_rs)
case (LhsRefined(S(ov: Overload), ts, r, trs), _) =>
annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints
case (LhsRefined(S(Without(b, ns)), ts, r, _), RhsBases(pts, N | S(L(_)), _)) =>
Expand Down Expand Up @@ -849,9 +847,8 @@ class ConstraintSolver extends NormalForms { self: Typer =>
}
val u = lhs.tsc.filter(_._1.constraints.sizeCompare(1) === 0)
u.foreachEntry { case (k, _) =>
k.tvs.mapValues(_.unwrapProxies).foreach { // TODO less inefficient; remove useless case
k.tvs.mapValues(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case (_,ProvType(tv: TV)) => tv.tsc.remove(k)
case _ => ()
}
}
Expand All @@ -877,9 +874,8 @@ class ConstraintSolver extends NormalForms { self: Typer =>
}
val u = rhs.tsc.filter(_._1.constraints.sizeCompare(1) === 0)
u.foreachEntry { case (k, _) =>
k.tvs.mapValues(_.unwrapProxies).foreach { // TODO less inefficient; remove useless case
k.tvs.mapValues(_.unwrapProxies).foreach {
case (_,tv: TV) => tv.tsc.remove(k)
case (_,ProvType(tv: TV)) => tv.tsc.remove(k)
case _ => ()
}
}
Expand Down Expand Up @@ -1612,24 +1608,20 @@ class ConstraintSolver extends NormalForms { self: Typer =>
lvl
})
val freshentsc = tv.tsc.flatMap { case (tsc,_) =>
if (tsc.tvs.forall {
case (_,tv: TV) => !freshened.contains(tv)
case (_,ProvType(tv: TV)) => !freshened.contains(tv)
if (tsc.tvs.map(_._2.unwrapProxies).forall {
case tv: TV => !freshened.contains(tv)
case _ => true
}) S(tsc) else N
}
freshened += tv -> v
v.lowerBounds = tv.lowerBounds.mapConserve(freshen)
v.upperBounds = tv.upperBounds.mapConserve(freshen)
freshentsc.foreach { tsc =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)(tsc.prov)
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
t.constraints = t.constraints.map(_.map(freshen))
t.tvs = t.tvs.map(x => (x._1,freshen(x._2)))
t.tvs.zipWithIndex.foreach {
case ((pol, tv: TV), i) =>
tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
case ((pol, ProvType(tv: TV)), i) =>
tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
t.tvs.map(_._2.unwrapProxies).zipWithIndex.foreach {
case (tv: TV, i) => tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
case _ => ()
}
}
Expand Down
4 changes: 2 additions & 2 deletions shared/src/main/scala/mlscript/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ trait TypeSimplifier { self: Typer =>
case S(tsc) => (tsc, i)
case N if inPlace => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)(tsc.prov)
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewedtsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, process(x._2, N)))
(t, i)
Expand Down Expand Up @@ -1041,7 +1041,7 @@ trait TypeSimplifier { self: Typer =>
res.tsc ++= tv.tsc.map { case (tsc, i) => renewaltsc.get(tsc) match {
case S(tsc) => (tsc, i)
case N =>
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)(tsc.prov)
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
renewaltsc += tsc -> t
t.tvs = t.tvs.map(x => (x._1, transform(x._2, PolMap.neu, Set.empty)))
(t, i)
Expand Down
13 changes: 4 additions & 9 deletions shared/src/main/scala/mlscript/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -686,10 +686,9 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
}
tscs.foreach { case (typevars, constrs) =>
val tvs = typevars.map(x => (x._1, rec(x._2)))
val tsc = new TupleSetConstraints(constrs.map(_.map(rec)), tvs)(res.prov)
tvs.zipWithIndex.foreach {
case ((_, tv: TV), i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
case ((_, ProvType(tv: TV)), i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
val tsc = new TupleSetConstraints(constrs.map(_.map(rec)), tvs)
tvs.map(_._2.unwrapProxies).zipWithIndex.foreach {
case (tv: TV, i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
case _ => ()
}
}
Expand Down Expand Up @@ -1992,11 +1991,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
tv.tsc.foreachEntry {
case (tsc, i) =>
if (seenTscs.add(tsc)) {
val tvs = tsc.tvs.map {
case (pol, tv: TV) => (pol, tv.asTypeVar)
case (pol, ProvType(tv: TV)) => (pol, tv.asTypeVar)
case (pol, t) => (pol, go(t))
}
val tvs = tsc.tvs.map(x => (x._1,go(x._2)))
val constrs = tsc.constraints.map(_.map(go))
tscs ::= tvs -> constrs
}
Expand Down
21 changes: 9 additions & 12 deletions shared/src/main/scala/mlscript/TyperDatatypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
val prov = noProv
}

class TupleSetConstraints(var constraints: Ls[Ls[ST]], var tvs: Ls[(Bool, ST)])(val prov: TypeProvenance) {
class TupleSetConstraints(var constraints: Ls[Ls[ST]], var tvs: Ls[(Bool, ST)]) {
def updateImpl(index: Int, bound: ST)(implicit raise: Raise, ctx: Ctx) : Unit = {
val u0 = constraints.flatMap { c =>
TupleSetConstraints.lcg(tvs(index)._1, bound, c(index)).map(tvs.zip(c)++_)
Expand All @@ -683,17 +683,15 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
(u,l.reduce((x,y) => ComposedType(!p,x,y)(noProv)))
}
}
tvs.foreach {
case (_, tv: TV) => tv.tsc += this -> Set.empty
case (_, ProvType(tv: TV)) => tv.tsc += this -> Set.empty
tvs.map(_._2.unwrapProxies).foreach {
case tv: TV => tv.tsc += this -> Set.empty
case _ => ()
}
if (!u.isEmpty) {
tvs = u.flatMap(_.keys).distinct
constraints = tvs.map(x => u.map(_.getOrElse(x,if (x._1) TopType else BotType))).transpose
tvs.zipWithIndex.foreach {
case ((pol, tv: TV), i) => tv.tsc.updateWith(this)(_.map(_ + i).orElse(S(Set(i))))
case ((pol, ProvType(tv: TV)), i) => tv.tsc.updateWith(this)(_.map(_ + i).orElse(S(Set(i))))
tvs.map(_._2.unwrapProxies).zipWithIndex.foreach {
case (tv: TV, i) => tv.tsc.updateWith(this)(_.map(_ + i).orElse(S(Set(i))))
case _ => ()
}
} else {
Expand Down Expand Up @@ -779,8 +777,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
if (u.isEmpty) { return N }
val tvs = u.flatMap(_.keys).distinct
val m = tvs.map(x => u.map(_.getOrElse(x,if (x._1) TopType else BotType)))
val tsc = new TupleSetConstraints(m.transpose, tvs)(ov.prov)
tvs.map(x => (x._1,x._2.unwrapProxies)).zipWithIndex.foreach {
val tsc = new TupleSetConstraints(m.transpose, tvs)
tvs.mapValues(_.unwrapProxies).zipWithIndex.foreach {
case ((true, tv: TV), i) =>
tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
tv.lowerBounds.foreach(tsc.updateImpl(i, _))
Expand All @@ -791,9 +789,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
}
println(s"TSC mk: ${tsc.tvs} in ${tsc.constraints}")
if (tsc.constraints.sizeCompare(1) === 0) {
tvs.foreach {
case (_, tv: TV) => tv.tsc.remove(tsc)
case (_, ProvType(tv: TV)) => tv.tsc.remove(tsc)
tvs.map(_._2.unwrapProxies).foreach {
case tv: TV => tv.tsc.remove(tsc)
case _ => ()
}
tsc.constraints.head.zip(tvs).foreach {
Expand Down
6 changes: 2 additions & 4 deletions shared/src/main/scala/mlscript/TyperHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ abstract class TyperHelpers { Typer: Typer =>
}
def children(includeBounds: Bool): List[SimpleType] = this match {
case tv @ AssignedVariable(ty) => if (includeBounds) ty :: Nil else Nil
case tv: TypeVariable => if (includeBounds) tv.lowerBounds ::: tv.upperBounds else Nil
case tv: TypeVariable => if (includeBounds) tv.lowerBounds ::: tv.upperBounds ++ tv.tsc.keys.flatMap(_.tvs.map(_._2)) else Nil
case FunctionType(l, r) => l :: r :: Nil
case Overload(as) => as
case ComposedType(_, l, r) => l :: r :: Nil
Expand Down Expand Up @@ -1014,9 +1014,7 @@ abstract class TyperHelpers { Typer: Typer =>
val couldBeDistribbed = bod.varsBetween(polymLevel, MaxLevel)
println(s"could be distribbed: $couldBeDistribbed")
if (couldBeDistribbed.isEmpty) return N
val cannotBeDistribbed = par.varsBetween(polymLevel, MaxLevel).flatMap { v =>
v :: v.tsc.keys.flatMap(_.tvs.flatMap(_._2.getVars)).toList
}
val cannotBeDistribbed = par.varsBetween(polymLevel, MaxLevel)
println(s"cannot be distribbed: $cannotBeDistribbed")
val canBeDistribbed = couldBeDistribbed -- cannotBeDistribbed
if (canBeDistribbed.isEmpty) return N // TODO
Expand Down
23 changes: 11 additions & 12 deletions shared/src/test/diff/fcp/Overloads.mls
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,22 @@ IISS 0
def f = fun x -> (if true then IISS else BBNN) x
//│ f: 'a -> 'b
//│ where
//│ [+'a, -'b] in {[bool, bool], [number, number]}
//│ [+'a, -'b] in {[int, int], [string, string]}
//│ [+'a, -'b] in {[bool, bool], [number, number]}

f(0)
//│ res: number


// FIXME
f(0) + 1
//│ ╔══[ERROR] Type mismatch in operator application:
//│ ║ l.105: f(0) + 1
//│ ║ l.104: f(0) + 1
//│ ║ ^^^^^^
//│ ╟── type `number` is not an instance of type `int`
//│ ║ l.13: def BBNN: bool -> bool & number -> number
//│ ║ ^^^^^^
//│ ╟── but it flows into application with expected type `int`
//│ ║ l.105: f(0) + 1
//│ ║ l.104: f(0) + 1
//│ ╙── ^^^^
//│ res: error | int

Expand All @@ -120,10 +119,10 @@ f : int -> number
:e
f : number -> int
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.121: f : number -> int
//│ ║ l.120: f : number -> int
//│ ║ ^
//│ ╟── type `number` does not match type `?a`
//│ ║ l.121: f : number -> int
//│ ║ l.120: f : number -> int
//│ ╙── ^^^^^^
//│ res: number -> int

Expand All @@ -140,16 +139,16 @@ if true then IISS else BBNN
:e
(if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.141: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ l.140: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── type `int -> int & string -> string` is not a function
//│ ║ l.12: def IISS: int -> int & string -> string
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── but it flows into reference with expected type `(0 | 1 | true) -> number`
//│ ║ l.141: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ l.140: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ ^^^^
//│ ╟── Note: constraint arises from function type:
//│ ║ l.141: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ║ l.140: (if true then IISS else BBNN) : (0 | 1 | true) -> number
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^
//│ res: (0 | 1 | true) -> number

Expand All @@ -168,13 +167,13 @@ not test
//│ <: test:
//│ ~(int -> int)
//│ ╔══[ERROR] Type mismatch in application:
//│ ║ l.166: not test
//│ ║ l.165: not test
//│ ║ ^^^^^^^^
//│ ╟── type `~(int -> int)` is not an instance of type `bool`
//│ ║ l.160: def test: ~(int -> int)
//│ ║ l.159: def test: ~(int -> int)
//│ ║ ^^^^^^^^^^^^^
//│ ╟── but it flows into reference with expected type `bool`
//│ ║ l.166: not test
//│ ║ l.165: not test
//│ ╙── ^^^^
//│ res: bool | error

Expand Down
Loading

0 comments on commit e1b7fba

Please sign in to comment.