-
Notifications
You must be signed in to change notification settings - Fork 0
/
tychk_nbe.ml
269 lines (239 loc) · 9.17 KB
/
tychk_nbe.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
(* Source: https://gist.github.com/mb64/f49ccb1bbf2349c8026d8ccf29bd158e#file-tychk_nbe-ml courtesy of MBones,
with some modifications *)
(* Build with: ocamlfind ocamlc -package angstrom,stdio -linkpkg tychk_nbe.ml -o tychk *)
module AST = struct
type ty =
| TNamed of string
| TFun of ty * ty
| TForall of string * ty
type exp =
| Var of string
| App of exp * exp
| Annot of exp * ty
| Lam of string * exp
| Let of string * exp * exp
end
let elem_index a = (* wish OCaml had this in the stdlib *)
let rec go i = function
| [] -> None
| x :: xs -> if x = a then Some i else go (i+1) xs in
go 0
module Infer = struct
type idx = int
type lvl = int
type ty =
| TVar of idx
| TFun of ty * ty
| TForall of string * ty
and vty =
| VVar of lvl
| VFun of vty * vty
| VForall of string * (vty -> vty)
| VHole of hole ref
and hole =
| Empty of { scope: lvl }
| Filled of vty
type ctx = { type_names: string list; lvl: lvl; env: (string * vty) list }
let initial_ctx: ctx = { type_names = []; lvl = 0; env = [] }
exception TypeError of string
let add_ty_to_ctx (name: string) (ctx: ctx): ctx =
{ type_names = name :: ctx.type_names
; lvl = ctx.lvl + 1
; env = ctx.env }
let add_var_to_ctx (name: string) (ty: vty) (ctx: ctx): ctx =
{ type_names = ctx.type_names
; lvl = ctx.lvl
; env = (name, ty) :: ctx.env }
let lookup_var (name: string) (ctx: ctx) =
match List.assoc_opt name ctx.env with
| Some ty -> ty
| None -> raise (TypeError ("variable " ^ name ^ " not in scope"))
let ast_ty_to_ty (ast_ty: AST.ty) =
let rec helper (env: string list) (ast_ty: AST.ty) = match ast_ty with
| TNamed n -> (match elem_index n env with
| Some idx -> TVar idx
| None -> raise (TypeError ("type variable " ^ n ^ " not in scope")))
| TFun(a, b) -> TFun (helper env a, helper env b)
| TForall(n, a) -> TForall (n, helper (n::env) a) in
helper [] ast_ty
let rec eval (env: vty list) = function
| TVar idx -> List.nth env idx
| TFun(a, b) -> VFun(eval env a, eval env b)
| TForall(name, ty) -> VForall(name, fun x -> eval (x::env) ty)
let deref = function
| VHole hole ->
let rec helper h = match !h with
| Filled (VHole h') ->
(* path compression *)
let a = helper h' in h := Filled a; a
| Filled a -> a
| _ -> VHole h in
helper hole
| a -> a
let print_ty (ctx: ctx) ty =
let parens p s = if p then "(" ^ s ^ ")" else s in
let rec helper ctx p t = match deref t with
| VVar lvl -> List.nth ctx.type_names (ctx.lvl - lvl - 1)
| VFun(a, b) -> parens p (helper ctx true a ^ " -> " ^ helper ctx false b)
| VForall(n, a) ->
let rec freshen_name n =
if List.mem n ctx.type_names then freshen_name (n ^ "'") else n in
let n' = freshen_name n in
let pr_a = helper (add_ty_to_ctx n' ctx) false (a (VVar ctx.lvl)) in
parens p ("forall " ^ n' ^ ". " ^ pr_a)
| VHole { contents = Empty { scope = lvl } } ->
Printf.sprintf "?[at lvl %d]" lvl
| VHole _ -> raise (invalid_arg "this should've been handled by deref") in
helper ctx false ty
(* when filling in a hole, a few things need to be checked:
- occurs check: check that you aren't making recursive types
- scope check: check that you aren't using bound vars outside its scope
*)
let unify_hole_prechecks (ctx: ctx) (hole: hole ref) (scope: lvl) ty =
let initial_lvl = ctx.lvl in
let rec helper ctx t = match deref t with
| VVar lvl ->
if lvl >= scope && lvl < initial_lvl
then raise (TypeError ("type variable " ^ print_ty ctx (VVar lvl) ^ " escaping its scope"))
| VFun(a, b) -> helper ctx a; helper ctx b;
| VForall(n, a) ->
helper (add_ty_to_ctx n ctx) (a (VVar ctx.lvl))
| VHole ({ contents = Empty { scope = l } } as h) ->
if h = hole
then raise (TypeError "occurs check: can't make infinite type")
else if l > scope then h := Empty { scope }
| _ -> raise (invalid_arg "unify_hole_prechecks")
in helper ctx ty
let rec unify (ctx: ctx) a b = match deref a, deref b with
| VHole hole_a, _ -> unify_hole_ty ctx hole_a b
| _, VHole hole_b -> unify_hole_ty ctx hole_b a
| VVar lvl_a, VVar lvl_b when lvl_a = lvl_b -> ()
| VFun(a1, a2), VFun(b1, b2) -> unify ctx a1 b1; unify ctx a2 b2
| VForall(n, a_fun), VForall(_, b_fun) ->
let new_ctx = add_ty_to_ctx n ctx in
unify new_ctx (a_fun (VVar ctx.lvl)) (b_fun (VVar ctx.lvl))
| _ ->
let a', b' = print_ty ctx a, print_ty ctx b in
raise (TypeError ("mismatch between " ^ a' ^ " and " ^ b'))
and unify_hole_ty (ctx: ctx) hole ty =
match !hole with
| Empty { scope } ->
if ty <> VHole hole
then (unify_hole_prechecks ctx hole scope ty; hole := Filled ty)
| Filled _ -> raise (invalid_arg "unify_hole_ty")
let rec eagerly_instantiate (ctx: ctx) = function
| VForall(n, a) ->
let new_hole = ref (Empty { scope = ctx.lvl }) in
eagerly_instantiate ctx (a (VHole new_hole))
| a -> a
(* The mutually-recursive typechecking functions *)
let rec check (ctx: ctx) (term: AST.exp) (ty: vty) = match term, deref ty with
| _, VForall(n, a) ->
check (add_ty_to_ctx n ctx) term (a (VVar ctx.lvl))
| Lam(var, body), VFun(a, b) ->
check (add_var_to_ctx var a ctx) body b
| Let(var, exp, body), a ->
let exp_ty = infer ctx exp in
check (add_var_to_ctx var exp_ty ctx) body a
| _, a ->
let inferred_ty = infer_and_inst ctx term in
unify ctx inferred_ty a
and infer (ctx: ctx) (term: AST.exp) = match term with
| Var var -> lookup_var var ctx
| Annot(e, ast_ty) ->
let ty = eval [] (ast_ty_to_ty ast_ty) in
check ctx e ty; ty
| App(f, arg) ->
let f_ty = infer_and_inst ctx f in
begin match deref f_ty with
| VFun(a, b) -> check ctx arg a; b
| VHole ({ contents = Empty { scope } } as hole) ->
let a = VHole (ref (Empty { scope })) in
let b = VHole (ref (Empty { scope })) in
hole := Filled (VFun(a, b));
check ctx arg a;
b
| _ -> raise (TypeError "not a function type")
end
| Lam(var, body) ->
let arg_ty = VHole (ref (Empty { scope = ctx.lvl })) in
let res_ty = infer_and_inst (add_var_to_ctx var arg_ty ctx) body in
VFun(arg_ty, res_ty)
| Let(var, exp, body) ->
let exp_ty = infer ctx exp in
infer (add_var_to_ctx var exp_ty ctx) body
and infer_and_inst (ctx: ctx) (term: AST.exp) =
let ty = infer ctx term in eagerly_instantiate ctx ty
end
(* module Parser = struct
open AST
open Angstrom (* parser combinators library *)
let keywords = ["forall"; "let"; "in"; "fun"]
let whitespace = take_while (String.contains " \n\t")
let lexeme a = a <* whitespace
let ident = lexeme (
let is_ident_char c =
c = '_' || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') in
let* i = take_while is_ident_char in
if String.length i > 0 then return i else fail "expected ident")
let str s = lexeme (string s) *> return ()
let name =
let* i = ident in
if List.mem i keywords then fail (i ^ " is a keyword") else return i
let keyword k =
let* i = ident in
if i = k then return () else fail ("expected " ^ k)
let parens p = str "(" *> p <* str ")"
let ty = fix (fun ty ->
let simple_ty = parens ty <|> lift (fun n -> TNamed n) name in
let forall_ty =
let+ () = keyword "forall"
and+ names = many1 name
and+ () = str "."
and+ a = ty in
List.fold_right (fun n a -> TForall(n, a)) names a in
let fun_ty =
let+ arg_ty = simple_ty
and+ () = str "->"
and+ res_ty = ty in
TFun(arg_ty, res_ty) in
forall_ty <|> fun_ty <|> simple_ty <?> "type")
let exp = fix (fun exp ->
let atomic_exp = parens exp <|> lift (fun n -> Var n) name in
let make_app (f::args) =
List.fold_left (fun f arg -> App(f,arg)) f args in
let simple_exp = lift make_app (many1 atomic_exp) in
let annot_exp =
let+ e = simple_exp
and+ annot = option (fun e -> e)
(lift (fun t e -> Annot(e,t)) (str ":" *> ty)) in
annot e in
let let_exp =
let+ () = keyword "let"
and+ n = name
and+ () = str "="
and+ e = exp
and+ () = keyword "in"
and+ body = exp in
Let(n, e, body) in
let fun_exp =
let+ () = keyword "fun"
and+ args = many1 name
and+ () = str "->"
and+ body = exp in
List.fold_right (fun arg body -> Lam(arg, body)) args body in
let_exp <|> fun_exp <|> annot_exp <?> "expression")
let parse (s: string) =
match parse_string ~consume:All (whitespace *> exp) s with
| Ok e -> e
| Error msg -> failwith msg
end
let main () =
let stdin = Stdio.In_channel.(input_all stdin) in
let exp = Parser.parse stdin in
let () = print_endline "parsed" in
let open Infer in
let ctx = initial_ctx in
let ty = infer ctx exp in
print_endline ("input : " ^ print_ty ctx ty)
let () = main () *)