From e4d5c4d286ecc3b32acda881fd86cfb81ce00a5c Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Tue, 2 Jul 2024 18:23:54 -0700 Subject: [PATCH] directories and symbolic resource sharers --- .editorconfig | 2 + .gitignore | 1 + Project.toml | 2 + docs/src/concepts/theory_composition.md | 51 ++ docs/src/examples/springs.md | 5 + docs/src/nonstdlib/resource_sharers.md | 9 + dynamics.md | 111 +++++ src/models/ModelInterface.jl | 12 +- src/models/SymbolicModels.jl | 14 +- src/nonstdlib/dynamics/ResourceSharers.jl | 487 +++++++++++++++++++ src/nonstdlib/module.jl | 2 + src/syntax/GATContexts.jl | 19 - src/syntax/GATs.jl | 3 + src/syntax/Scopes.jl | 3 +- src/syntax/TheoryInterface.jl | 41 +- src/syntax/TheoryMaps.jl | 2 +- src/syntax/gats/algorithms.jl | 39 +- src/syntax/gats/ast.jl | 153 +++++- src/syntax/gats/closures.jl | 241 ++++++++++ src/syntax/gats/exprinterop.jl | 38 +- src/syntax/gats/gat.jl | 2 + src/util/Dtrys.jl | 550 ++++++++++++++++++++++ src/util/MetaUtils.jl | 4 +- src/util/module.jl | 2 + test/Project.toml | 8 + test/nonstdlib/ResourceSharers.jl | 87 ++++ test/nonstdlib/tests.jl | 1 + test/stdlib/Arithmetic.jl | 3 +- test/syntax/GATs.jl | 15 +- test/util/Dtrys.jl | 123 +++++ test/util/MetaUtils.jl | 6 +- test/util/tests.jl | 4 + 32 files changed, 1956 insertions(+), 84 deletions(-) create mode 100644 .editorconfig create mode 100644 docs/src/concepts/theory_composition.md create mode 100644 docs/src/examples/springs.md create mode 100644 docs/src/nonstdlib/resource_sharers.md create mode 100644 dynamics.md create mode 100644 src/nonstdlib/dynamics/ResourceSharers.jl create mode 100644 src/syntax/gats/closures.jl create mode 100644 src/util/Dtrys.jl create mode 100644 test/nonstdlib/ResourceSharers.jl create mode 100644 test/util/Dtrys.jl diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..32bad656 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,2 @@ +[*.jl] +indent_size = 2 diff --git a/.gitignore b/.gitignore index 5d6c467f..17044cba 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ docs/site/ # committed for packages, but should be committed for applications that require a static # environment. Manifest.toml +/coverage/ diff --git a/Project.toml b/Project.toml index 403f9ca7..5bdf7dd5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["AlgebraicJulia Developers"] version = "0.1.3" [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AlgebraicInterfaces = "23cfdc9f-0504-424a-be1f-4892b28e2f0c" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -11,6 +12,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" diff --git a/docs/src/concepts/theory_composition.md b/docs/src/concepts/theory_composition.md new file mode 100644 index 00000000..e0be15e0 --- /dev/null +++ b/docs/src/concepts/theory_composition.md @@ -0,0 +1,51 @@ +# Theory Composition + +As theories get larger, it becomes more and more important to not build the +entire theory from scratch. Not only is this tedious, it is also error-prone. +From the beginning, Catlab and GATlab have supported single inheritance, which +helps to some extent. In this document, we lay out other approaches to composing +theories. + +## Multiple Inheritance + +In a GATlab `@theory`, one can use `using` to take the *union* of one theory +with another theory. + +The way this works is the following. Every time a new theory is created, the new +definitions for that theory form a new scope, with a unique UUID. Union of +theories operates on a scope tag level, taking the union of the sets of UUIDs +and then producing a theory with all the bindings from the scopes tagged by +those UUIDs. + +If we never had to parse user-supplied expressions, then the names of the +operations in the theories wouldn't matter, because identifiers come with scope +tags. However, as a practical matter, we disallow unioning two theories with the +same name declaration. + +That being said, it is fine to union two theories which *overload* the same +declaration. That is, if two theories have the declaration of a name in common, +then they can overload that name as long as they don't give conflicting +overloads, in the same way that overloading methods in Julia works. + +This is akin to the way multiple inheritance works in frameworks such as + +- Haskell typeclasses +- [Object-oriented systems with multiple inheritance, like Scala](https://docs.scala-lang.org/scala3/book/domain-modeling-tools.html#traits) +- [Module inclusion in OCaml](https://cs3110.github.io/textbook/chapters/modules/includes.html) + +## Nesting + +However, there are other ways of composing things like GATlab theories. In +dependently typed languages used for theorem proving, algebraic structures are +often represented by dependent records. For instance, in the agda unimath +library, the [definition of a group](https://github.com/UniMath/agda-unimath/blob/master/src/group-theory/groups.lagda.md) is + +```agda +Semigroup : + (l : Level) → UU (lsuc l) +Semigroup l = Σ (Set l) has-associative-mul-Set + +Group : + (l : Level) → UU (lsuc l) +Group l = Σ (Semigroup l) is-group +``` diff --git a/docs/src/examples/springs.md b/docs/src/examples/springs.md new file mode 100644 index 00000000..f5d2d593 --- /dev/null +++ b/docs/src/examples/springs.md @@ -0,0 +1,5 @@ +# Composition of resource sharers + +```julia + +``` diff --git a/docs/src/nonstdlib/resource_sharers.md b/docs/src/nonstdlib/resource_sharers.md new file mode 100644 index 00000000..84ce3b3f --- /dev/null +++ b/docs/src/nonstdlib/resource_sharers.md @@ -0,0 +1,9 @@ +# Resource Sharers + +```@docs +GATlab.NonStdlib.ResourceSharers.Rhizome +GATlab.NonStdlib.ResourceSharers.ResourceSharer +GATlab.NonStdlib.ResourceSharers.Variable +GATlab.NonStdlib.ResourceSharers.PortVariable +GATlab.NonStdlib.ResourceSharers.ocompose +``` diff --git a/dynamics.md b/dynamics.md new file mode 100644 index 00000000..6930c2df --- /dev/null +++ b/dynamics.md @@ -0,0 +1,111 @@ +# The Road to Dynamical Systems + +## Basic steps + +- [x] Tuple types +- [-] Symbolic functions + Data type: + + ```julia + struct AlgebraicFunction + theory::GAT + args::TypeScope + ret::AlgType + body::AlgTerm + end + ``` + + Affordances: + - [x] DSL for writing down functions, composing, etc. + - [ ] A function `tcompose(t::Dtry{AlgebraicFunction})::AlgebraicFunction`, implementing the Dtry-algebra structure on morphisms + - [ ] Interpret/compile a symbolic function into a real function + - [ ] Serialize symbolic functions +- [ ] Compilation +- [ ] Serialization + +## Lens-based dynamical systems + +- [ ] Arenas + Sketch: + ```julia + struct Arena + in::AlgType + out::AlgType + end + ``` + + Affordances: + - A function `tcompose(arena::Dtry{Arena})::Arena`, implementing the Dtry-algebra structure on objects +- [ ] Multilenses + Sketch: + ```julia + struct MultiLens + inner_boxes::Dtry{Arena} + outer_box::Arena + # used for namespacing `params` in composition, must not overlap with `inner_boxes` + name::Symbol + params::AlgType + # (params, tcompose(inner_boxes[...].out)) -> outer_box.out + output::AlgebraicFunction + # (params, tcompose(inner_boxes[...].out), outer_box.in) -> tcompose(inner_boxes[...].in) + update::AlgebraicFunction + end + ``` + + Affordances: + - A function `ocompose(l::MultiLens, args::Dtry{MultiLens})::MultiLens` implementing the Dtry-multicategory structure +- [ ] Systems + Sketch: + ```julia + struct System + interface::Arena + state::AlgType + params::AlgType + # (params, state) -> interface.out + output::AlgebraicFunction + # (params, state, interface.in) -> state + input::AlgebraicFunction + end + ``` + + Affordances: + - A function `oapply(l::MultiLens, args::Dtry{System})::System` implementing the action of the Dtry-multicategory of multilenses on systems. + +## Resource sharers + +- [ ] Interfaces +- [ ] Rhizomes (epi-mono uwds) + ```julia + struct VariableType + type::AlgType + exposed::Bool + end + + struct Rhizome + boxes::Dtry{Interface} + junctions::Dtry{VariableType} + mapping::Dict{DtryVar, DtryVar} + end + ``` + + Affordances: + - `ocompose(r::Rhizome, rs::Dtry{Rhizome})::Rhizome` + + In `ocompose`, the names of the junctions in the top-level rhizome dominate. +- [ ] Systems + ```julia + struct ResourceSharer + variables::Dtry{VariableType} + params::AlgType + output::AlgType + # (params, state) -> state + update::AlgebraicFunction + # (params, state) -> output + readout::AlgebraicFunction + end + ``` + + Affordances: + - `oapply(r::Rhizome, sharers::Dtry{ResourceSharer})::ResourceSharer` + + In `oapply`, variables get renamed to the junctions that they are attached to. diff --git a/src/models/ModelInterface.jl b/src/models/ModelInterface.jl index 6d008bab..3e5dabe9 100644 --- a/src/models/ModelInterface.jl +++ b/src/models/ModelInterface.jl @@ -644,12 +644,12 @@ function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory) _x = gensym("val") # Map CODOM sorts to whereparam symbols - whereparamdict = OrderedDict(s=>gensym(s.head.name) for s in sorts(codom_theory)) + whereparamdict = OrderedDict(s=>gensym(headof(s).name) for s in sorts(codom_theory)) # New model is parameterized by these types whereparams = collect(values(whereparamdict)) # Julia types of domain sorts determined by theorymap jltype_by_sort = Dict(map(sorts(dom_theory)) do v - v => whereparamdict[AlgSort(tmap(v.method).val)] + v => whereparamdict[AlgSort(tmap(methodof(v)).val)] end) # Create input for instance_code @@ -758,13 +758,9 @@ function to_call_impl(t::AlgTerm, theory::GAT, mod::Union{Symbol,Module}, migrat b = bodyof(t) if GATs.isvariable(t) nameof(b) - elseif GATs.isdot(t) + elseif GATs.isdot(t) impl = to_call_impl(b.body, theory, mod, migrate) - if isnamed(b.head) - Expr(:., impl, QuoteNode(nameof(b.head))) - else - Expr(:ref, impl, getlid(b.head).val) - end + Expr(:., impl, QuoteNode(b.head)) else args = to_call_impl.(argsof(b), Ref(theory), Ref(mod), migrate) name = nameof(headof(b)) diff --git a/src/models/SymbolicModels.jl b/src/models/SymbolicModels.jl index 26f33c38..52dd1c43 100644 --- a/src/models/SymbolicModels.jl +++ b/src/models/SymbolicModels.jl @@ -328,13 +328,13 @@ end function internal_accessors(theory::GAT) map(theory.sorts) do sort - typecon = getvalue(theory[sort.method]) - map(collect(pairs(theory.accessors[sort.method]))) do (i, acc) + typecon = getvalue(theory[methodof(sort)]) + map(collect(pairs(theory.accessors[methodof(sort)]))) do (i, acc) accessor = getvalue(theory[acc]) return_type = getvalue(typecon[typecon.args[i]]) JuliaFunction( name=esc(nameof(getdecl(accessor))), - args=[:(x::$(esc(nameof(sort.head))))], + args=[:(x::$(esc(nameof(sort))))], return_type = typename(theory, return_type), impl=:(x.type_args[$i]) ) @@ -415,11 +415,11 @@ function symbolic_instance_methods( type_con_funs = [] accessors_funs = [] for sort in sorts(theory) - type_con = getvalue(theory[sort.method]) - symgen = symbolic_generator(theorymodule, syntaxname, sort.method, type_con, theory) + type_con = getvalue(theory[methodof(sort)]) + symgen = symbolic_generator(theorymodule, syntaxname, methodof(sort), type_con, theory) push!(type_con_funs, symgen) for binding in argsof(type_con) - push!(accessors_funs, symbolic_accessor(theorymodule, theory, syntaxname, sort.method, binding)) + push!(accessors_funs, symbolic_accessor(theorymodule, theory, syntaxname, methodof(sort), binding)) end end @@ -614,7 +614,7 @@ function parse_json_sexpr(syntax_module::Module, sexpr; theory = theory_module.Meta.theory type_lens = Dict( nameof(getdecl(getvalue(binding))) => length(getvalue(binding).args) - for binding in [theory[sort.method] for sort in sorts(theory)] + for binding in [theory[methodof(sort)] for sort in sorts(theory)] ) function parse_impl(sexpr::Vector, ::Type{Val{:expr}}) diff --git a/src/nonstdlib/dynamics/ResourceSharers.jl b/src/nonstdlib/dynamics/ResourceSharers.jl new file mode 100644 index 00000000..9a3e82f2 --- /dev/null +++ b/src/nonstdlib/dynamics/ResourceSharers.jl @@ -0,0 +1,487 @@ +module ResourceSharers +export Rhizome, @rhizome, ResourceSharer, @resource_sharer, Variable + +using ...Syntax +using ...Util.Dtrys +using ...Util.Dtrys: flatten, node +using ...Syntax.GATs: tcompose +using ...Util +using MLStyle +using OrderedCollections + +""" +A state variable. + +This is used for both the variables of [`ResourceSharer`](@ref) and the +junctions of [`Rhizome`](@ref). + +This simplifies naming, because instead of having tries and an injective +map between them, we can just have a single trie. +""" +struct Variable + exposed::Bool + type::AlgType +end + +""" +A variable in the interface to an inner box of a rhizome. + +Fields: + +- `type`: the type of the variable +- `junction`: the junction that the variable is attached to. +""" +struct PortVariable + type::AlgType + junction::DtryVar +end + +const Interface = Dtry{AlgType} + +""" +A rhizome is a variant of a UWD where the underlying cospan is epi-monic. + +The mathematical theory is developed within (provide link to Lohmayer-Lynch +paper). + +This also differs from the ACSet-based UWDs in the following ways + +1. We use tries instead of numbering things sequentially. +2. All of the "functions out" of the tries are handled by just storing data in +the leaves of the tries, rather than having external "column vectors". +So e.g. we would use `Dtry{Int}` rather than a pair of `Dtry{Nothing}`, +`Dict{DtryVar, Int}`. +3. We use an indexed rather than fibered representation for inner ports +on boxes. That is, we have a function Boxes -> Set rather than a function +InnerPorts -> Boxes. Following 2., this is just stored in the leaf nodes +of the trie of boxes, hence `Dtry{Dtry{PortVariable}}`. +4. We don't have separate sets for outer ports and junctions: rather each +junction is marked as "exposed" or not. This simplifies naming. Also see +[`Variable`](@ref). + +All of this adds up to a convenient and compact representation. + +The namespaces of `boxes` and `junctions` must be disjoint. One way +of enforcing this would be to do something like + +``` +struct Box + ports::Dtry{PortVariable} +end + +struct Rhizome + stuff::Dtry{Union{Variable, Box}} +end +``` + +The only problem is that I'm not sure then what to call the single field in +`Rhizome`. Also, this would require a slight refactor of the rest of the code. +""" +struct Rhizome + theory::GAT + mcompose::AlgClosure + mzero::AlgClosure + boxes::Dtry{Dtry{PortVariable}} + junctions::Dtry{Variable} +end + +""" + ocompose(r::Rhizome, rs::Dtry{Rhizome}) + +This implements the composition operation for the Dtry-multicategory of +rhizomes. This is the better way of doing the [operad of undirected wiring][1]. + +TODO: add link to arXiv paper on Dtry-multicategories once it goes up. + +See [2] for a reference on general T-multicategories, then trust me for now +that Dtry is a cartesian monoid. + +[1]: [The operad of wiring diagrams: formalizing a graphical language for +databases, recursion, and plug-and-play circuits](https://arxiv.org/abs/1305.0297) +[2]: [Higher Operads, Higher Categories](https://arxiv.org/abs/math/0305049) +""" +function ocompose(r::Rhizome, rs::Dtry{Rhizome}) + # paired :: Dtry{Tuple{Dtry{PortVariable}, Rhizome}} + paired = zip(r.boxes, rs) + boxes = flatten(mapwithkey(Dtry{Dtry{PortVariable}}, paired) do k, (interface, r′) + # k :: DtryVar + # interface :: Dtry{PortVariable} + # r′ :: Rhizome + # We want to create the new collection of boxes + + map(Dtry{PortVariable}, r′.boxes) do b + # b :: Dtry{PortVariable} + map(PortVariable, b) do p + # p :: PortVariable + # p.junction :: namespace(b.junctions) + jvar = p.junction + j = r′.junctions[jvar] + if j.exposed == true + # If exposed, then use the junction that the port is connected to + PortVariable(p.type, interface[jvar].junction) + else + # Otherwise, attach to a newly added junction from the rhizome `r′` + # which is attached at path `k` + PortVariable(p.type, k * jvar) + end + end + end + end) + # Add all unexposed junctions + newjunctions = flatten( + map(Dtry{Variable}, rs) do r′ + internal_junctions = filter(j -> !j.exposed, r′.junctions) + if isnothing(internal_junctions) + Dtrys.node(OrderedDict{Symbol, Dtry{Variable}}()) + else + internal_junctions + end + end + ) + + junctions = merge(r.junctions, newjunctions) + + Rhizome(r.theory, r.mcompose, r.mzero, boxes, junctions) +end + +# TODO: this should just use a `toexpr` method +function Base.show(io::IO, r::Rhizome) + print(io, "Rhizome(") + comma = false + traversewithkey(r.junctions) do k, j + if comma + print(io, ", ") + end + comma = true + print(io, k, "::", j.type) + end + println(io, ")") + newline = false + traversewithkey(r.boxes) do k, box + if newline + println(io) + end + newline = true + print(io, k, "(") + comma = false + traversewithkey(box) do k′, port + if comma + print(io, ", ") + end + comma = true + print(io, k′, "::", port.type, " = ", port.junction) + end + print(io, ")") + end +end + +function parse_var(e::Union{Expr, Symbol}) + @match e begin + :_ => PACKAGE_ROOT + a::Symbol => getproperty(PACKAGE_ROOT, a) + Expr(:(.), e′, QuoteNode(x)) => parse_var(e′).x + end +end + +""" +A modified version of the relation macro supporting namespacing + +TODO: this currently does not support unexposed junctions. We should support +unexposed junctions. + +See some examples below: + +``` +@rhizome ThRing R(a, b) begin + X(a, b) + Y(a, c.x = a) +end + +@rhizome ThRing Id(a, b) begin + _(a, b) +end +``` + +We can interpret the first rhizome as follows: + +1. The rhizome has ports that are typed by types within ThRing. Because +`ThRing` has a `default` type, ports that are not explicitly annotated +will be assumed to be of that type. +2. The name of the rhizome is `R`. +3. The namespace for the external ports of the rhizome is `[a, b]`. +As noted in 1., each of these ports is typed with `default`. +""" +macro rhizome(theorymod, head, body) + # macroexpand `theory` to get the actual theory + theory = macroexpand(__module__, :($theorymod.Meta.@theory)) + # parse the name and junctions out of `head` + junctions = OrderedDict{DtryVar, Variable}() + (name, args) = @match head begin + Expr(:call, name::Symbol, args...) => (name, args) + end + for arg in args + (jname, typeexpr) = @match arg begin + jname::Symbol => (jname, :default) + Expr(:(.), _, _) => (arg, :default) + Expr(:(::), jname, type) => (jname, type) + end + v = parse_var(jname) + type = fromexpr(theory, typeexpr, AlgType) + junctions[v] = Variable(true, type) + end + junctions = Dtry(junctions) + + boxes = OrderedDict{DtryVar, Dtry{PortVariable}}() + # for each line in body, add a box to boxes + for line in body.args + (box, args) = @match line begin + _::LineNumberNode => continue + Expr(:call, name, args...) => (parse_var(name), args) + end + interface = OrderedDict{DtryVar, PortVariable}() + for arg in args + (pname, junction) = @match arg begin + pname::Symbol => (pname, pname) + Expr(:(.), _, _) => (arg, arg) + Expr(:(=), pname, junction) => (pname, junction) + Expr(:kw, pname, junction) => (pname, junction) + _ => error("unknown port pattern for box $box: $arg") + end + v = parse_var(pname) + jvar = parse_var(junction) + j = junctions[jvar] + interface[v] = PortVariable(j.type, jvar) + end + boxes[box] = Dtry(interface) + end + :( + $name = $Rhizome( + $theory, + $theorymod.Meta.Constructors.:(+), + $theorymod.Meta.Constructors.zero, + $(Dtry(boxes)), + $(junctions), + ) + ) |> esc +end + +""" +A resource sharer whose variables are namespaced in a trie +and whose update function is symbolic + +TODO: there should be a smart constructor that takes an AlgClosure and finds +the right method of it for the variables and params. +""" +struct ResourceSharer + variables::Dtry{Variable} + params::AlgType + # (tcompose(variables[..].type), params) -> tcompose(variables) + update::AlgMethod + # output::AlgType + # (tcompose(variables[..].type), params) -> output + # observe::AlgClosure +end + +""" +A DSL for writing down resource sharers + +Example: +``` +@resource_sharer ThRing Spring begin + variables = x, v + params = k + update = (state, params) -> (x = state.v, v = -params.k * state.x) +end +``` +""" +macro resource_sharer(theory, name, body) + args = Expr[] + + for line in body.args + @match line begin + _ :: LineNumberNode => nothing + Expr(:(=), arg, val) => + push!(args, Expr(:kw, arg, Expr(:quote, val))) + end + end + + esc(:($name = $ResourceSharer($theory.Meta.theory; $(args...)))) +end + +function parse_namespace(theory::GAT, expr::Expr0) + vs = OrderedDict{DtryVar, AlgType}() + argexprs = @match expr begin + Expr(:tuple, args...) => args + _ => [expr] + end + for arg in argexprs + (vname, typeexpr) = @match arg begin + vname::Symbol => (vname, :default) + Expr(:(.), _, _) => (arg, :default) + Expr(:(::), vname, type) => (vname, type) + end + v = parse_var(vname) + type = fromexpr(theory, typeexpr, AlgType) + vs[v] = type + end + Dtry(vs) +end + +function ResourceSharer(theory::GAT; variables::Expr0, params::Expr0, update::Expr0) + variable_trie = parse_namespace(theory, variables) + variables = map(Variable, variable_trie) do type + Variable(true, type) + end + param_type = tcompose(parse_namespace(theory, params)) + state_type = tcompose(variable_trie) + + update = begin + (argexpr, bodyexpr) = @match update begin + Expr(:(->), args, body) => (args, last(body.args)) + end + statename, paramname = @match argexpr begin + Expr(:tuple, s, p) => (s, p) + end + args = TypeScope(statename => state_type, paramname => param_type) + body = fromexpr(GATContext(theory, args), bodyexpr, AlgTerm) + AlgMethod(args, body, "", LID.([1,2]), state_type) + end + + ResourceSharer(variables, param_type, update) +end + +function Base.show(io::IO, r::ResourceSharer; theory) + println(io, "ResourceSharer:") + print(io, "variables = ", map(v -> string(toexpr(theory, v.type)), String, r.variables)) + println(io, "params = ", toexpr(theory, r.params)) + println(io, "update = ") + show(io, r.update; theory) +end + +""" +If C is a cartesian category, and t1 and t2 are families of objects in C with a function +of finite sets dom(t2) -> dom(t1), then we can use "copy and delete" to form a morphism in C +from tcompose(t1) to tcompose(t2). + +This function implements that operation with the category of AlgTypes. + +Arguments +- `t1` is a family of AlgTypes +- `t2` is a family of pairs consisting of an AlgType and a key in `t1` + +Produces an AlgMethod going from tcompose(t1) to tcompose(first.(t2)) +""" +function pullback(t1::Dtry{AlgType}, t2::Dtry{Tuple{AlgType, DtryVar}}; argname=:x) + ty1 = tcompose(t1) + ty2 = tcompose(map(first, t2)) + ctx = TypeScope(argname => ty1) + x = AlgTerm(ident(ctx; name=argname)) + body = tcompose( + map(AlgTerm, t2) do (_, k) + x[k] + end + ) + AlgMethod(ctx, body, "", [LID(1)], ty2) +end + +""" +If C is a monoidal category with a supply of monoids, and t1 and t2 are families of objects +in C with a function of finite sets dom(t2) -> dom(t1), then we can use the monoid +structure to form a morphism in C from tcompose(t2) to tcompose(t1). + +This function implements that operation with the category of AlgTypes, with a user-supplied +monoid operation. + +Arguments +- `t1` is a family of AlgTypes +- `t2` is a family of pairs consisting of an AlgType and a key in `t1` + +Produces an AlgMethod going from tcompose(first.(t2)) to tcompose(t1) +""" +function pushforward( + t1::Dtry{AlgType}, + t2::Dtry{Tuple{AlgType, DtryVar}}, + mcompose::AlgClosure, + mzero::AlgClosure; + argname=:x +) + preimages = Dict{DtryVar, Vector{DtryVar}}() + traversewithkey(t2) do k, (_, v) + if haskey(preimages, v) + push!(preimages[v], k) + else + preimages[v] = [k] + end + end + ty1 = tcompose(t1) + ty2 = tcompose(map(first, t2)) + ctx = TypeScope(argname => ty2) + x = AlgTerm(ident(ctx; name=argname)) + body = tcompose( + mapwithkey(AlgTerm, t1) do k, _ + foldl( + (term, v) -> first(values(mcompose.methods))(term, x[v]), + preimages[k]; + init=mzero() + ) + end + ) + AlgMethod(ctx, body, "", [LID(1)], ty1) +end + +function oapply(r::Rhizome, sharers::Dtry{ResourceSharer}) + new_variables = filter(v -> !v.exposed, flatten( + map(sharers) do sharer + sharer.variables + end + )) + variables = if !isnothing(new_variables) + merge(r.junctions, new_variables) + else + r.junctions + end + state = map(v -> v.type, variables) + state_type = tcompose(state) + + # full_state : Dtry{Tuple{AlgType, DtryVar}} + # This is a trie mapping all the variables in all of the systems + # to their type, and to the variable in the reduced system that they + # map to + full_state = flatten( + mapwithkey(Dtry{Tuple{AlgType, DtryVar}}, zip(r.boxes, sharers)) do b, (interface, sharer) + mapwithkey(Tuple{AlgType, DtryVar}, sharer.variables) do k, v + if v.exposed + (v.type, interface[k].junction) + else + (v.type, b * k) + end + end + end + ) + + params = tcompose(map(sharer -> sharer.params, sharers)) + + # (full_state, params) -> full_state + orig_update = tcompose(map(sharer -> sharer.update, sharers), [:state, :params]) + + # copy :: state -> full_state + copy = pullback(state, full_state; argname=:state) + + # add :: full_state -> state + add = pushforward(state, full_state, r.mcompose, r.mzero) + + update_ctx = TypeScope(:state => state_type, :params => params) + statevar, paramsvar = AlgTerm.(idents(update_ctx; name=[:state, :params])) + + update_body = add(orig_update(copy(statevar), paramsvar)) + + update = AlgMethod(update_ctx, update_body, "", LID.([1, 2]), state_type) + + ResourceSharer( + variables, + params, + update + ) +end + +end diff --git a/src/nonstdlib/module.jl b/src/nonstdlib/module.jl index be91a91f..5615b1af 100644 --- a/src/nonstdlib/module.jl +++ b/src/nonstdlib/module.jl @@ -4,11 +4,13 @@ using Reexport include("theories/module.jl") include("models/module.jl") +include("dynamics/ResourceSharers.jl") # include("theorymaps/module.jl") # include("derivedmodels/module.jl") @reexport using .NonStdTheories @reexport using .NonStdModels +@reexport using .ResourceSharers # @reexport using .StdTheoryMaps # @reexport using .StdDerivedModels diff --git a/src/syntax/GATContexts.jl b/src/syntax/GATContexts.jl index 060ec958..b8da5614 100644 --- a/src/syntax/GATContexts.jl +++ b/src/syntax/GATContexts.jl @@ -46,23 +46,4 @@ function Base.show(io::IO, p::GATContext) end end -struct SymbolicFunction - theory::GAT - dom::TypeScope - codom::TypeScope - substitution::Vector{AlgTerm} -end - -""" -``` -@symbolic ThRing function v(a, b, c) - (a*b, c, b) -end -``` -""" -macro symbolic(head, body) - fun = parse_function(body) -end - - end # module diff --git a/src/syntax/GATs.jl b/src/syntax/GATs.jl index 19326a16..c05bb345 100644 --- a/src/syntax/GATs.jl +++ b/src/syntax/GATs.jl @@ -15,6 +15,8 @@ using ..Scopes import ..ExprInterop: fromexpr, toexpr import ..Scopes: retag, rename, reident +using ...Util.Dtrys +using AbstractTrees import AlgebraicInterfaces: equations @@ -27,5 +29,6 @@ include("gats/judgments.jl") include("gats/gat.jl") include("gats/exprinterop.jl") include("gats/algorithms.jl") +include("gats/closures.jl") end diff --git a/src/syntax/Scopes.jl b/src/syntax/Scopes.jl index 16e17c64..5d33c80e 100644 --- a/src/syntax/Scopes.jl +++ b/src/syntax/Scopes.jl @@ -175,7 +175,7 @@ end reident(r::Dict{Ident}, x) = x -# XXX we need to make sure we match on just tag and name +# XXX we need to make sure we match on just tag and name function reident(r::Dict{Ident}, x::Ident) haskey(r, x) ? r[x] : x end @@ -726,6 +726,7 @@ end Base.getindex(c::Context, x::Ident) = getbinding(getscope(c, x), x) +getvalue(c::Context, lid::LID) = getvalue(c[lid]) getvalue(c::Context, x::Ident) = getvalue(c[x]) getvalue(c::Context, name::Symbol) = getvalue(c[ident(c; name)]) diff --git a/src/syntax/TheoryInterface.jl b/src/syntax/TheoryInterface.jl index 8eb50f97..2ac0c90c 100644 --- a/src/syntax/TheoryInterface.jl +++ b/src/syntax/TheoryInterface.jl @@ -156,7 +156,8 @@ function theory_impl(head, body, __module__) push!(modulelines, Expr(:toplevel, :(module Meta struct T end - + $(constructor_module(theory)) + @doc ($(Markdown.MD)((@doc $(__module__).$doctarget), $docstr)) const theory = $theory @@ -216,10 +217,44 @@ function invoke_term(theory_module, types, name, args; model=nothing) end end - """ - +Produce a module with an AlgClosure for each declaration in the theory. """ +function constructor_module(theory::GAT) + closures = Dict{Ident, AlgClosure}() + + for segment in allscopes(theory) + for (x, binding) in zip(getidents(segment), getbindings(segment)) + judgment = getvalue(binding) + if judgment isa AlgDeclaration + closures[x] = AlgClosure(theory) + elseif judgment isa AlgTermConstructor + add_method!( + closures[judgment.declaration], + AlgMethod( + judgment.localcontext, + AlgTerm( + judgment.declaration, + x, + AlgTerm.(idents(judgment.localcontext; lid=judgment.args)) + ), + "", + judgment.args, + judgment.type, + ) + ) + end + end + end + + Expr( + :module, + true, + :Constructors, + Expr(:block, [:(const $(nameof(x)) = $f) for (x, f) in closures]...) + ) +end + function mk_struct(s::AlgStruct, mod) fields = map(argsof(s)) do b Expr(:(::), nameof(b), nameof(AlgSort(getvalue(b)))) diff --git a/src/syntax/TheoryMaps.jl b/src/syntax/TheoryMaps.jl index fae0f915..bd2dcfc9 100644 --- a/src/syntax/TheoryMaps.jl +++ b/src/syntax/TheoryMaps.jl @@ -159,7 +159,7 @@ function infer_type(ctx::Context, t::AlgTerm) tc = getvalue(ctx[head]) typed_terms = bind_localctx(ctx, t) typ = bodyof(tc.type) - args = substitute_term.(argsof(typ), Ref(typed_terms)) + args = Vector{AlgTerm}(substitute_term.(argsof(typ), Ref(typed_terms))) AlgType(headof(typ), methodof(typ), args) end end diff --git a/src/syntax/gats/algorithms.jl b/src/syntax/gats/algorithms.jl index b19ee1ac..2ce6d933 100644 --- a/src/syntax/gats/algorithms.jl +++ b/src/syntax/gats/algorithms.jl @@ -5,6 +5,9 @@ Throw an error if a the head of an AlgTerm (which refers to a term constructor) has arguments of the wrong sort. Returns the sort of the term. """ function sortcheck(ctx::Context, t::AlgTerm)::AbstractAlgSort + if isconstant(t) || isannot(t) + return AlgSort(t.body.type) + end t_sub = substitute_funs(ctx, t) if t_sub != t return sortcheck(ctx, t_sub) @@ -22,9 +25,8 @@ function sortcheck(ctx::Context, t::AlgTerm)::AbstractAlgSort type = ctx[t.body] |> getvalue AlgSort(type) elseif isdot(t) + # This looks like it will infinitely recur... AlgSort(ctx, t) - elseif isconstant(t) - AlgSort(t.body.type) end end @@ -183,9 +185,21 @@ function substitute_term(ma::MethodApp{AlgTerm}, subst::Dict{Ident, AlgTerm}) end function substitute_term(ad::AlgDot, subst::Dict{Ident, AlgTerm}) - AlgDot(ad.head, substitute_term(ad.body, subst)) + if istuple(ad.body) + substitute_term(ad.body.body.fields[ad.head], subst) + else + AlgDot(ad.head, substitute_term(ad.body, subst)) + end end +function substitute_term(annot::AlgAnnot, subst::Dict{Ident, AlgTerm}) + # todo: should also substitute in type + AlgAnnot(substitute_term(annot.term, subst), annot.type) +end + +function substitute_term(tup::AlgNamedTuple{AlgTerm}, subst::Dict{Ident, AlgTerm}) + AlgNamedTuple{AlgTerm}(OrderedDict{Symbol, AlgTerm}(n => substitute_term(t, subst) for (n, t) in tup.fields)) +end """Replace all functions with their desugared expressions""" function substitute_funs(ctx::Context, t::AlgTerm) @@ -203,5 +217,22 @@ function substitute_funs(ctx::Context, t::AlgTerm) t elseif isdot(t) AlgTerm(AlgDot(headof(b), substitute_funs(ctx, bodyof(b)))) + elseif isannot(t) + AlgTerm(AlgAnnot(substitute_funs(ctx, t.body.term), t.body.type)) end -end \ No newline at end of file +end + +Base.map(f, t::AlgTerm) = AlgTerm(map(f, bodyof(t))) + +Base.map(f, b::MethodApp{AlgTerm}) = MethodApp{AlgTerm}(headof(b), methodof(b), map(f, argsof(b))) + +Base.map(f, b::AlgDot) = AlgDot(b.head, f(b.body)) + +Base.map(f, b::AlgAnnot) = AlgAnnot(f(b.term), b.type) + +Base.map(f, b::AlgNamedTuple{AlgTerm}) = + AlgNamedTuple{AlgTerm}(OrderedDict{Symbol, AlgTerm}(n => f(v) for (n, v) in b.fields)) + +Base.map(f, b::Ident) = b + +Base.map(f, b::Constant) = b diff --git a/src/syntax/gats/ast.jl b/src/syntax/gats/ast.jl index 8f8c98af..2feb3b93 100644 --- a/src/syntax/gats/ast.jl +++ b/src/syntax/gats/ast.jl @@ -1,6 +1,14 @@ # GAT ASTs ########## +@struct_hash_equal struct AlgNamedTuple{T} + fields::OrderedDict{Symbol, T} +end + +function Base.map(f, t::AlgNamedTuple) + newfields = OrderedDict((x, f(v)) for (x,v) in t.fields) + AlgNamedTuple(newfields) +end # AlgSorts #--------- @@ -12,42 +20,52 @@ abstract type AbstractAlgSort end A *sort*, which is essentially a type constructor without arguments """ @struct_hash_equal struct AlgSort <: AbstractAlgSort - head::Ident - method::Ident + body::Union{Tuple{Ident, Ident}, AlgNamedTuple{AlgSort}} end -function reident(reps::Dict{Ident}, a::AlgSort) - newhead = reident(reps, headof(a)) - newmethod = retag(Dict(a.head.tag => newhead.tag), methodof(a)) - AlgSort(newhead, newmethod) +AlgSort(head::Ident, method::Ident) = AlgSort((head, method)) + +iseq(::AlgSort) = false +istuple(s::AlgSort) = s.body isa AlgNamedTuple + +function reident(reps::Dict{Ident}, s::AlgSort) + if istuple(s) + AlgSort(map(s -> reident(reps, s), s.body)) + else + newhead = reident(reps, headof(s)) + newmethod = retag(Dict(headof(s).tag => newhead.tag), methodof(s)) + AlgSort(newhead, newmethod) + end end + +headof(a::AlgSort) = a.body[1] +methodof(a::AlgSort) = a.body[2] + + """ `AlgSort` A sort for equality judgments of terms for a particular sort """ @struct_hash_equal struct AlgEqSort <: AbstractAlgSort - head::Ident - method::Ident + sort::AlgSort end +AlgEqSort(head::Ident, method::Ident) = AlgEqSort(AlgSort(head, method)) +headof(s::AlgEqSort) = headof(s.sort) +methodof(s::AlgEqSort) = methodof(s.sort) + iseq(::AlgEqSort) = true -iseq(::AlgSort) = false -headof(a::AbstractAlgSort) = a.head -methodof(a::AbstractAlgSort) = a.method -Base.nameof(sort::AbstractAlgSort) = nameof(sort.head) +Base.nameof(sort::AbstractAlgSort) = nameof(headof(sort)) -getdecl(s::AbstractAlgSort) = s.head +getdecl(s::AbstractAlgSort) = headof(s) function reident(reps::Dict{Ident}, a::AlgEqSort) - newhead = reident(reps, headof(a)) - newmethod = retag(Dict(a.head.tag => newhead.tag), methodof(a)) - AlgEqSort(newhead, newmethod) + AlgEqSort(reident(reps, a.sort)) end - """ We need this to resolve a mutual reference loop; the only subtype is Constant """ @@ -106,6 +124,7 @@ abstract type AlgAST end bodyof(t::AlgAST) = t.body +abstract type AbstractAlgAnnot end """ `AlgTerm` @@ -113,7 +132,14 @@ bodyof(t::AlgAST) = t.body One syntax tree to rule all the terms. """ @struct_hash_equal struct AlgTerm <: AlgAST - body::Union{Ident, MethodApp{AlgTerm}, AbstractConstant, AbstractDot} + body::Union{ + Ident, + MethodApp{AlgTerm}, + AbstractConstant, + AbstractDot, + AbstractAlgAnnot, + AlgNamedTuple{AlgTerm} + } end @@ -127,10 +153,18 @@ function AlgTerm(fun::Ident, method::Ident) AlgTerm(MethodApp{AlgTerm}(fun, method, EMPTY_ARGS)) end +AlgTerm(t::AlgTerm) = t + +function AlgTerm(tup::NamedTuple) + AlgTerm(AlgNamedTuple{AlgTerm}(OrderedDict{Symbol, AlgTerm}(n => AlgTerm(v) for (n, v) in pairs(tup)))) +end + isvariable(t::AlgTerm) = t.body isa Ident isapp(t::AlgTerm) = t.body isa MethodApp +istuple(t::AlgTerm) = t.body isa AlgNamedTuple + isdot(t::AlgAST) = t.body isa AlgDot isconstant(t::AlgTerm) = t.body isa AbstractConstant @@ -142,6 +176,18 @@ retag(reps::Dict{ScopeTag, ScopeTag}, t::AlgTerm) = AlgTerm(retag(reps, t.body)) reident(reps::Dict{Ident}, t::AlgTerm) = AlgTerm(reident(reps, t.body)) +function tcompose(t::AbstractDtry{AlgTerm}) + @match t begin + Dtrys.Leaf(v) => v + Dtrys.Node(bs) => + AlgTerm(AlgNamedTuple{AlgTerm}(OrderedDict{Symbol, AlgTerm}( + (n, tcompose(v)) for (n, v) in bs + ))) + Dtrys.Empty() => + AlgTerm(AlgNamedTuple{AlgTerm}(OrderedDict{Symbol, AlgTerm}())) + end +end + function AlgSort(c::Context, t::AlgTerm) t_sub = substitute_funs(c, t) if t_sub != t @@ -154,8 +200,15 @@ function AlgSort(c::Context, t::AlgTerm) value = getvalue(binding) AlgSort(value.type) elseif isdot(t) - algstruct = c[AlgSort(c, bodyof(bodyof(t))).method] |> getvalue - AlgSort(getvalue(algstruct.fields[headof(bodyof(t))])) + parentsort = AlgSort(c, bodyof(bodyof(t))) + if istuple(parentsort) + parentsort.body.fields[headof(bodyof(t))] + else + algstruct = c[methodof(AlgSort(c, bodyof(bodyof(t))))] |> getvalue + AlgSort(getvalue(algstruct.fields[headof(bodyof(t))])) + end + elseif isannot(t) + AlgSort(t.body.type) else # variable binding = c[t.body] AlgSort(getvalue(binding)) @@ -185,7 +238,7 @@ reident(reps::Dict{Ident}, eq::Eq) = Eq(reident.(Ref(reps), eq.equands)) One syntax tree to rule all the types. """ @struct_hash_equal struct AlgType <: AlgAST - body::Union{MethodApp{AlgTerm}, Eq} + body::Union{MethodApp{AlgTerm}, Eq, AlgNamedTuple{AlgType}} end function AlgType(fun::Ident, method::Ident) @@ -200,6 +253,8 @@ isapp(t::AlgType) = t.body isa MethodApp iseq(t::AlgType) = t.body isa Eq +istuple(t::AlgType) = t.body isa AlgNamedTuple + isconstant(t::AlgType) = false AlgType(head::Ident, method::Ident, args::Vector{AlgTerm}) = @@ -218,10 +273,23 @@ function reident(reps::Dict{Ident}, t::AlgType) AlgType(reident(reps, t.body)) end -AlgSort(t::AlgType) = if iseq(t) - AlgEqSort(t.body.sort.head, t.body.sort.method) -else - AlgSort(t.body.head, t.body.method) +function AlgSort(t::AlgType) + if iseq(t) + AlgEqSort(headof(t.body.sort), methodof(t.body.sort)) + elseif istuple(t) + AlgSort(AlgNamedTuple{AlgSort}(OrderedDict{Symbol, AlgSort}(k => AlgSort(v) for (k, v) in t.body.fields))) + else + AlgSort(headof(t.body), methodof(t.body)) + end +end + +function tcompose(t::AbstractDtry{AlgType}) + @match t begin + Dtrys.Node(bs) => + AlgType(AlgNamedTuple(OrderedDict(k => tcompose(v) for (k,v) in AbstractTrees.children(t)))) + Dtrys.Leaf(v) => v + Dtrys.Empty() => AlgType(AlgNamedTuple(OrderedDict{Symbol, AlgType}())) + end end @@ -244,18 +312,51 @@ A Julia value in an algebraic context. Type checked elsewhere. type::AlgType end +""" +An explicitly type-annotated value. +""" +@struct_hash_equal struct AlgAnnot <: AbstractAlgAnnot + term::AlgTerm + type::AlgType +end + +isannot(t::AlgTerm) = t.body isa AbstractAlgAnnot """ Accessing a name from a term of tuple type """ @struct_hash_equal struct AlgDot <: AbstractDot - head::Ident + head::Symbol body::AlgTerm + function AlgDot(head::Symbol, body::AlgTerm) + if istuple(body) + body.body.fields[head] + else + new(head, body) + end + end end headof(a::AlgDot) = a.head bodyof(a::AlgDot) = a.body +function Base.getindex(a::AlgTerm, v::DtryVar) + @match v begin + Dtrys.Root() => a + Dtrys.Nested((n, v′)) => getindex(AlgTerm(AlgDot(n, a)), v′) + end +end + +function Base.getproperty(a::AlgTerm, n::Symbol) + if n == :body + # this is a hack: we should instead replace everywhere we use t.body + # with `getbody(t)` or something like it + getfield(a, :body) + else + AlgTerm(AlgDot(n, a)) + end +end + # Type Contexts ############### diff --git a/src/syntax/gats/closures.jl b/src/syntax/gats/closures.jl new file mode 100644 index 00000000..366cc40b --- /dev/null +++ b/src/syntax/gats/closures.jl @@ -0,0 +1,241 @@ +export AlgMethod, AlgClosure, add_method!, @algebraic + +using ...Util.MetaUtils + +""" +A method of an AlgClosure. Can also be used standalone. +""" +struct AlgMethod + context::TypeScope + body::AlgTerm + docstring::String + args::Vector{LID} + ret::Union{AlgType, Nothing} + function AlgMethod( + context::TypeScope, + body::Any, + docstring::String="", + args::Vector{LID}=LID.(1:length(context)), + ret::Union{AlgType, Nothing}=nothing, + ) + new(context, AlgTerm(body), docstring, args, ret) + end +end + +""" +Currently does not do any sort/typechecking. AlgClosure does sort checking, +so when called via AlgClosure this works as normal. +""" +function (m::AlgMethod)(argvals::Any...) + substitution = Dict{Ident, AlgTerm}(map(zip(getidents(m.context), [argvals...])) do (arg, val) + if val isa AlgTerm + arg => val + else + arg => AlgTerm(Constant(val, getvalue(m.context, arg))) + end + end) + substitute_term(m.body, substitution) +end + +""" +This implements the Dtry-algebra structure on the multicategory of types and +AlgMethods. +""" +function tcompose(ms::Dtry{AlgMethod}, argnames::Vector{Symbol}) + # First check that all methods have argument contexts of the same + # length/variable names + # argnames = ... + # For now pass in argnames + + # Then create a new typescope with the same variable names, but with types + # given by tcompose of the argument types of the ms + contexts = map(m -> m.context, TypeScope, ms) + context = TypeScope( + Scope([x => tcompose(map(ctx -> getvalue(ctx, LID(i)), AlgType, contexts)) for (i,x) in enumerate(argnames)]...) + ) + + # Then for each method, create a new body by applying it to the variables + # in the new scope with the appropriate AlgDots added + bodies = mapwithkey(AlgTerm, ms) do k, m + m([AlgTerm(x)[k] for x in getidents(context)]...) + end + + # Finally, compose all of the bodies into an expression creating an + # AlgNamedTuple + body = Dtrys.fold( + AlgTerm(AlgNamedTuple(OrderedDict{Symbol, AlgTerm}())), + x -> x, + d -> AlgTerm(AlgNamedTuple{AlgTerm}(d)), + bodies + ) + + ret = try + tcompose(map(m -> m.ret, ms)) + catch _ + nothing + end + + AlgMethod( + context, + body, + "", + LID.(eachindex(argnames)), + ret + ) +end + +""" +A standalone, anonymous symbolic function. May have multiple methods. +""" +struct AlgClosure + theory::GAT + methods::Dict{AlgSorts, AlgMethod} + function AlgClosure( + theory::GAT, + methods::Dict{AlgSorts, AlgMethod} = Dict{AlgSorts, AlgMethod}() + ) + new(theory, methods) + end +end + +function add_method!(f::AlgClosure, m::AlgMethod) + sorts = [AlgSort(getvalue(m.context, i)) for i in m.args] + if haskey(f.methods, sorts) + error("attempted to overload a pre-existing sort signature") + end + f.methods[sorts] = m +end + +function tcompose(fs::Dtry{AlgClosure}, argnames::Vector{Symbol}) + ms = map(f -> only(values(f.methods)), fs) + m = tcompose(ms, argnames) + f = AlgClosure(first(fs).theory) + add_method!(f, m) + f +end + +function (f::AlgClosure)(argvals::Any...) + if length(f.methods) > 1 && any(!(x isa AlgTerm) for x in argvals) + error("cannot infer type of non-AlgTerm value $x") + end + sorts = AlgSort[sortcheck.(Ref(f.theory), argvals)...] + if !haskey(f.methods, sorts) + error("no method with argument sorts $sorts found") + end + m = f.methods[sorts] + if m.args != LID.(1:length(m.context)) + error("context inference not yet supported") + end + m(argvals...) +end + +function Base.show(io::IO, f::AlgClosure) + m = only(values(f.methods)) + fndef = Expr(:(->), Expr(:tuple, toexpr(f.theory, m.context).args...), toexpr(GATContext(f.theory, m.context), m.body)) + println(io, "AlgClosure in theory $(nameof(f.theory)) with definition:") + print(io, fndef) +end + +function Base.show(io::IO, m::AlgMethod; theory) + fndef = Expr(:(->), Expr(:tuple, toexpr(theory, m.context).args...), toexpr(GATContext(theory, m.context), m.body)) + print(io, fndef) +end + +function strip_annot(t::AlgTerm) + if isannot(t) + strip_annot(t.body.term) + else + map(strip_annot, t) + end +end + +""" +This constructs an algebraic closure. + +Use: + +```julia +@algebraic ThRing f(x, y) + x^2 + 2*x*y + y^2 +end +``` + +This expands in the following way. + +First, we create an outer let binding that binds all declarations in the theory to +the term-creating functions defined in the theory. So this would look something like: + +``` +let + + = ThRing.Meta.Constructors.+ + * = ThRing.Meta.Constructors.* + ... +end +``` + +Then, we bind the arguments to the function + +``` +let + ... + x = AlgTerm(\$(ident(args; name=:x))) + y = AlgTerm(\$(ident(args; name=:y))) + ... +end +``` + +Here `args` is a TypeScope that we generate by parsing the arguments to `fn`. + +Then, we *evaluate* the body at runtime, and create the algebraic function. +This means that arbitrary Julia code can go in the body, including calling +other algebraic functions that happen to be in scope. + +So, for instance, something like + +``` +double = true + +@algebraic ThRing f(x, y) + if double + (x + y) + (x + y) + else + x + y + end +end +``` + +would produce the same result as + +``` +@algebraic ThRing f(x, y) + (x + y) + (x + y) +end +``` +""" +macro algebraic(theorymodule, fn) + theory = macroexpand(__module__, :($theorymodule.Meta.@theory)) + fn = parse_function(fn, :default) + scope = fromexpr(theory, Expr(:vect, fn.args...), TypeScope) + esc( + Expr( + :(=), + fn.name, + Expr( + :let, + Expr( + :block, + [:($n = $theorymodule.Meta.Constructors.$n) for n in nameof.(declarations(theory))]..., + [:($(nameof(x)) = $(AlgTerm(AlgAnnot(AlgTerm(x), getvalue(scope, x))))) for x in getidents(scope)]..., + :(__body = $strip_annot($AlgTerm($(fn.impl)))), + :(__f = $AlgClosure($theory)), + :(__m = $AlgMethod($scope, __body)) + ), + Expr( + :block, + :($add_method!(__f, __m)), + :__f + ) + ) + ) + ) +end diff --git a/src/syntax/gats/exprinterop.jl b/src/syntax/gats/exprinterop.jl index c49ce7da..0896f9b2 100644 --- a/src/syntax/gats/exprinterop.jl +++ b/src/syntax/gats/exprinterop.jl @@ -19,8 +19,8 @@ function toexpr(c::Context, m::MethodApp) Expr(:call, toexpr(c, m.head), toexpr.(Ref(c), m.args)...) end -function toexpr(c::Context, m::AlgDot) - Expr(:., toexpr(c, m.body), QuoteNode(m.head)) +function toexpr(c::Context, m::AlgDot; kw...) + Expr(:., toexpr(c, m.body; kw...), QuoteNode(m.head)) end function fromexpr(c::GATContext, e, ::Type{AlgTerm}) @@ -36,11 +36,22 @@ function fromexpr(c::GATContext, e, ::Type{AlgTerm}) end Expr(:., body, QuoteNode(head)) => begin t = fromexpr(c, body, AlgTerm) - algstruct = c[AlgSort(c, t).method] |> getvalue - AlgTerm(AlgDot(ident(algstruct.fields; name=head), t))# , str)) - end + AlgTerm(AlgDot(head, t)) + end Expr(:call, head::Symbol, argexprs...) => AlgTerm(parse_methodapp(c, head, argexprs)) Expr(:(::), val, type) => AlgTerm(Constant(val, fromexpr(c, type, AlgType))) + Expr(:tuple, kvs...) => AlgTerm( + AlgNamedTuple{AlgTerm}( + OrderedDict{Symbol, AlgTerm}( + map(kvs) do kv + @match kv begin + Expr(:(=), k, v) => (k => fromexpr(c, v, AlgTerm)) + _ => error("expected key-value pairs inside tuple") + end + end + ) + ) + ) e::Expr => error("could not parse AlgTerm from $e") constant::Constant => AlgTerm(constant) end @@ -57,6 +68,13 @@ function fromexpr(p::GATContext, e, ::Type{AlgType})::AlgType AlgType(parse_methodapp(p, head, args)) Expr(:call, :(==), lhs, rhs) => AlgType(p, fromexpr(p, lhs, AlgTerm), fromexpr(p, rhs, AlgTerm)) + Expr(:tuple, args...) => begin + fields = OrderedDict{Symbol, AlgType}() + for arg in args + parse_binding_expr!(p, b -> (fields[nameof(b)] = getvalue(b)), arg) + end + AlgType(AlgNamedTuple{AlgType}(fields)) + end _ => error("could not parse AlgType from $e") end end @@ -70,6 +88,8 @@ function toexpr(c::Context, type::AlgType) end elseif iseq(type) Expr(:call, :(==), toexpr.(Ref(c), type.body.equands)...) + elseif istuple(type) + Expr(:tuple, [Expr(:(::), k, toexpr(c, v)) for (k, v) in type.body.fields]...) end end @@ -87,6 +107,14 @@ end toexpr(c::Context, constant::Constant; kw...) = Expr(:(::), constant.value, toexpr(c, constant.type; kw...)) +toexpr(c::Context, annot::AlgAnnot; kw...) = + Expr(:(::), toexpr(c, annot.term; kw...), toexpr(c, annot.type; kw...)) + +# toexpr(c::Context, annot::AlgAnnot; kw...) = toexpr(c, annot.term; kw...) + +toexpr(c::Context, t::AlgNamedTuple{AlgTerm}; kw...) = + Expr(:tuple, [Expr(:(=), k, toexpr(c, v; kw...)) for (k, v) in t.fields]...) + function fromexpr(c::GATContext, e, ::Type{InCtx{T}}; kw...) where T (termexpr, localcontext) = @match e begin Expr(:call, :(⊣), binding, tscope) => (binding, fromexpr(c, tscope, TypeScope)) diff --git a/src/syntax/gats/gat.jl b/src/syntax/gats/gat.jl index eb6841a6..a65d7a0a 100644 --- a/src/syntax/gats/gat.jl +++ b/src/syntax/gats/gat.jl @@ -379,3 +379,5 @@ end """Get all structs in a theory""" structs(t::GAT) = AlgStruct[getvalue(t[methodof(s)]) for s in struct_sorts(t)] + +declarations(t::GAT) = keys(t.resolvers) diff --git a/src/util/Dtrys.jl b/src/util/Dtrys.jl new file mode 100644 index 00000000..02375371 --- /dev/null +++ b/src/util/Dtrys.jl @@ -0,0 +1,550 @@ +module Dtrys +export Dtry, NonEmptyDtry, AbstractDtry, PACKAGE_ROOT, ■, DtryVar, + filtermap, mapwithkey, traversewithkey + +using AbstractTrees +using OrderedCollections +using MLStyle +using StructEquality + +""" +An internal node of a [`Dtry`](@ref). Should not be used outside of this module. + +Cannot be empty. +""" +@struct_hash_equal struct Node_{X} + branches::OrderedDict{Symbol, X} + function Node_{X}(branches::OrderedDict{Symbol, X}) where {X} + if length(branches) == 0 + error( + """ + Attempted to make a Node with no branches. This is an error because + tries can only be empty at the top level. + """ + ) + end + new{X}(branches) + end + function Node_(branches::OrderedDict{Symbol, X}) where {X} + Node_{X}(branches) + end +end + +""" +A leaf node of a [`Dtry`](@ref). Should not be used outside of this module. +""" +@struct_hash_equal struct Leaf_{A} + value::A +end + +abstract type AbstractDtry{A} end + +function Base.:(==)(t1::AbstractDtry, t2::AbstractDtry) + content(t1) == content(t2) +end + +""" +A non-empty trie. + +See the docs for [`Dtry`](@ref) for general information about tries; these +docs are specific to the reasoning behind having a non-empty variant. + +We use non-empty tries because we don't want to worry about the difference +between the empty tuple `()` and a named tuple `(a::())`. In either one, the +set of valid paths is the same. Thus, we only allow a trie to be empty at +the toplevel, and subtries must be non-empty. So the self-similar recursion +happens for non-empty tries, while [`Dtry`](@ref) is just a wrapper around +`Union{NonEmptyDtry{A}, Nothing}`. +""" +struct NonEmptyDtry{A} <: AbstractDtry{A} + content::Union{Leaf_{A}, Node_{NonEmptyDtry{A}}} +end + +content(t::NonEmptyDtry) = getfield(t, :content) + +@active Node(t) begin + inner = content(t) + if inner isa Node_ + Some(inner.branches) + end +end + +@active Leaf(t) begin + inner = content(t) + if inner isa Leaf_ + Some(inner.value) + end +end + +""" +A possibly-empty [trie][1]. + +We use `trie` in a slightly idiosyncratic way here. + +1. Branches are indexed by `Symbol`s rather than `Char`s +2. We only store values at leaf nodes rather than internal nodes. + +One way of slickly defining NonEmptyDtry is that it is the free monad on the +polynomial + +p = ∑_{U ⋐ Symbol, U ≠ ∅} y^U + +Then Dtry = NonEmptyDtry + 1. + +[1]: https://en.wikipedia.org/wiki/Dtry +""" +struct Dtry{A} <: AbstractDtry{A} + content::Union{Nothing, NonEmptyDtry{A}} +end + +function content(p::Dtry) + unwrapped = getfield(p, :content) + if !isnothing(unwrapped) + content(unwrapped) + end +end + +@active Empty(t) begin + isnothing(content(t)) +end + +@active NonEmpty(t) begin + if t isa NonEmptyDtry + Some(t) + elseif t isa Dtry + c = getfield(t, :content) + if !isnothing(c) + Some(c) + end + end +end + +""" +Construct a new Dtry node. +""" +function node(d::OrderedDict{Symbol, <:AbstractDtry{A}}) where {A} + nonempties = OrderedDict{Symbol, NonEmptyDtry{A}}() + for (k, t) in pairs(d) + @match t begin + NonEmpty(net) => begin + nonempties[k] = net + end + Empty() => nothing + end + end + if length(nonempties) > 0 + Dtry{A}(NonEmptyDtry{A}(Node_{NonEmptyDtry{A}}(nonempties))) + else + Dtry{A}() + end +end + +node(ps::Pair{Symbol, T}...) where {A, T<:AbstractDtry{A}} = node(OrderedDict{Symbol, T}(ps...)) + +node(g::Base.Generator) = node(OrderedDict(g)) + +leaf(x::A) where {A} = Dtry{A}(NonEmptyDtry{A}(Leaf_{A}(x))) + +Dtry{A}() where {A} = Dtry{A}(nothing) + +struct DtryIndexError <: Exception + trie::Dtry + key::Symbol +end + +struct DtryDerefError <: Exception + trie::Dtry +end + +function Base.getproperty(t::AbstractDtry{A}, n::Symbol) where {A} + @match t begin + Node(bs) => + if haskey(bs, n) + bs[n] + else + throw(DtryIndexError(t, n)) + end + _ => throw(DtryIndexError(t, n)) + end +end + +function Base.filter(f, t::AbstractDtry{A}) where {A} + @match t begin + Leaf(v) => f(v) ? t : Dtry{A}() + Node(bs) => begin + bs′ = OrderedDict{Symbol, NonEmptyDtry{A}}() + for (n, s) in bs + @match filter(f, s) begin + NonEmpty(net) => (bs′[n] = net) + Empty() => nothing + end + end + node(bs′) + end + Empty() => t + end +end + +""" + filtermap(f, return_type::Type, t::AbstractDtry) + +Map the function `f : eltype(t) -> Union{Some{return_type}, Nothing}` over the +trie `t` to produce a trie of type `return_type`, filtering out the elements of +`t` on which `f` returns `nothing`. We pass `return_type` explicitly so that in +the case `t` is the empty trie this doesn't return `Dtry{Any}`. +""" +function filtermap(f, return_type::Type, t::AbstractDtry) + @match t begin + Leaf(v) => begin + @match f(v) begin + Some(v′) => leaf(v′) + nothing => Dtry{return_type}() + end + end + Node(bs) => begin + bs′ = OrderedDict{Symbol, NonEmptyDtry{return_type}}() + for (n, s) in bs + @match filtermap(f, return_type, s) begin + NonEmpty(net) => (bs′[n] = net) + Empty() => nothing + end + end + node(bs′) + end + Empty() => t + end +end + +""" + zipwith(f, t1::AbstractDtry, t2::AbstractDtry) + +Produces a new trie whose leaf node at a path `p` is given by `f(t1[p], t2[p])`. + +Throws an error if `t1` and `t2` are not of the same shape: i.e. they don't +have the exact same set of paths. +""" +function zipwith(f, t1::AbstractDtry{A1}, t2::AbstractDtry{A2}) where {A1, A2} + @match (t1, t2) begin + (Leaf(v1), Leaf(v2)) => leaf(f(v1, v2)) + (Node(bs1), Node(bs2)) => begin + keys(bs1) == keys(bs2) || error("cannot zip two tries not of the same shape") + node(OrderedDict(n => zipwith(f, s1, s2) for ((n, s1), (_, s2)) in zip(bs1, bs2))) + end + (Empty(), Empty()) => Dtry{Core.Compiler.return_type(f, Tuple{A1, A2})}() + _ => error("cannot zip two tries not of the same shape") + end +end + +""" + zip(t1, t2) + +Produces a new trie whose leaf node at a path `p` is given by `(t1[p], t2[p])`. + +Throws an error if `t1` and `t2` are not of the same shape: i.e. they don't +have the exact same set of paths. +""" +Base.zip(t1::AbstractDtry, t2::AbstractDtry) = zipwith((a,b) -> (a,b), t1, t2) + +Base.getindex(p::Dtry, n::Symbol) = getproperty(p, n) + +function Base.getindex(t::AbstractDtry{A})::A where {A} + @match t begin + Leaf(v) => v + _ => throw(DtryDerefError(t)) + end +end + +function Base.first(t::AbstractDtry) + @match t begin + Leaf(v) => v + Node(bs) => first(first(values(bs))) + Empty() => error("cannot take the first value of an empty trie") + end +end + +function Base.hasproperty(t::AbstractDtry, n::Symbol) + @match t begin + Node(bs) => haskey(bs, n) + _ => false + end +end + +function Base.propertynames(t::AbstractDtry) + @match t begin + Node(bs) => keys(bs) + _ => Symbol[] + end +end + +Base.keys(t) = Base.propertynames(t) +Base.valtype(t::AbstractDtry{A}) where {A} = A +Base.valtype(::Type{<:AbstractDtry{A}}) where {A} = A +Base.eltype(t::AbstractDtry{A}) where {A} = A +Base.eltype(::Type{<:AbstractDtry{A}}) where {A} = A + +""" + map(f, return_type::Type, t::AbstractDtry) + +Produce a new trie of the same shape as `t` where the value at a path `p` is +given by `f(t[p])`. We pass in the return type explicitly so that in the case +that `t` is empty we don't get `Dtry{Any}`. + +There is a variant defined later where `return_type` is not passed in, and it +tries to use type inference from the Julia compiler to infer `return_type`: use +of this should be discouraged. +""" +function Base.map(f, return_type::Type, t::AbstractDtry) + @match t begin + Leaf(v) => leaf(f(v)) + Node(bs) => node( + OrderedDict{Symbol, Dtry{return_type}}( + (n => map(f, return_type, t′)) for (n, t′) in bs + ) + ) + Empty() => Dtry{return_type}() + end +end + +""" + map(f, t::AbstractDtry) + +Variant of `map(f, return_type, t::AbstractDtry)` which attempts to infer the +return type of `f`. +""" +function Base.map(f, t::AbstractDtry{A}) where {A} + B = Core.Compiler.return_type(f, Tuple{A}) + map(f, B, t) +end + +""" + flatten(t::AbstractDtry{Dtry{A}}) + +The monad operation for Dtrys. Works on NonEmptyDtrys and Dtrys. +""" +function flatten(t::AbstractDtry{Dtry{A}}) where {A} + @match t begin + Leaf(v) => v + # Note that if flatten(v) is empty, the `node` constructor will + # automatically remove it from the built trie. + Node(bs) => node(n => flatten(v) for (n, v) in bs) + Empty() => Dtry{A}() + end +end + +struct DtryVar + content::Union{Nothing, Tuple{Symbol, DtryVar}} +end + +PACKAGE_ROOT = DtryVar(nothing) +■ = PACKAGE_ROOT + +content(v::DtryVar) = getfield(v, :content) + +@active Root(v) begin + isnothing(content(v)) +end + +@active Nested(v) begin + inner = content(v) + if !isnothing(inner) + Some(inner) + end +end + +function Base.getproperty(v::DtryVar, n::Symbol) + @match v begin + Root() => DtryVar((n, v)) + Nested((n′, v′)) => DtryVar((n′, getproperty(v′, n))) + end +end + +function Base.:(*)(v1::DtryVar, v2::DtryVar) + @match v1 begin + Root() => v2 + Nested((n, v1′)) => DtryVar((n, v1′ * v2)) + end +end + +""" + mapwithkey(f, return_type::Type, t::AbstractDtry) + +Constructs a new trie with the same shape as `t` where the value at the path +`p` is `f(p, t[p])`. +""" +function mapwithkey(f, return_type::Type, t::AbstractDtry; prefix=PACKAGE_ROOT) + @match t begin + Leaf(v) => leaf(f(prefix, v)) + Node(bs) => + node([k => mapwithkey(f, return_type, v; prefix=getproperty(prefix, k)) for (k, v) in bs]...) + Empty() => Dtry{return_type}() + end +end + +""" + traversewithkey(f, t::AbstractDtry; prefix=PACKAGE_ROOT) + +Similar to [`mapwithkey`](@ref) but just evaluates `f` for its side effects +instead of constructing a new trie. +""" +function traversewithkey(f, t::AbstractDtry; prefix=PACKAGE_ROOT) + @match t begin + Leaf(v) => (f(prefix, v); nothing) + Node(bs) => begin + for (k, v) in bs + traversewithkey(f, v; prefix=getproperty(prefix, k)) + end + end + Empty() => nothing + end +end + +""" + fold(emptycase::A, leafcase, nodecase, t::AbstractDtry)::A + +Fold over `t` to produce a single value. + +Args: +- `emptycase::A` +- `leafcase::eltype(t) -> A` +- `nodecase::OrderedDict{Symbol, A} -> A` +""" +function fold(emptycase::A, leafcase, nodecase, t::AbstractDtry)::A where {A} + @match t begin + Empty() => emptycase + Leaf(v) => leafcase(v) + Node(bs) => nodecase(OrderedDict(k => fold(emptycase, leafcase, nodecase, v) for (k,v) in bs)) + end +end + +""" + all(f, t::AbstractDtry) + +Checks if `f` returns `true` when applied to all of the elements of `t`. + +Args: +- `f::eltype(t) -> Bool` +""" +function Base.all(f, t::AbstractDtry) + fold(true, f, d -> all(values(d)), t) +end + +# precondition: the union of the keys in t1 and t2 is prefix-free +function Base.merge(t1::AbstractDtry{A}, t2::AbstractDtry{A}) where {A} + @match (t1, t2) begin + (Leaf(_), _) || (_, Leaf(_)) => + error("cannot merge tries with overlapping keys") + (Empty(), _) => t2 + (_, Empty()) => t1 + (Node(b1), Node(b2)) => begin + b = OrderedDict{Symbol, NonEmptyDtry{A}}() + for (n, t) in b1 + if haskey(b2, n) + @match merge(t, b2[n]) begin + NonEmpty(t′) => (b[n] = t′) + _ => nothing + end + else + b[n] = t + end + end + for (n, t) in b2 + if !haskey(b1, n) + b[n] = t + end + end + node(b) + end + end +end + +""" +Make a trie out of a dict from trie keys to values. + +Fails if the keys in the dict are not prefix-free. +""" +function Dtry(d::OrderedDict{DtryVar, A}) where {A} + branches = OrderedDict{Symbol, OrderedDict{DtryVar, A}}() + for (v, x) in d + @match v begin + Root() => + if length(d) == 1 + return leaf(x) + else + error("attempted trie conversion failed because keys were not prefix-free") + end + Nested((n, v′)) => + if haskey(branches, n) + branches[n][v′] = x + else + branches[n] = OrderedDict(v′ => x) + end + end + end + node(n => Dtry(d′) for (n, d′) in branches) +end + +Dtry(pairs::Pair{DtryVar, A}...) where {A} = Dtry(OrderedDict{DtryVar, A}(pairs...)) + +function Base.show(io::IO, v::DtryVar) + print(io, "■") + while true + @match v begin + Root() => break + Nested((n, v′)) => begin + print(io, ".", n) + v = v′ + end + end + end +end + +function Base.haskey(t::AbstractDtry, v::DtryVar) + @match (t, v) begin + (Leaf(_), Root) => true + (Node(_), Nested((n, v))) => hasproperty(t, n) && haskey(getproperty(t, n), v) + _ => false + end +end + +struct DtryVarNotFound <: Exception + p::Dtry + v::DtryVar +end + +function Base.getindex(t::AbstractDtry, v::DtryVar) + @match (t, v) begin + (Leaf(v), Root) => v + (Node(_), Nested((n, v′))) => + if hasproperty(t, n) + getproperty(t, n)[v′] + else + throw(DtryVarNotFound(t, v)) + end + _ => throw(DtryVarNotFound(t, v)) + end +end + +function AbstractTrees.children(t::AbstractDtry) + @match t begin + Leaf(_) => () + Node(bs) => bs + end +end + +function AbstractTrees.printnode(io::IO, t::AbstractDtry{A}; kw...) where {A} + @match t begin + Leaf(v) => print(io, v) + Node(bs) => print(io, nameof(typeof(t)), "{$A}") + Empty() => print(io, "{}") + end +end + +function Base.show(io::IO, t::AbstractDtry{A}) where {A} + @match t begin + Leaf(v) => print(io, "leaf(", v, ")::$(nameof(typeof(t))){$A}") + Node(_) => print_tree(io, t) + Empty() => print(io, "Dtry{$A}()") + end +end + +end diff --git a/src/util/MetaUtils.jl b/src/util/MetaUtils.jl index a3793808..41d6b59a 100644 --- a/src/util/MetaUtils.jl +++ b/src/util/MetaUtils.jl @@ -80,7 +80,7 @@ end """ Parse Julia function definition into standardized form. """ -function parse_function(expr::Expr)::JuliaFunction +function parse_function(expr::Expr, default_type=:Any)::JuliaFunction doc, expr = parse_docstring(expr) fun_expr, impl = @match expr begin Expr(:(=), args...) => args @@ -106,7 +106,7 @@ function parse_function(expr::Expr)::JuliaFunction args = map(args) do arg @match arg begin Expr(:(::), x, T) => arg - x::Symbol => :($x::Any) + x::Symbol => :($x::$(default_type)) Expr(:(::), T) => Expr(:(::), gensym(), T) _ => throw(ParseError("Ill-formed argument expression $arg")) end diff --git a/src/util/module.jl b/src/util/module.jl index 2caa99f0..92bd76dc 100644 --- a/src/util/module.jl +++ b/src/util/module.jl @@ -2,10 +2,12 @@ module Util using Reexport +include("Dtrys.jl") include("MetaUtils.jl") include("HashColor.jl") include("Eithers.jl") +@reexport using .Dtrys @reexport using .MetaUtils @reexport using .HashColor @reexport using .Eithers diff --git a/test/Project.toml b/test/Project.toml index 52861d46..0599b75c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,12 @@ [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +GATlab = "f0ffcf3b-d13a-433e-917c-cc44ccf5ead2" +MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/nonstdlib/ResourceSharers.jl b/test/nonstdlib/ResourceSharers.jl new file mode 100644 index 00000000..d5a6fcb8 --- /dev/null +++ b/test/nonstdlib/ResourceSharers.jl @@ -0,0 +1,87 @@ +module TestResourceSharers + +using Test + +using GATlab.Util.Dtrys +using GATlab.Util.Dtrys: node, leaf +using GATlab.NonStdlib.ResourceSharers +using GATlab.NonStdlib.ResourceSharers: ocompose, oapply +using GATlab +using GATlab.Syntax.GATs: tcompose +using ComponentArrays + +import Base: +, *, - + +@theory ThRing begin + default::TYPE + zero::default + one::default + ((x::default) + (y::default))::default + ((x::default) * (y::default))::default + (-(x::default))::default +end + +@rhizome ThRing rtop(a, b) begin + X(a) + Y(a, c.x = b) +end + +@rhizome ThRing rX(a) begin + A(x = a, y = a) +end + +@rhizome ThRing rY(a, c.x) begin + A(x = a, y = a) + B(t = c.x) +end + +@test sprint(show, rtop) isa String + +r = ocompose(rtop, Dtry(■.X => rX, ■.Y => rY)) + +@resource_sharer ThRing Spring begin + variables = x, v + params = k + update = (state, params) -> (x = state.v, v = -params.k * state.x) +end; + +@resource_sharer ThRing Gravity begin + variables = v + params = g + update = (state, params) -> (v = - params.g,) +end; + +@rhizome ThRing SpringGravity(x, v) begin + spring(x, v) + gravity(v) +end + +@test sprint((io, r) -> show(io, r; theory=ThRing.Meta.theory), Gravity) isa String + +s = oapply(SpringGravity, Dtry(■.spring => Spring, ■.gravity => Gravity)); + +body = toexpr(GATContext(ThRing.Meta.theory, s.update.context), s.update.body) + +zero() = 0.0 + +eval( + :(update(state, params) = ComponentArray(;$(body.args...))) +) + +update((x = 0.0, v = 1.0), (spring = (k = 1.0,), gravity = (g = 9.8,))) + +init = ComponentArray(x = 0.0, v = 1.0) +params = ComponentArray(spring = (k = 1.0,), gravity = (g = 9.8,)) + +function euler(init, params, v, dt, steps) + values = Vector{typeof(init)}(undef, steps+1) + values[1] = init + for i in 1:steps + values[i+1] = values[i] .+ (dt .* v(values[i], params)) + end + values +end + +traj = euler(init, params, update, 0.1, 100); + +end diff --git a/test/nonstdlib/tests.jl b/test/nonstdlib/tests.jl index 73f34d72..b290293d 100644 --- a/test/nonstdlib/tests.jl +++ b/test/nonstdlib/tests.jl @@ -1,5 +1,6 @@ module TestNonStdModels include("Pushouts.jl") +include("ResourceSharers.jl") end diff --git a/test/stdlib/Arithmetic.jl b/test/stdlib/Arithmetic.jl index 74550953..99dd3ca7 100644 --- a/test/stdlib/Arithmetic.jl +++ b/test/stdlib/Arithmetic.jl @@ -50,7 +50,8 @@ using .ThCategory end # Ring of integers -#--------------------- +# -------------------- + using .ThRing import .ThRing: zero, one, -, +, * diff --git a/test/syntax/GATs.jl b/test/syntax/GATs.jl index c5901e74..00b930b7 100644 --- a/test/syntax/GATs.jl +++ b/test/syntax/GATs.jl @@ -55,7 +55,6 @@ seg_expr = quote id_span(x) := Span(x, id(x),id(x)) ⊣ [x::Ob] end - thcat = fromexpr(GAT(:ThCat), seg_expr, GAT; current_module=[:Foo, :Bar]) O, H, i = idents(thcat; name=[:Ob, :Hom, :id]) @@ -92,8 +91,8 @@ HomS = AlgSort(HomT) @test rename(gettag(scope), Dict(:A=>:Z), HomT) isa AlgType @test retag(Dict(gettag(scope)=>newscopetag()), HomT) isa AlgType @test reident(Dict(A=>ident(scope; name=:B)), HomS) isa AlgSort -@test reident(Dict(A=>ident(scope; name=:B)), AlgEqSort(HomS.head, HomS.method)) == - AlgEqSort(HomS.head, HomS.method) +@test reident(Dict(A=>ident(scope; name=:B)), AlgEqSort(HomS)) == + AlgEqSort(HomS) @test sortcheck(c, AlgTerm(A)) == ObS @@ -122,7 +121,7 @@ iida = AlgTerm(i, im, [AlgTerm(i, im, [AlgTerm(A)])]) # Good type and bad type haa = HomT -haia = AlgType(HomS.head, HomS.method, [ATerm, ida]) +haia = AlgType(headof(HomS), methodof(HomS), [ATerm, ida]) @test sortcheck(c, haa) @test_throws Exception sortcheck(c, haia) @@ -158,6 +157,14 @@ end id_span(x) := Span(x, id(x),id(x)) ⊣ [x::Ob] end +# Dtrys + +tuplescope = fromexpr(ThMonoid.Meta.theory, :([x::(a::(s,t),b)]), TypeScope) + +@algebraic ThRing function f(x, y) + x * y + x * x +end + @test Base.isempty(GAT(:_EMPTY)) end # module diff --git a/test/util/Dtrys.jl b/test/util/Dtrys.jl new file mode 100644 index 00000000..984028da --- /dev/null +++ b/test/util/Dtrys.jl @@ -0,0 +1,123 @@ +module TestDtrys + +using GATlab.Util.Dtrys +import .Dtrys: node, leaf, Node, Leaf, Empty, NonEmpty, zipwith, flatten, fold +using Test +using OrderedCollections + +using MLStyle + +@test_throws ErrorException Dtrys.Node_(OrderedDict{Symbol, Int}()) + +t1 = node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3))) + +@test t1.a isa AbstractDtry +@test t1.a == t1[:a] +@test_throws Dtrys.DtryDerefError t1[] +@test_throws Dtrys.DtryIndexError t1.z +@test t1.a[] == 1 +@test t1.b.a isa NonEmptyDtry +@test t1.b.a[] == 2 + +@test sprint(show, t1.a) == "leaf(1)::NonEmptyDtry{Int64}" +@test sprint(show, t1.b) == "NonEmptyDtry{Int64}\n├─ :a ⇒ 2\n└─ :c ⇒ 3\n" +@test sprint(show, Dtry{Int}()) == "Dtry{Int64}()" + +@test ■ == PACKAGE_ROOT +@test ■.a isa DtryVar +@test ■.a.b isa DtryVar +@test_throws Dtrys.DtryVarNotFound t1[■] + +@test haskey(t1, ■.a) +@test t1[■.a] == 1 + +@test t1[■.b.c] == 3 + +@test sprint(show, ■.a) == "■.a" + +@test map(x -> x + 1, leaf(2)) == leaf(3) +@test map(x -> x + 1, Dtry{Int}()) == Dtry{Int}() + +@test filter(x -> x % 2 == 0, t1) == node(:b => node(:a => leaf(2))) +@test filter(_ -> false, Dtry{Int}()) == Dtry{Int}() + +function int_sqrt(x) + try + Some(Int(sqrt(x))) + catch e + nothing + end +end + +@test filtermap(int_sqrt, Int, t1) == node(:a => leaf(1)) +@test filtermap(int_sqrt, Int, Dtry{Int}()) == Dtry{Int}() + +@test t1 == @match t1 begin + NonEmpty(net1) => net1 +end + +@test zipwith(+, t1, t1) == node(:a => leaf(2), :b => node(:a => leaf(4), :c => leaf(6))) +# TODO: fix this, zipwith should take an argument for the return type +@test zipwith(+, Dtry{Int}(), Dtry{Int}()) == Dtry{Int}() +@test_throws ErrorException zipwith(+, Dtry{Int}(), t1) +@test zip(t1, t1) == node(:a => leaf((1,1)), :b => node(:a => leaf((2,2)), :c => leaf((3,3)))) + +@test first(t1) == 1 +@test_throws ErrorException first(Dtry{Int}()) + +@test hasproperty(t1, :a) +@test !hasproperty(t1, :z) + +@test [propertynames(t1)...] == [:a, :b] +@test keys(t1) == propertynames(t1) + +@test valtype(t1) == Int +@test valtype(typeof(t1)) == Int +@test eltype(t1) == Int + +@test flatten(Dtry{Dtry{Int}}()) == Dtry{Int}() +@test flatten(leaf(t1)) == t1 +@test flatten(leaf(leaf(1))) == leaf(1) +@test flatten(node(:f => leaf(leaf(1)), :g => leaf(t1))) == + node( + :f => leaf(1) + , :g => node( + :a => leaf(1) + , :b => node( + :a => leaf(2) + , :c => leaf(3) + ) + ) + ) + +@test mapwithkey((k, _) -> k, DtryVar, t1) == node(:a => leaf(■.a), :b => node(:a => leaf(■.b.a), :c => leaf(■.b.c))) +@test mapwithkey((k, _) -> k, DtryVar, Dtry{Int}()) == Dtry{DtryVar}() + +t1_keys = DtryVar[] + +traversewithkey((k, _) -> push!(t1_keys, k), t1) + +@test t1_keys == [■.a, ■.b.a, ■.b.c] + +b = Ref(true) + +traversewithkey((_, _) -> b[] = false, Dtry{Int}()) + +@test b[] + +@test fold(0, identity, d -> sum(values(d)), t1) == 6 +@test fold(0, identity, d -> sum(values(d)), Dtry{Int}()) == 0 + +@test all(iseven, t1) == false +@test all(iseven, filter(iseven, t1)) + +@test merge(t1, node(:b => node(:z => leaf(4)))) == node(:a => leaf(1), :b => node(:a => leaf(2), :c => leaf(3), :z => leaf(4))) + +@test Dtry(■.a => 1, ■.b.a => 2, ■.b.c => 3) == t1 +@test_throws ErrorException Dtry(■.a => 1, ■.a.b => 2) + +t2 = node(:a => Dtry{Int}()) + +@test t2 == Dtry{Int}() + +end diff --git a/test/util/MetaUtils.jl b/test/util/MetaUtils.jl index 3b416102..fed8ed47 100644 --- a/test/util/MetaUtils.jl +++ b/test/util/MetaUtils.jl @@ -30,7 +30,7 @@ end # Function parsing @test_throws ParseError parse_fun(:(f(x,y))) @test (parse_fun(:(function f(x,y) x end)) == - JuliaFunction(:f, [:(x::Any), :(y::Any)], [], [], nothing, strip_all(quote x end))) + JuliaFunction(:f, [:(x::Any), :(y::Any)], [], [], nothing, strip_all(quote x end))) @test parse_fun(strip_all(quote """ My docstring @@ -40,11 +40,11 @@ end).args[1]) == JuliaFunction(:f, [:(x::Any), :(y::Any)], [], [], nothing, stri # TODO RHS has LNN between quote and x @test (parse_fun(:(function f(x::Int,y::Int)::Int x end)) == - JuliaFunction(:f, [:(x::Int),:(y::Int)], [], [], :Int, strip_all(quote x end))) + JuliaFunction(:f, [:(x::Int),:(y::Int)], [], [], :Int, strip_all(quote x end))) # TODO RHS has LNN between quote and x @test (parse_fun(:(f(x,y) = x)) == - JuliaFunction(:f, [:(x::Any), :(y::Any)], [], [], nothing, strip_all(quote x end))) + JuliaFunction(:f, [:(x::Any), :(y::Any)], [], [], nothing, strip_all(quote x end))) sig = JuliaFunctionSig(:f, [:Int,:Int]) @test parse_function_sig(parse_fun(:(function f(x::Int,y::Int)::Int end))) == sig diff --git a/test/util/tests.jl b/test/util/tests.jl index 57ec9415..ef12c46b 100644 --- a/test/util/tests.jl +++ b/test/util/tests.jl @@ -6,4 +6,8 @@ using Test include("MetaUtils.jl") end +@testset "Dtrys" begin + include("Dtrys.jl") +end + end