-
Notifications
You must be signed in to change notification settings - Fork 0
/
partial_derivation.ml
109 lines (100 loc) · 3.68 KB
/
partial_derivation.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
open Core
module T = struct
type unseen =
{ terminal: Program.t
; nonterminals: t list
; nonterminal: Type.t
; log_likelihood: float }
and t = Unseen of unseen | Seen of Type.t
[@@deriving equal, compare, sexp_of, fields]
end
include T
include Comparator.Make (T)
let rec unify (partial_deriv, cxt) =
match partial_deriv with
| Unseen ({nonterminal; _} as unseen) ->
Unseen
{ unseen with
nonterminal= snd @@ Type_context.apply cxt nonterminal
; nonterminals=
List.map unseen.nonterminals ~f:(fun pd -> unify (pd, cxt)) }
| Seen ty ->
Seen (snd @@ Type_context.apply cxt ty)
let to_type = function
| Unseen {nonterminal; _} ->
nonterminal
| Seen nonterminal ->
nonterminal
let rec to_productions = function
| Unseen {nonterminal; terminal; nonterminals; log_likelihood} ->
( nonterminal
, Production.Fields.create ~terminal
~nonterminals:(List.map nonterminals ~f:to_type)
~log_likelihood )
:: List.concat_map nonterminals ~f:to_productions
| Seen _ ->
[]
module Transition = struct
module T = struct
type t = Program.t * int * Program.t [@@deriving equal, compare, sexp_of]
end
include T
include Comparator.Make (T)
end
let rec find ?(type_size_limit = 100)
?(transitions = Set.empty (module Transition))
?(seen_nts = Set.empty (module Structural_type)) ?(completed_nts = [])
?(trans : Transition.t option = None) dsl cxt req =
let completed =
if Set.is_empty seen_nts then []
else
List.filter_map completed_nts ~f:(fun ty ->
try
let cxt' = Type_unification.unify cxt req ty in
let cxt'', req' = Type_context.apply cxt' req in
Some (Seen req', cxt'')
with Type_unification.UnificationFailure _ -> None )
in
if not (List.is_empty completed) then completed
else if
Type.size req > type_size_limit
|| Set.mem seen_nts (Some cxt, req)
|| Option.value_map trans ~default:false ~f:(fun trans ->
Set.mem transitions trans )
then []
else
let seen_nts' =
Set.add seen_nts @@ Structural_type.of_contextual_type cxt req
in
Dsl_unification.expressions dsl [] req cxt
|> List.sort ~compare:(fun u_1 u_2 ->
Int.compare (List.length u_1.parameters) (List.length u_2.parameters) )
|> List.concat_map ~f:(fun unified ->
let transitions' =
Option.value_map trans ~default:transitions
~f:(Set.add transitions)
in
let trans' i =
match trans with
| Some (_, _, parent) ->
Some (parent, i, unified.expr)
| None ->
Some (Primitive {name= "UNK"; ty= req}, i, unified.expr)
in
List.fold unified.parameters
~init:[([], unified.context, 0)]
~f:(fun params_derivs param ->
List.concat_map params_derivs
~f:(fun (params_deriv, cxt', position) ->
let cxt'', param' = Type_context.apply cxt' param in
find ~transitions:transitions' ~seen_nts:seen_nts'
~completed_nts ~trans:(trans' position) dsl cxt'' param'
|> List.map ~f:(fun (param_deriv, cxt''') ->
(param_deriv :: params_deriv, cxt''', position + 1) ) )
)
|> List.map ~f:(fun (params_deriv_rev, cxt', _) ->
( Unseen
(Fields_of_unseen.create ~terminal:unified.expr
~nonterminals:(List.rev params_deriv_rev)
~nonterminal:req ~log_likelihood:unified.log_likelihood )
, cxt' ) ) )