Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SciMLStructure Interface #233

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
Expand All @@ -36,10 +37,10 @@ DecapodesCUDAExt = "CUDA"
[compat]
ACSets = "0.2"
Artifacts = "1"
CUDA = "5.2"
Catlab = "0.15, 0.16"
CombinatorialSpaces = "0.6.3"
ComponentArrays = "0.15"
CUDA = "5.2"
DataStructures = "0.18.13"
DiagrammaticEquations = "0.1"
Distributions = "0.25"
Expand Down
11 changes: 6 additions & 5 deletions src/Decapodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ using PreallocationTools
using DiagrammaticEquations
using DiagrammaticEquations.Deca

export
findname, flat_op,
gensim, evalsim, closest_point, findnode, compile, compile_env, PhysicsState, default_dec_matrix_generate, default_dec_cu_matrix_generate, default_dec_generate, VectorForm,
CartesianPoint, SpherePoint, r, theta, phi, TangentBasis, θhat, ϕhat,
CPUTarget, CUDATarget
export findname, flat_op,
gensim, evalsim, closest_point, findnode, compile, compile_env, PhysicsState, default_dec_matrix_generate, default_dec_cu_matrix_generate, default_dec_generate, VectorForm,
CartesianPoint, SpherePoint, r, theta, phi, TangentBasis, θhat, ϕhat,
CPUTarget, CUDATarget,
TunableWrapper, ConstantsWrapper, CachesWrapper, DiscreteWrapper, canonicalize

append_dot(s::Symbol) = Symbol(string(s)*'\U0307')

include("coordinates.jl")
include("operators.jl")
include("simulation.jl")
include("scimlstructure.jl")

# documentation
include("canon/Canon.jl")
Expand Down
118 changes: 118 additions & 0 deletions src/scimlstructure.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
module StructuredNamedTuples

using SciMLStructures: Tunable, Constants, Caches, Discrete
import SciMLStructures: canonicalize, isscimlstructure

struct TunableWrapper
vals
end
struct ConstantsWrapper
vals
end
struct CachesWrapper
vals
end
struct DiscreteWrapper
vals
end

isscimlstructure(::NamedTuple) = true

function canonicalize(::Tunable, p::NamedTuple)
values = map(x -> x.vals, filter(Tuple(p)) do v
typeof(v) == TunableWrapper
end)
values = map(x -> x.vals, Tuple(keys_values))
lens = length.(values)
# TODO: Use metaprogramming instead of hardcoding AbstractVector.
function repack(new_values::AbstractVector)
lens_idx = 1
buf_idx = 1
map(keys(p), p) do k,v
if typeof(v) == TunableWrapper
new_value = new_values[buf_idx:buf_idx+lens[lens_idx]-1]
buf_idx += lens[lens_idx]
lens_idx += 1
k,TunableWrapper(new_value)
else
k,v
end
end |> NamedTuple
end
return reduce(vcat, values), repack, true
end

# TODO: Use metaprogramming instead of copying-and-pasting
function canonicalize(::Constants, p::NamedTuple)
keys_values = filter(keys(p), p) do k,v
typeof(v) == ConstantsWrapper
end
values = map(x -> x.vals, Tuple(keys_values))
lens = length.(values)
# TODO: Use metaprogramming instead of hardcoding AbstractVector.
function repack(new_values::AbstractVector)
lens_idx = 1
buf_idx = 1
map(keys(p), p) do k,v
if typeof(v) == ConstantsWrapper
new_value = new_values[buf_idx:buf_idx+lens[lens_idx]-1]
buf_idx += lens[lens_idx]
lens_idx += 1
k,ConstantsWrapper(new_value)
else
k,v
end
end |> NamedTuple
end
return reduce(vcat, values), repack, true
end

function canonicalize(::Caches, p::NamedTuple)
keys_values = filter(keys(p), p) do k,v
typeof(v) == CachesWrapper
end
values = map(x -> x.vals, Tuple(keys_values))
lens = length.(values)
# TODO: Use metaprogramming instead of hardcoding AbstractVector.
function repack(new_values::AbstractVector)
lens_idx = 1
buf_idx = 1
map(keys(p), p) do k,v
if typeof(v) == CachesWrapper
new_value = new_values[buf_idx:buf_idx+lens[lens_idx]-1]
buf_idx += lens[lens_idx]
lens_idx += 1
k,CachesWrapper(new_value)
else
k,v
end
end |> NamedTuple
end
return reduce(vcat, values), repack, true
end

function canonicalize(::Discrete, p::NamedTuple)
keys_values = filter(keys(p), p) do k,v
typeof(v) == DiscreteWrapper
end
values = map(x -> x.vals, Tuple(keys_values))
lens = length.(values)
# TODO: Use metaprogramming instead of hardcoding AbstractVector.
function repack(new_values::AbstractVector)
lens_idx = 1
buf_idx = 1
map(keys(p), p) do k,v
if typeof(v) == DiscreteWrapper
new_value = new_values[buf_idx:buf_idx+lens[lens_idx]-1]
buf_idx += lens[lens_idx]
lens_idx += 1
k,DiscreteWrapper(new_value)
else
k,v
end
end |> NamedTuple
end
return reduce(vcat, values), repack, true
end

end
Loading