-
Notifications
You must be signed in to change notification settings - Fork 0
/
linear.ml
104 lines (88 loc) · 2.9 KB
/
linear.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
(** Incremental linear algebra solver *)
(*
Copyright 2008, Nathan Mishra-Linger
License: BSD
*)
(* The implementation is inspired by incremental unification as
found in modern type inference implementations. The classic
reference is "Basic Polymorphic Typechecking" by Luca Cardelli.
The linear expression type `t' is a term representation of a
linear combination of distinct variables plus a constant.
a1*x1 + a2*x2 + ... + aN*xN + b
As in efficient unification algorithms, a variable is represented
by a possibly-null pointer to another term. For unsolved
variables, this pointer is null. For solved variables, it points
to the term to which the variable has been set equal.
The "occurs check" of ordinary first order term unification is
not necessary because we can always isolate the variable in
question. (see the code for `equate', which might also be called
`unify')
*)
open Core.Std
include Int.Replace_polymorphic_compare
module Uid = Unique_id.Int (struct end)
module rec Var : sig
type t = {
uid : Uid.t;
mutable value : Comb.t option;
}
include Linear_comb.Var with type t := t
end = struct
type t = {
uid : Uid.t;
mutable value : Comb.t option;
}
let compare t1 t2 = Uid.compare t1.uid t2.uid
let create () = {uid = Uid.create (); value = None}
end
and Comb : Linear_comb.S_concrete with type var = Var.t
= Linear_comb.Make (Var)
include Comb
(* a solved variable is one that has been set equal to a term *)
let solved x =
match x.Var.value with
| None -> false
| Some _ -> true
(* recursively substitute for all solved variables
until only unsolved variables remain *)
let rec subst (Sum (terms, b)) =
let ts, terms =
List.partition_tf terms ~f:(fun (Prod (_, x)) -> solved x)
in
let ts =
List.map ts ~f:(fun (Prod (a, x)) ->
match x.Var.value with
| Some t ->
let t = subst t in
x.Var.value <- Some t;
times a t
| None -> assert false)
in
List.fold_left ~f:plus ~init:(Sum (terms, b)) ts
exception Inconsistent
exception Redundant
(* for numerical stability: *)
let best_coeff (Prod (a1, _) as t1) (Prod (a2, _) as t2) =
if Float.(abs a1 >= abs a2) then t1 else t2
let equate t1 t2 =
(* it is simpler to work with the equation t = 0
where t = t1 - t2 *)
let t = minus t1 t2 in
let Sum (ts, b) = subst t in
(* check degenerate cases *)
match ts with
| [] ->
raise (if Float.equal b 0.0 then Redundant else Inconsistent)
| hd :: tl ->
(* choose a "pivot" *)
let Prod (a, x) = List.fold_left ~f:best_coeff ~init:hd tl in
(* solve for x *)
let ts =
List.filter ts ~f:(fun (Prod (_, y)) ->
not (Uid.equal x.Var.uid y.Var.uid))
in
let t' = Sum (ts, b) in
x.Var.value <- Some (div (negate t') a)
let value t = match subst t with
| Sum ([], c) -> Some c (* no unsolved variables *)
| Sum (_, _) -> None