diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index ef5f2d6..5d9b4dc 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -32,7 +32,7 @@ to_graphviz, # Re-exported from Catlab ## rewrite average_rewrite, ## openoperators -transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! +transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!, AbstractSDRewriteRule, Op1SDRule, Op2SDRule, apply_rule!, rewrite! using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom diff --git a/src/deca/Deca.jl b/src/deca/Deca.jl index 8202704..731174b 100644 --- a/src/deca/Deca.jl +++ b/src/deca/Deca.jl @@ -4,9 +4,9 @@ using DataStructures using ..DiagrammaticEquations using Catlab -import ..infer_types!, ..resolve_overloads! +import ..infer_types!, ..resolve_overloads!, ..rewrite! -export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec! +export normalize_unicode, varname, infer_types!, resolve_overloads!, rewrite!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec! include("deca_acset.jl") include("deca_visualization.jl") diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 4334c4b..e4d112f 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -359,3 +359,71 @@ Resolve function overloads based on types of src and tgt. resolve_overloads!(d::SummationDecapode) = resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D) +# Default Rewrite Rules +# --------------------- + +rewrite_rules_2D = Vector{AbstractSDRewriteRule}([ + Op1SDRule( + :Δ₀, + @decapode begin + y == ∘(d,δ)(X) + end), + + Op1SDRule( + :Δ₁, + @decapode begin + (X,y)::Form1 + y == ∘(d,δ)(X) + ∘(δ,d)(X) + end), + + Op1SDRule( + :Δ₂, + @decapode begin + y == ∘(δ,d)(X) + end), + + Op1SDRule( + :δ, + @decapode begin + y == ∘(⋆,d,⋆)(X) + end), + + Op1SDRule( + :δ₁, + @decapode begin + y == ∘(⋆,d,⋆)(X) + end), + + Op1SDRule( + :δ₂, + @decapode begin + y == ∘(⋆,d,⋆)(X) + end), + + Op2SDRule( + :ι₁, + @decapode begin + y == -1*⋆((⋆p1) ∧ p2) + end), + + Op2SDRule( + :L₀, + @decapode begin + y == ι(p1, d(p2)) + end), + + Op2SDRule( + :L₁, + @decapode begin + y == ι(p1, d(p2)) + d(ι(p1, p2)) + end), + + Op2SDRule( + :L₂, + @decapode begin + y == d(ι(p1, p2)) + end)]) + +rewrite!(d::SummationDecapode) = + rewrite!(d, rewrite_rules_2D) + diff --git a/src/openoperators.jl b/src/openoperators.jl index b929463..839ecf7 100644 --- a/src/openoperators.jl +++ b/src/openoperators.jl @@ -1,3 +1,26 @@ +abstract type AbstractSDRewriteRule end + +struct Op1SDRule <: AbstractSDRewriteRule + LHS::Union{Symbol, SummationDecapode} + RHS::Union{Symbol, SummationDecapode} +end + +struct Op2SDRule <: AbstractSDRewriteRule + LHS::Union{Symbol, SummationDecapode} + RHS::Union{Symbol, SummationDecapode} + proj1::Int + proj2::Int +end + +function Op2SDRule(LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}) + p1s = incident(RHS, :p1, :name) + p2s = incident(RHS, :p2, :name) + if length(p1s) != 1 || length(p2s) != 1 + error("proj1 and proj2 to use were not given, but unique distinguished variables p1 and p2 were not found. Found p1: $(p1s) and p2: $(p2s).") + end + Op2SDRule(LHS, RHS, only(p1s), only(p2s)) +end + # Opening up Op1s # -------------- @@ -8,6 +31,9 @@ function validate_op1_match(d::SummationDecapode, LHS::SummationDecapode) end end +validate_op1_match(d::SummationDecapode, r::Op1SDRule) = + validate_op1_match(d, r.LHS) + # Validate whether RHS represents a valid replacement for an op1. function validate_op1_replacement(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode) if length(infer_states(RHS)) != 1 || length(infer_terminals(RHS)) != 1 @@ -15,6 +41,9 @@ function validate_op1_replacement(d::SummationDecapode, LHS::Symbol, RHS::Summat end end +validate_op1_replacement(d::SummationDecapode, r::Op1SDRule) = + validate_op1_replacement(d, r.LHS, r.RHS) + """ function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode) Given a Decapode, d, replace at most one instance of the left-hand-side unary operator with those of the right-hand-side. @@ -81,6 +110,19 @@ function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::Symbol) LHS_op1 end +""" function replace_op1!(d::SummationDecapode, r::Op1SDRule) + +Given a Decapode, d, replace at most one instance of the left-hand-side unary operator with those of the right-hand-side. + +Return the index of the replaced unary operator, 0 if no match was found. +See also: [`replace_op2!`](@ref), [`replace_all_op1s!`](@ref) +""" +replace_op1!(d::SummationDecapode, r::Op1SDRule) = + replace_op1!(d, r.LHS, r.RHS) + +apply_rule!(d::SummationDecapode, r::Op1SDRule) = + replace_op1!(d, r) + """ function replace_all_op1s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}) Given a Decapode, d, replace all instances of the left-hand-side unary operator with those of the right-hand-side. @@ -97,6 +139,17 @@ function replace_all_op1s!(d::SummationDecapode, LHS::Union{Symbol, SummationDec any_replaced end +""" function replace_all_op1s!(d::SummationDecapode, r::Op1SDRule) + +Given a Decapode, d, replace all instances of the left-hand-side unary operator with those of the right-hand-side. + +Return true if any replacements were made, otherwise false. + +See also: [`replace_op1!`](@ref), [`replace_all_op2s!`](@ref) +""" +replace_all_op1s!(d::SummationDecapode, r::Op1SDRule) = + replace_all_op1s!(d, r.LHS, r.RHS) + # Opening up Op2s # -------------- @@ -107,6 +160,9 @@ function validate_op2_match(d::SummationDecapode, LHS::SummationDecapode) end end +validate_op2_match(d::SummationDecapode, r::Op2SDRule) = + validate_op2_match(d, r.LHS) + # Validate whether RHS represents a valid replacement for an op2. function validate_op2_replacement(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode, proj1::Int, proj2::Int) if length(infer_states(RHS)) != 2 || length(infer_terminals(RHS)) != 1 @@ -117,6 +173,9 @@ function validate_op2_replacement(d::SummationDecapode, LHS::Symbol, RHS::Summat end end +validate_op2_replacement(d::SummationDecapode, r::Op2SDRule) = + validate_op2_replacement(d, r.LHS, r.RHS, r.proj1, r.proj2) + """ function replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode, proj1::Int, proj2::Int) Given a Decapode, d, replace at most one instance of the left-hand-side binary operator with those of the right-hand-side. @@ -202,6 +261,12 @@ end replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::Symbol, proj1, proj2) = replace_op2!(d, LHS, RHS) +replace_op2!(d::SummationDecapode, r::Op2SDRule) = + replace_op2!(d, r.LHS, r.RHS, r.proj1, r.proj2) + +apply_rule!(d::SummationDecapode, r::Op2SDRule) = + replace_op2!(d, r) + """ function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}, proj1::Int, proj2::Int) Given a Decapode, d, replace all instances of the left-hand-side binary operator with those of the right-hand-side. @@ -245,3 +310,9 @@ function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDec any_replaced end +replace_all_op2s!(d::SummationDecapode, r::Op2SDRule) = + replace_all_op2s!(d, r.LHS, r.RHS, r.proj1, r.proj2) + +rewrite!(d::SummationDecapode, rules::AbstractVector{AbstractSDRewriteRule}) = + foreach(r -> apply_rule!(d,r), rules) + diff --git a/test/core.jl b/test/core.jl deleted file mode 100644 index 7f5983d..0000000 --- a/test/core.jl +++ /dev/null @@ -1 +0,0 @@ -@test "Hello, World!" == "Hello, World!" diff --git a/test/openoperators.jl b/test/openoperators.jl index 1c9245c..c51a337 100644 --- a/test/openoperators.jl +++ b/test/openoperators.jl @@ -1,5 +1,6 @@ using ACSets using DiagrammaticEquations +using DiagrammaticEquations.Deca using Test @testset "Open Operators" begin @@ -239,3 +240,98 @@ using Test only(incident(RHS, :p1, :name)), only(incident(RHS, :p2, :name))) end +@testset "Rewrite Rules" begin + # Explicit Rule Application + # ------------------------- + + # Test expanding the Heat equation. + rule = Op1SDRule( + @decapode begin + (X,y,Z)::Form0 + y == Δ(X) + end + , + @decapode begin + (X,y)::Form0 + y == -1*∘(d,⋆,d,⋆)(X) + end) + Heat = @decapode begin + C::Form0 + ∂ₜ(C) == Δ(C) + end + apply_rule!(Heat, rule) + @test Heat == @acset SummationDecapode{Any,Any,Symbol} begin + Var=4 + type=[:infer, :Literal, :Form0, :infer] + name=[Symbol("•1"), Symbol("-1"), :C, :Ċ] + TVar=1 + incl=[4] + Op1=2 + src=[3,3] + tgt=[4,1] + op1=[:∂ₜ, [:d, :⋆, :d, :⋆]] + Op2=1 + proj1=[2] + proj2=[1] + res=[4] + op2=[:*] + end + + # Test expanding the vector laplacian. + rule = Op1SDRule( + @decapode begin + (X,y)::Form1 + y == Δ(X) + end + , + @decapode begin + (X,y)::Form1 + y == ∘(d,⋆,d,⋆)(X) + ∘(⋆,d,⋆,d)(X) + end) + VectorHeat = @decapode begin + V::Form1 + ∂ₜ(V) == -1*Δ(V) + end + apply_rule!(VectorHeat, rule) + @test VectorHeat == @acset SummationDecapode{Any,Any,Symbol} begin + Var = 6 + TVar = 1 + Op1 = 3 + Op2 = 1 + Σ = 1 + Summand = 2 + src = [5, 5, 5] + tgt = [2, 3, 1] + proj1 = [4] + proj2 = [6] + res = [2] + incl = [2] + summand = [1, 3] + summation = [1, 1] + sum = [6] + op1 = Any[:∂ₜ, [:⋆, :d, :⋆, :d], [:d, :⋆, :d, :⋆]] + op2 = [:*] + type = [:infer, :infer, :infer, :Literal, :Form1, :infer] + name = [Symbol("•1"), :V̇, Symbol("•2"), Symbol("-1"), :V, Symbol("•2")] + end + + # Default Rules Application + # ------------------------- + + # Test expanding the Brusselator. + Brusselator = @decapode begin + (U, V)::Form0 + U2V::Form0 + (U̇, V̇)::Form0 + (α)::Constant + F::Parameter + U2V == (U .* U) .* V + U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F + V̇ == (3.4 * U) - U2V + (α * Δ(V)) + ∂ₜ(U) == U̇ + ∂ₜ(V) == V̇ + end + rewrite!(Brusselator) + @test Brusselator[:op1] == Any[[:d,:⋆,:d,:⋆], [:d,:⋆,:d,:⋆], :∂ₜ, :∂ₜ] +end +