diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 9529578..0d0de25 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -14,7 +14,7 @@ Collage, collate, oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, ## acset SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, -contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, expand_operators, infer_state_names, infer_terminal_names, recognize_types, +contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, resolve_overloads!, replace_names!, apply_inference_rule_op1!, apply_inference_rule_op2!, transfer_parents!, transfer_children!, @@ -26,7 +26,9 @@ Plus, AppCirc1, Var, Tan, App1, App2, ## visualization to_graphviz_property_graph, typename, draw_composition, ## rewrite -average_rewrite +average_rewrite, +## openoperators +transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! using Catlab using Catlab.Theories @@ -56,6 +58,7 @@ include("visualization.jl") include("rewrite.jl") include("pretty.jl") include("colanguage.jl") +include("openoperators.jl") include("deca/Deca.jl") include("learn/Learn.jl") diff --git a/src/acset.jl b/src/acset.jl index 19eb515..2dc7860 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -185,6 +185,16 @@ function recognize_types(d::AbstractNamedDecapode) error("Types $unrecognized_types are not recognized. CHECK: $types") end +""" is_expanded(d::AbstractNamedDecapode) + +Check that no unary operator is a composition of unary operators. +""" +is_expanded(d::AbstractNamedDecapode) = !any(x -> x isa AbstractVector, d[:op1]) + +""" function expand_operators(d::AbstractNamedDecapode) + +If any unary operator is a composition, expand it out using intermediate variables. +""" function expand_operators(d::AbstractNamedDecapode) #e = SummationDecapode{Symbol, Symbol, Symbol}() e = SummationDecapode{Any, Any, Symbol}() diff --git a/src/openoperators.jl b/src/openoperators.jl new file mode 100644 index 0000000..b929463 --- /dev/null +++ b/src/openoperators.jl @@ -0,0 +1,247 @@ +# Opening up Op1s +# -------------- + +# Validate whether LHS represents a valid op1. +function validate_op1_match(d::SummationDecapode, LHS::SummationDecapode) + if nparts(LHS, :Op1) != 1 + error("Only single operator replacement is supported for now, but found Op1s: $(LHS[:op1])") + end +end + +# Validate whether RHS represents a valid replacement for an op1. +function validate_op1_replacement(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode) + if length(infer_states(RHS)) != 1 || length(infer_terminals(RHS)) != 1 + error("The replacement for $(LHS) must have a single input and a single output, but found inputs: $(RHS[infer_states(RHS), :name]) and outputs $(RHS[infer_terminals(RHS), :name])") + end +end + +""" function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode) + +Given a Decapode, d, replace at most one instance of the left-hand-side unary operator with those of the right-hand-side. + +Return the index of the replaced operator, 0 if no match was found. + +See also: [`replace_all_op1s!`](@ref) +""" +function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode) + validate_op1_replacement(d, LHS, RHS) + isempty(incident(d, LHS, :op1)) && return 0 + + # Identify the "matched" operation. + LHS_op1 = first(incident(d, LHS, :op1)) + LHS_input = d[LHS_op1, :src] + LHS_output = d[LHS_op1, :tgt] + + # Add in the "replace" operation(s). + added_vars = copy_parts!(d, RHS).Var + RHS_input = only(intersect(infer_states(d), added_vars)) + RHS_output = only(intersect(infer_terminals(d), added_vars)) + + # Transfer LHS_input's pointers to RHS_input. + transfer_parents!(d, LHS_input, RHS_input) + transfer_children!(d, LHS_input, RHS_input) + d[RHS_input, :name] = d[LHS_input, :name] + d[RHS_input, :type] = d[LHS_input, :type] + + # Transfer LHS_output's pointers to RHS_output. + transfer_parents!(d, LHS_output, RHS_output) + transfer_children!(d, LHS_output, RHS_output) + d[RHS_output, :name] = d[LHS_output, :name] + d[RHS_output, :type] = d[LHS_output, :type] + + # Remove the replaced match and variables. + rem_parts!(d, :Var, sort!([LHS_input, LHS_output])) + rem_part!(d, :Op1, LHS_op1) + LHS_op1 +end + +""" function replace_op1!(d::SummationDecapode, LHS::SummationDecapode, RHS::SummationDecapode) + +Given a Decapode, d, replace at most one instance of the left-hand-side unary operator with those of the right-hand-side. + +Return the index of the replaced unary operator, 0 if no match was found. +See also: [`replace_op2!`](@ref), [`replace_all_op1s!`](@ref) +""" +function replace_op1!(d::SummationDecapode, LHS::SummationDecapode, RHS::SummationDecapode) + validate_op1_match(d, LHS) + replace_op1!(d, only(LHS[:op1]), RHS) +end + +""" function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::Symbol) + +Given a Decapode, d, replace at most one instance of the left-hand-side unary operator with that of the right-hand-side. + +Return the index of the replaced unary operator, 0 if no match was found. +See also: [`replace_op2!`](@ref), [`replace_all_op1s!`](@ref) +""" +function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::Symbol) + isempty(incident(d, LHS, :op1)) && return 0 + LHS_op1 = first(incident(d, LHS, :op1)) + d[LHS_op1, :op1] = RHS + LHS_op1 +end + +""" function replace_all_op1s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}) + +Given a Decapode, d, replace all instances of the left-hand-side unary operator with those of the right-hand-side. + +Return true if any replacements were made, otherwise false. + +See also: [`replace_op1!`](@ref), [`replace_all_op2s!`](@ref) +""" +function replace_all_op1s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}) + any_replaced = false + while replace_op1!(d,LHS,RHS) != 0 + any_replaced = true + end + any_replaced +end + +# Opening up Op2s +# -------------- + +# Validate whether LHS represents a valid op2. +function validate_op2_match(d::SummationDecapode, LHS::SummationDecapode) + if nparts(LHS, :Op2) != 1 + error("Only single operator replacement is supported for now, but found Op2s: $(LHS[:op2])") + end +end + +# Validate whether RHS represents a valid replacement for an op2. +function validate_op2_replacement(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode, proj1::Int, proj2::Int) + if length(infer_states(RHS)) != 2 || length(infer_terminals(RHS)) != 1 + error("The replacement for $(LHS) must have two inputs and a single output, but found inputs: $(RHS[infer_states(RHS), :name]) and outputs $(RHS[infer_terminals(RHS), :name])") + end + if !issetequal(infer_states(RHS), [proj1, proj2]) + error("The projections of the RHS of this replacement are not state variables. The projections are $(RHS[[proj1,proj2], :op2]) but the state variables are $(RHS[infer_states(RHS), :op2]).") + end +end + +""" function replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode, proj1::Int, proj2::Int) + +Given a Decapode, d, replace at most one instance of the left-hand-side binary operator with those of the right-hand-side. + +proj1 and proj2 are the indices of the intended proj1 and proj2 in RHS. + +Return the index of the replaced operator, 0 if no match was found. + +See also: [`replace_op1!`](@ref), [`replace_all_op2s!`](@ref) +""" +function replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode, proj1::Int, proj2::Int) + validate_op2_replacement(d, LHS, RHS, proj1, proj2) + isempty(incident(d, LHS, :op2)) && return 0 + + # Identify the "matched" operation. + LHS_op2 = first(incident(d, LHS, :op2)) + LHS_proj1, LHS_proj2 = d[LHS_op2, :proj1], d[LHS_op2, :proj2] + LHS_output = d[LHS_op2, :res] + + # Add in the "replace" operation(s). + added_vars = copy_parts!(d, RHS).Var + RHS_proj1, RHS_proj2 = intersect(infer_states(d), added_vars) + + # Preserve the order of proj1 and proj2. + if d[RHS_proj1, :name] != RHS[proj1, :name] + RHS_proj1, RHS_proj2 = RHS_proj2, RHS_proj1 + end + RHS_output = only(intersect(infer_terminals(d), added_vars)) + + # Transfer LHS_proj1's pointers to RHS_proj1. + transfer_parents!(d, LHS_proj1, RHS_proj1) + transfer_children!(d, LHS_proj1, RHS_proj1) + d[RHS_proj1, :name] = d[LHS_proj1, :name] + d[RHS_proj1, :type] = d[LHS_proj1, :type] + + # Transfer LHS_proj2's pointers to RHS_proj2. + transfer_parents!(d, LHS_proj2, RHS_proj2) + transfer_children!(d, LHS_proj2, RHS_proj2) + d[RHS_proj2, :name] = d[LHS_proj2, :name] + d[RHS_proj2, :type] = d[LHS_proj2, :type] + + # Transfer LHS_output's pointers to RHS_output. + transfer_parents!(d, LHS_output, RHS_output) + transfer_children!(d, LHS_output, RHS_output) + d[RHS_output, :name] = d[LHS_output, :name] + d[RHS_output, :type] = d[LHS_output, :type] + + # Remove the replaced match and variables. + rem_parts!(d, :Var, sort!([LHS_proj1, LHS_proj2, LHS_output])) + rem_part!(d, :Op2, LHS_op2) + LHS_op2 +end + +""" function replace_op2!(d::SummationDecapode, LHS::SummationDecapode, RHS::SummationDecapode, proj1::Int, proj2::Int) + +Given a Decapode, d, replace at most one instance of the left-hand-side binary operator with those of the right-hand-side. + +proj1 and proj2 are the indices of the intended proj1 and proj2 in RHS. + +Return the index of the replaced binary operator, 0 if no match was found. +See also: [`replace_op1!`](@ref), [`replace_all_op2s!`](@ref) +""" +function replace_op2!(d::SummationDecapode, LHS::SummationDecapode, RHS::SummationDecapode, proj1::Int, proj2::Int) + validate_op2_match(d, LHS) + replace_op2!(d, only(LHS[:op2]), RHS, proj1, proj2) +end + +""" function replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::Symbol) + +Given a Decapode, d, replace at most one instance of the left-hand-side binary operator with that of the right-hand-side. + +Return the index of the replaced binary operator, 0 if no match was found. +See also: [`replace_op1!`](@ref), [`replace_all_op2s!`](@ref) +""" +function replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::Symbol) + isempty(incident(d, LHS, :op2)) && return 0 + LHS_op2 = first(incident(d, LHS, :op2)) + d[LHS_op2, :op2] = RHS + LHS_op2 +end + +# Ignoring proj1 and proj2 keeps replace_all_op2s! simple. +replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::Symbol, proj1, proj2) = + replace_op2!(d, LHS, RHS) + +""" function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}, proj1::Int, proj2::Int) + +Given a Decapode, d, replace all instances of the left-hand-side binary operator with those of the right-hand-side. + +proj1 and proj2 are the indices of the intended proj1 and proj2 in RHS. + +Return true if any replacements were made, otherwise false. + +See also: [`replace_op2!`](@ref), [`replace_all_op1s!`](@ref) +""" +function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}, proj1::Int, proj2::Int) + any_replaced = false + while replace_op2!(d,LHS,RHS, proj1, proj2) != 0 + any_replaced = true + end + any_replaced +end + +""" function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}) + +Given a Decapode, d, replace all instances of the left-hand-side binary operator with those of the right-hand-side. + +Search for distinguished variables "p1" and "p2" to serve as the proj1 and proj2 from RHS. + +Return true if any replacements were made, otherwise false. + +See also: [`replace_op2!`](@ref), [`replace_all_op1s!`](@ref) +""" +function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}) + p1s = incident(RHS, :p1, :name) + p2s = incident(RHS, :p2, :name) + if length(p1s) != 1 || length(p2s) != 1 + error("proj1 and proj2 to use were not given, but unique distinguished variables p1 and p2 were not found. Found p1: $(p1s) and p2: $(p2s).") + end + proj1 = only(p1s) + proj2 = only(p2s) + any_replaced = false + while replace_op2!(d,LHS,RHS, proj1, proj2) != 0 + any_replaced = true + end + any_replaced +end + diff --git a/test/Project.toml b/test/Project.toml index c36beb4..97b4819 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" CombinatorialSpaces = "b1c52339-7909-45ad-8b6a-6e388f7c67f2" diff --git a/test/openoperators.jl b/test/openoperators.jl new file mode 100644 index 0000000..1c9245c --- /dev/null +++ b/test/openoperators.jl @@ -0,0 +1,241 @@ +using ACSets +using DiagrammaticEquations +using Test + +@testset "Open Operators" begin + # Error handling + # -------------- + + # Test erroneous left-hand side. + LHS = @decapode begin + y == Δ(X) + ∇(Z) + end + RHS = @decapode begin + y == -1*∘(d,⋆,d,⋆)(X) + end + Heat = @decapode begin + ∂ₜ(C) == Δ(C) + end + @test_throws "Only single operator replacement is supported for now, but found Op1s: [:Δ, :∇]" replace_op1!(Heat, LHS, RHS) + + # Test erroneous right-hand side. + LHS = @decapode begin + y == Δ(X) + end + RHS = @decapode begin + y == -1*∘(d,⋆,d,⋆)(X) + 10*Z + end + Heat = @decapode begin + ∂ₜ(C) == Δ(C) + end + @test_throws "The replacement for Δ must have a single input and a single output, but found inputs: [:X, :Z] and outputs [:y]" replace_op1!(Heat, LHS, RHS) + + # Transfering variables + # --------------------- + + # Test transfering variables on the Brusselator. + Brusselator = @decapode begin + (U, V)::Form0 + U2V::Form0 + (U̇, V̇)::Form0 + (α)::Constant + F::Parameter + U2V == (U .* U) .* V + U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F + V̇ == (3.4 * U) - U2V + (α * Δ(V)) + ∂ₜ(U) == U̇ + ∂ₜ(V) == V̇ + end + U_idx = only(incident(Brusselator, :U, :name)) + V_idx = only(incident(Brusselator, :V, :name)) + transfer_children!(Brusselator, U_idx, V_idx) + @test Brusselator == @decapode begin + (U, V)::Form0 + U2V::Form0 + (U̇, V̇)::Form0 + (α)::Constant + F::Parameter + U2V == (V .* V) .* V + U̇ == 1 + U2V - (4.4 * V) + (α * Δ(V)) + F + V̇ == (3.4 * V) - U2V + (α * Δ(V)) + ∂ₜ(V) == U̇ + ∂ₜ(V) == V̇ + end + + # Opening Op1s + # ------------ + + # Test expanding the Heat equation. + LHS = @decapode begin + y == Δ(X) + end + RHS = @decapode begin + y == -1*∘(d,⋆,d,⋆)(X) + end + Heat = @decapode begin + ∂ₜ(C) == Δ(C) + end + replace_op1!(Heat, LHS, RHS) + @test Heat == @acset SummationDecapode{Any,Any,Symbol} begin + Var=4 + type=[:Literal, :infer, :infer, :infer] + name=[Symbol("-1"), Symbol("•1"), :Ċ, :C] + TVar=1 + incl=[3] + Op1=2 + src=[4,4] + tgt=[3,2] + op1=[:∂ₜ, [:d, :⋆, :d, :⋆]] + Op2=1 + proj1=[1] + proj2=[2] + res=[3] + op2=[:*] + end + + # Test expanding a vector calculus equivalent. + LHS = @decapode begin + y == div(X) + end + RHS = @decapode begin + y == ∘(⋆,d,⋆)(X) + end + Divergence = @decapode begin + C::Form0 + V::Form1 + ∂ₜ(C) == C*div(V) + end + replace_op1!(Divergence, LHS, RHS) + @test Divergence == @acset SummationDecapode{Any,Any,Symbol} begin + Var = 4 + TVar = 1 + Op1 = 2 + Op2 = 1 + src = [1, 4] + tgt = [3, 2] + proj1 = [1] + proj2 = [2] + res = [3] + incl = [3] + op1 = Any[:∂ₜ, [:⋆, :d, :⋆]] + op2 = [:*] + type = [:Form0, :infer, :infer, :Form1] + name = [:C, Symbol("•2"), :Ċ, :V] + end + + # Test expanding the vector laplacian. + LHS = @decapode begin + y == Δ(X) + end + RHS = @decapode begin + y == ∘(d,⋆,d,⋆)(X) + ∘(⋆,d,⋆,d)(X) + end + Heat = @decapode begin + ∂ₜ(C) == -1*Δ(V) + end + replace_op1!(Heat, LHS, RHS) + @test Heat == @acset SummationDecapode{Any,Any,Symbol} begin + Var = 7 + TVar = 1 + Op1 = 3 + Op2 = 1 + Σ = 1 + Summand = 2 + src = [2, 3, 3] + tgt = [1, 5, 7] + proj1 = [4] + proj2 = [6] + res = [1] + incl = [1] + summand = [7, 5] + summation = [1, 1] + sum = [6] + op1 = Any[:∂ₜ, [:⋆, :d, :⋆, :d], [:d, :⋆, :d, :⋆]] + op2 = [:*] + type = [:infer, :infer, :infer, :Literal, :infer, :infer, :infer] + name = [:Ċ, :C, :V, Symbol("-1"), Symbol("•2"), Symbol("•2"), Symbol("•1")] + end + + # Test expanding multiple op1s. + LHS = @decapode begin + y == Δ(X) + end + RHS = @decapode begin + y == ∘(d,⋆,d,⋆)(X) + end + LapLap = @decapode begin + ∂ₜ(C) == Δ(Δ(C)) + end + replace_all_op1s!(LapLap, LHS, RHS) + @test LapLap == @acset SummationDecapode{Any,Any,Symbol} begin + Var = 3 + TVar = 1 + Op1 = 3 + src = [3, 3, 2] + tgt = [1, 2, 1] + incl = [1] + op1 = Any[:∂ₜ, [:d, :⋆, :d, :⋆], [:d, :⋆, :d, :⋆]] + type = [:infer, :infer, :infer] + name = [:Ċ, Symbol("•2"), :C] + end + + # Test expanding the Brusselator. + Brusselator = @decapode begin + (U, V)::Form0 + U2V::Form0 + (U̇, V̇)::Form0 + (α)::Constant + F::Parameter + U2V == (U .* U) .* V + U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F + V̇ == (3.4 * U) - U2V + (α * Δ(V)) + ∂ₜ(U) == U̇ + ∂ₜ(V) == V̇ + end + LHS = @decapode begin + y == Δ(X) + end + RHS = @decapode begin + y == -1*∘(d,⋆,d,⋆)(X) + end + replace_all_op1s!(Brusselator, LHS, RHS) + @test Brusselator[:op1] == Any[[:d,:⋆,:d,:⋆], [:d,:⋆,:d,:⋆], :∂ₜ, :∂ₜ] + + # Opening Op2s + # ------------ + + # Test expanding the interior product. + Interior = @decapode begin + ∂ₜ(C) == ι(A, B) + end + LHS = @decapode begin + y == ι(X, Z) + end + RHS = @decapode begin + y == -1*⋆((⋆p1) ∧ p2) + end + replace_all_op2s!(Interior, LHS, RHS) + @test Interior == @acset SummationDecapode{Any,Any,Symbol} begin + Var = 8 + TVar = 1 + Op1 = 3 + Op2 = 2 + src = [2, 4, 1] + tgt = [5, 3, 8] + proj1 = [7, 3] + proj2 = [8, 6] + res = [5, 1] + incl = [5] + op1 = [:∂ₜ, :⋆, :⋆] + op2 = [:*, :∧] + type = [:infer, :infer, :infer, :infer, :infer, :infer, :Literal, :infer] + name = [Symbol("•2"), :C, Symbol("•3"), :A, :Ċ, :B, Symbol("-1"), Symbol("•1")] + end + Interior = @decapode begin + ∂ₜ(C) == ι(A, B) + end + @test replace_all_op2s!(copy(Interior), LHS, RHS) == + replace_all_op2s!(copy(Interior), LHS, RHS, + only(incident(RHS, :p1, :name)), only(incident(RHS, :p2, :name))) +end + diff --git a/test/runtests.jl b/test/runtests.jl index 1e64509..0cb0f30 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,3 +35,7 @@ end @testset "SummationDecapode Deconstruction" begin include("colanguage.jl") end + +@testset "Open Operators" begin + include("openoperators.jl") +end