Skip to content

Commit

Permalink
Use musical isomorphism inference from DiagX main
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 authored and GeorgeR227 committed Nov 1, 2024
1 parent dc75301 commit 8cbca40
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 60 deletions.
12 changes: 6 additions & 6 deletions docs/src/bsh/budyko_sellers_halfar.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ We have defined the [Halfar ice model](../ice_dynamics/ice_dynamics.md) in other
``` @example DEC
halfar_eq2 = @decapode begin
h::Form0
Γ::Form1
Γ::Form0
n::Constant
ḣ == ∂ₜ(h)
ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
ḣ == Γ * ∘(⋆, d, ⋆)(d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
end
glens_law = @decapode begin
Γ::Form1
A::Form1
Γ::Form0
A::Form0
(ρ,g,n)::Constant
Γ == (2/(n+2))*A*(ρ*g)^n
Expand Down Expand Up @@ -140,9 +140,9 @@ We need to specify physically what it means for these two terms to interact. We
``` @example DEC
warming = @decapode begin
Tₛ::Form0
A::Form1
A::Form0
A == avg₀₁(5.8282*10^(-0.236 * Tₛ)*1.65e7)
A == 5.8282*10^(-0.236 * Tₛ)*1.65e7
end
Expand Down
8 changes: 4 additions & 4 deletions docs/src/cism/cism.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ Halfar's equation looks a little disjoint. It seems that the front most terms ar
# translated into the exterior calculus.
halfar_eq2 = @decapode begin
h::Form0
Γ::Form1
Γ::Form0
n::Constant
ḣ == ∂ₜ(h)
ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
ḣ == Γ * ∘(⋆, d, ⋆)(d(h) (mag(♯(d(h)))^(n-1)) (h^(n+2)))
end
to_graphviz(halfar_eq2)
Expand All @@ -72,7 +72,7 @@ Here, we recognize that Gamma is in fact what glaciologists call "Glen's Flow La
# assumptions made in glacier theory, their experimental foundations and
# consequences. (1958)
glens_law = @decapode begin
Γ::Form1
Γ::Form0
(A,ρ,g,n)::Constant
Γ == (2/(n+2))*A*(ρ*g)^n
Expand Down Expand Up @@ -154,7 +154,7 @@ g = 9.8101
alpha = 1/9
beta = 1/18
flwa = 1e-16
A = fill(1e-16, ne(sd))
A = 1e-16
Gamma = 2.0/(n+2) * flwa * (ρ * g)^n
t0 = (beta/Gamma) * (7.0/4.0)^3 * (R₀^4 / H^7)
Expand Down
12 changes: 4 additions & 8 deletions docs/src/grigoriev/grigoriev.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ The coordinates of a vertex are stored in `sd[:point]`. Let's use our interpolat
n = 3
ρ = 910
g = 9.8101
A = fill(1e-16, ne(sd))
A = 1e-16
h₀ = map(sd[:point]) do (x,y,_)
tif_val = ice_interp(x,y)
Expand All @@ -118,15 +118,15 @@ For exposition on this Halfar Decapode, see our [Glacial Flow](../ice_dynamics/i
``` @example DEC
halfar_eq2 = @decapode begin
h::Form0
Γ::Form1
Γ::Form0
n::Constant
ḣ == ∂ₜ(h)
ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
ḣ == Γ * ∘(⋆, d, ⋆)(d(h) (mag(♯(d(h)))^(n-1)) h^(n+2))
end
glens_law = @decapode begin
Γ::Form1
Γ::Form0
(A,ρ,g,n)::Constant
Γ == (2/(n+2))*A*(ρ*g)^n
Expand All @@ -151,10 +151,6 @@ to_graphviz(ice_dynamics)
function generate(sd, my_symbol; hodge=GeometricHodge())
op = @match my_symbol begin
:mag => x -> norm.(x)
:♯ => begin
sharp_mat = ♯_mat(sd, AltPPSharp())
x -> sharp_mat * x
end
x => error("Unmatched operator $my_symbol")
end
return op
Expand Down
6 changes: 3 additions & 3 deletions docs/src/halmo/halmo.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ Halfar's equation and Glen's law are composed like so:
```@example DEC_halmo
halfar_eq2 = @decapode begin
h::Form0
Γ::Form1
Γ::Form0
n::Constant
∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
∂ₜ(h) == Γ * ∘(⋆, d, ⋆)(d(h) (mag(♯(d(h)))^(n-1)) (h^(n+2)))
end
glens_law = @decapode begin
Γ::Form1
Γ::Form0
(A,ρ,g,n)::Constant
Γ == (2/(n+2))*A*(ρ*g)^n
Expand Down
11 changes: 4 additions & 7 deletions docs/src/ice_dynamics/ice_dynamics.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,17 @@ We'll change the term out front to Γ so we can demonstrate composition in a mom
In the exterior calculus, we could write the above equations like so:

```math
\partial_t(h) = \circ(\star, d, \star)(\Gamma\quad d(h)\quad \text{avg}_{01}|d(h)^\sharp|^{n-1} \quad \text{avg}_{01}(h^{n+2})).
\partial_t(h) = \Gamma\quad \circ(\star, d, \star)(d(h)\quad \wedge \quad|d(h)^\sharp|^{n-1} \quad \wedge \quad (h^{n+2})).
```

`avg` here is an operator that performs the midpoint rule, setting the value at an edge to be the average of the values at its two vertices.

``` @example DEC
halfar_eq2 = @decapode begin
h::Form0
Γ::Form1
Γ::Form0
n::Constant
ḣ == ∂ₜ(h)
ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2)))
ḣ == Γ * ∘(⋆, d, ⋆)(d(h) (mag(♯(d(h)))^(n-1)) (h^(n+2)))
end
to_graphviz(halfar_eq2)
Expand All @@ -65,8 +63,7 @@ And here, a formulation of Glen's law from J.W. Glen's 1958 ["The flow law of ic

``` @example DEC
glens_law = @decapode begin
#Γ::Form0
Γ::Form1
Γ::Form0
(A,ρ,g,n)::Constant
Γ == (2/(n+2))*A*(ρ*g)^n
Expand Down
44 changes: 27 additions & 17 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,30 @@ using Krylov
using LinearAlgebra
using SparseArrays

# Define mappings for default DEC operations that are not optimizable.
# --------------------------------------------------------------------
function default_dec_generate(sd::HasDeltaSet, my_symbol::Symbol, hodge::DiscreteHodge=GeometricHodge())

op = @match my_symbol begin

:plus => (+)
:(-) || :neg => x -> -1 .* x
:ln => (x -> log.(x))
# Musical Isomorphisms
:♯ᵖᵖ => dec_♯_p(sd)
:♯ᵈᵈ => dec_♯_d(sd)
:♭ᵈᵖ => dec_♭(sd)

_ => error("Unmatched operator $my_symbol")
end

return (args...) -> op(args...)
end

function default_dec_cu_generate() end;

# Define mappings for default DEC operations that are optimizable.
# ----------------------------------------------------------------
function default_dec_cu_matrix_generate() end;

function default_dec_matrix_generate(sd::HasDeltaSet, my_symbol::Symbol, hodge::DiscreteHodge)
Expand Down Expand Up @@ -58,10 +82,10 @@ function default_dec_matrix_generate(sd::HasDeltaSet, my_symbol::Symbol, hodge::
:Δᵈ₁ => Δᵈ(Val{1},sd)

# Musical Isomorphisms
:♯ => dec_♯_p(sd)
:♯ => dec_♯_d(sd)
:♯ᵖᵖ => dec_♯_p(sd)
:♯ᵈᵈ => dec_♯_d(sd)

:♭ => dec_♭(sd)
:♭ᵈᵖ => dec_♭(sd)

# Averaging Operator
:avg₀₁ => dec_avg₀₁(sd)
Expand Down Expand Up @@ -146,20 +170,6 @@ function dec_avg₀₁(sd::HasDeltaSet)
(avg_mat, x -> avg_mat * x)
end

function default_dec_generate(sd::HasDeltaSet, my_symbol::Symbol, hodge::DiscreteHodge=GeometricHodge())

op = @match my_symbol begin

:plus => (+)
:(-) || :neg => x -> -1 .* x
:ln => (x -> log.(x))

_ => error("Unmatched operator $my_symbol")
end

return (args...) -> op(args...)
end

function open_operators(d::SummationDecapode; dimension::Int=2)
e = deepcopy(d)
open_operators!(e, dimension=dimension)
Expand Down
58 changes: 43 additions & 15 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,12 @@ Base.showerror(io::IO, e::InvalidCodeTargetException) = print(io, "Provided code
This creates the symbol to function linking for the simulation output. Those run through the `default_dec` backend
expect both an in-place and an out-of-place variant in that order. User defined operations only support out-of-place.
"""
function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Set{Symbol}, nonoptimizable_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
defined_ops = deepcopy(con_dec_operators)

defs = quote end

# These are optimizable default DEC functions.
for op in dec_matrices
op in defined_ops && continue

Expand All @@ -253,6 +254,25 @@ function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec
push!(defined_ops, op)
end

# These are nonoptimizable default DEC functions.
for op in nonoptimizable_operators
op in defined_ops && continue

quote_op = QuoteNode(op)

# TODO: Add support for user-defined code targets
default_generation = @match code_target begin
::CPUBackend => :default_dec_generate
::CUDABackend => :default_dec_cu_generate
_ => throw(InvalidCodeTargetException(code_target))
end

def = :($op = $(default_generation)(mesh, $quote_op, hodge))
push!(defs.args, def)

push!(defined_ops, op)
end

# Add in user-defined operations
for op in vcat(d[:op1], d[:op2])
if op == DerivOp || op in defined_ops || op in ARITHMETIC_OPS
Expand Down Expand Up @@ -371,15 +391,15 @@ const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+,
:.= => :.=)

"""
compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, matrix_optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
Function that compiles the computation body. `d` is the input Decapode, `inputs` is a vector of state variables and literals,
`alloc_vec` should be empty when passed in, `optimizable_dec_operators` is a collection of all DEC operator symbols that can use special
`alloc_vec` should be empty when passed in, `matrix_optimizable_dec_operators` is a collection of all DEC operator symbols that can use special
in-place methods, `dimension` is the dimension of the problem (usually 1 or 2), `stateeltype` is the type of the state elements
(usually Float32 or Float64), `code_target` determines what architecture the code is compiled for (either CPU or CUDA), and `preallocate`
which is set to `true` by default and determines if intermediate results can be preallocated..
"""
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, matrix_optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
# Get the Vars of the inputs (probably state Vars).
visited_Var = falses(nparts(d, :Var))

Expand Down Expand Up @@ -416,7 +436,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve

# TODO: Check to see if this is a DEC operator
if preallocate && is_form(d, t)
if operator in optimizable_dec_operators
if operator in matrix_optimizable_dec_operators
equality = PROMOTE_ARITHMETIC_MAP[equality]
operator = add_stub(GENSIM_INPLACE_STUB, operator)
push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target))
Expand Down Expand Up @@ -462,7 +482,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end
elseif operator in optimizable_dec_operators
elseif operator in matrix_optimizable_dec_operators
operator = add_stub(GENSIM_INPLACE_STUB, operator)
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
Expand Down Expand Up @@ -597,13 +617,13 @@ function infer_overload_compiler!(d::SummationDecapode, dimension::Int)
end

"""
init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, matrix_optimizable_dec_operators::Set{Symbol})
Collects all DEC operators that are concrete matrices.
"""
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, matrix_optimizable_dec_operators::Set{Symbol})
for op_name in vcat(d[:op1], d[:op2])
if op_name in optimizable_dec_operators
if op_name in matrix_optimizable_dec_operators
push!(dec_matrices, op_name)
end
end
Expand Down Expand Up @@ -733,25 +753,33 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
infer_overload_compiler!(gen_d, dimension)

# This will generate all of the fundemental DEC operators present
optimizable_dec_operators = Set([:₀, :₁, :₂, :₀⁻¹, :₂⁻¹,
matrix_optimizable_dec_operators = Set([:₀, :₁, :₂, :₀⁻¹, :₂⁻¹,
:d₀, :d₁, :dual_d₀, :d̃₀, :dual_d₁, :d̃₁,
:avg₀₁])
extra_dec_operators = Set([:₁⁻¹, :₀₁, :₁₀, :₁₁, :₀₂, :₂₀])
nonmatrix_optimizable_dec_operators = Set([:₁⁻¹, :₀₁, :₁₀, :₁₁, :₀₂, :₂₀])

nonoptimizable_cpu_operators = Set([:♯ᵖᵖ, :♯ᵈᵈ, :♭ᵈᵖ])
nonoptimizable_cuda_operators = Set{Symbol}()
nonoptimizable_operators = @match code_target begin
::CPUBackend => nonoptimizable_cpu_operators
::CUDABackend => nonoptimizable_cuda_operators
_ => throw(InvalidCodeTargetException(code_target))
end

init_dec_matrices!(gen_d, dec_matrices, union(optimizable_dec_operators, extra_dec_operators))
init_dec_matrices!(gen_d, dec_matrices, union(matrix_optimizable_dec_operators, nonmatrix_optimizable_dec_operators))

# This contracts matrices together into a single matrix
contracted_dec_operators = Set{Symbol}()
contract_operators!(gen_d, white_list = optimizable_dec_operators)
contract_operators!(gen_d, white_list = matrix_optimizable_dec_operators)
cont_defs = link_contract_operators(gen_d, contracted_dec_operators, stateeltype, code_target)

union!(optimizable_dec_operators, contracted_dec_operators, extra_dec_operators)
optimizable_dec_operators = union(matrix_optimizable_dec_operators, contracted_dec_operators, nonmatrix_optimizable_dec_operators)

# Compilation of the simulation
equations = compile(gen_d, input_vars, alloc_vectors, optimizable_dec_operators, dimension, stateeltype, code_target, preallocate)
data = post_process_vector_allocs(alloc_vectors, code_target)

func_defs = compile_env(gen_d, dec_matrices, contracted_dec_operators, code_target)
func_defs = compile_env(gen_d, dec_matrices, contracted_dec_operators, nonoptimizable_operators, code_target)
vect_defs = compile_var(alloc_vectors)

quote
Expand Down

0 comments on commit 8cbca40

Please sign in to comment.