Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide Collection of Rewrite Rules #75

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!, AbstractSDRewriteRule, Op1SDRule, Op2SDRule, apply_rule!, rewrite!

using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
Expand Down
4 changes: 2 additions & 2 deletions src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using DataStructures
using ..DiagrammaticEquations
using Catlab

import ..infer_types!, ..resolve_overloads!
import ..infer_types!, ..resolve_overloads!, ..rewrite!

export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec!
export normalize_unicode, varname, infer_types!, resolve_overloads!, rewrite!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec!

include("deca_acset.jl")
include("deca_visualization.jl")
Expand Down
68 changes: 68 additions & 0 deletions src/deca/deca_acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,71 @@ Resolve function overloads based on types of src and tgt.
resolve_overloads!(d::SummationDecapode) =
resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D)

# Default Rewrite Rules
# ---------------------

rewrite_rules_2D = Vector{AbstractSDRewriteRule}([
Op1SDRule(
:Δ₀,
@decapode begin
y == ∘(d,δ)(X)
end),

Op1SDRule(
:Δ₁,
@decapode begin
(X,y)::Form1
y == ∘(d,δ)(X) + ∘(δ,d)(X)
end),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like overkill to use a whole decapode for this since it will always be a single LHS, RHS pair. Also the SymbolicUtils stuff will be able to use the @rule macro for defining rules.

Copy link
Member Author

@lukem12345 lukem12345 Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to replace operations which take place directly on Decapodes. In this case, this Decapode replaces this hand-written sub-Decapode insertion: https://github.com/AlgebraicJulia/Decapodes.jl/blob/main/src/operators.jl#L320

function add_Lie_2D!(::Type{Val{1}}, d::SummationDecapode, proj1_Lie::Int, proj2_Lie::Int, res_Lie::Int)

  ## Outputs result of dual derivative Dual1 to Dual2
  dual_d_1_tgt = add_part!(d, :Var, type=:infer, name=nothing)
  add_part!(d, :Op1, src=proj2_Lie, tgt=dual_d_1_tgt, op1=:d)

  ## Takes interior product of Primal1 and Dual2 to Dual1
  inter_product_2_res = add_part!(d, :Var, type=:infer, name=nothing)
  add_Inter_Prod_2D!(Val{2}, d, dual_d_1_tgt, proj1_Lie, inter_product_2_res)

  ## Takes interior product of Primal1 and Dual1 to Dual0
  inter_product_1_res = add_part!(d, :Var, type=:infer, name=nothing)
  add_Inter_Prod_2D!(Val{1}, d, proj2_Lie, proj1_Lie, inter_product_1_res)

  ## Outputs result of dual derivative Dual0 to Dual1
  dual_d_0_tgt = add_part!(d, :Var, type=:infer, name=nothing)
  add_part!(d, :Op1, src=inter_product_1_res, tgt=dual_d_0_tgt, op1=:d)

  ## Outputs sum of both dual_d_0 and inter_product_2
  summation_tgt = add_part!(d, , sum=res_Lie)

  add_part!(d, :Summand, summand=inter_product_2_res, summation=summation_tgt)
  add_part!(d, :Summand, summand=dual_d_0_tgt, summation=summation


Op1SDRule(
:Δ₂,
@decapode begin
y == ∘(δ,d)(X)
end),

Op1SDRule(
:δ,
@decapode begin
y == ∘(⋆,d,⋆)(X)
end),

Op1SDRule(
:δ₁,
@decapode begin
y == ∘(⋆,d,⋆)(X)
end),

Op1SDRule(
:δ₂,
@decapode begin
y == ∘(⋆,d,⋆)(X)
end),

Op2SDRule(
:ι₁,
@decapode begin
y == -1*⋆((⋆p1) ∧ p2)
end),

Op2SDRule(
:L₀,
@decapode begin
y == ι(p1, d(p2))
end),

Op2SDRule(
:L₁,
@decapode begin
y == ι(p1, d(p2)) + d(ι(p1, p2))
end),

Op2SDRule(
:L₂,
@decapode begin
y == d(ι(p1, p2))
end)])

rewrite!(d::SummationDecapode) =
rewrite!(d, rewrite_rules_2D)

71 changes: 71 additions & 0 deletions src/openoperators.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
abstract type AbstractSDRewriteRule end

struct Op1SDRule <: AbstractSDRewriteRule
LHS::Union{Symbol, SummationDecapode}
RHS::Union{Symbol, SummationDecapode}
end

struct Op2SDRule <: AbstractSDRewriteRule
LHS::Union{Symbol, SummationDecapode}
RHS::Union{Symbol, SummationDecapode}
proj1::Int
proj2::Int
end

function Op2SDRule(LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode})
p1s = incident(RHS, :p1, :name)
p2s = incident(RHS, :p2, :name)
if length(p1s) != 1 || length(p2s) != 1
error("proj1 and proj2 to use were not given, but unique distinguished variables p1 and p2 were not found. Found p1: $(p1s) and p2: $(p2s).")
end
Op2SDRule(LHS, RHS, only(p1s), only(p2s))
end

# Opening up Op1s
# --------------

Expand All @@ -8,13 +31,19 @@ function validate_op1_match(d::SummationDecapode, LHS::SummationDecapode)
end
end

validate_op1_match(d::SummationDecapode, r::Op1SDRule) =
validate_op1_match(d, r.LHS)

# Validate whether RHS represents a valid replacement for an op1.
function validate_op1_replacement(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode)
if length(infer_states(RHS)) != 1 || length(infer_terminals(RHS)) != 1
error("The replacement for $(LHS) must have a single input and a single output, but found inputs: $(RHS[infer_states(RHS), :name]) and outputs $(RHS[infer_terminals(RHS), :name])")
end
end

validate_op1_replacement(d::SummationDecapode, r::Op1SDRule) =
validate_op1_replacement(d, r.LHS, r.RHS)

""" function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode)

Given a Decapode, d, replace at most one instance of the left-hand-side unary operator with those of the right-hand-side.
Expand Down Expand Up @@ -81,6 +110,19 @@ function replace_op1!(d::SummationDecapode, LHS::Symbol, RHS::Symbol)
LHS_op1
end

""" function replace_op1!(d::SummationDecapode, r::Op1SDRule)

Given a Decapode, d, replace at most one instance of the left-hand-side unary operator with those of the right-hand-side.

Return the index of the replaced unary operator, 0 if no match was found.
See also: [`replace_op2!`](@ref), [`replace_all_op1s!`](@ref)
"""
replace_op1!(d::SummationDecapode, r::Op1SDRule) =
replace_op1!(d, r.LHS, r.RHS)

apply_rule!(d::SummationDecapode, r::Op1SDRule) =
replace_op1!(d, r)

""" function replace_all_op1s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode})

Given a Decapode, d, replace all instances of the left-hand-side unary operator with those of the right-hand-side.
Expand All @@ -97,6 +139,17 @@ function replace_all_op1s!(d::SummationDecapode, LHS::Union{Symbol, SummationDec
any_replaced
end

""" function replace_all_op1s!(d::SummationDecapode, r::Op1SDRule)

Given a Decapode, d, replace all instances of the left-hand-side unary operator with those of the right-hand-side.

Return true if any replacements were made, otherwise false.

See also: [`replace_op1!`](@ref), [`replace_all_op2s!`](@ref)
"""
replace_all_op1s!(d::SummationDecapode, r::Op1SDRule) =
replace_all_op1s!(d, r.LHS, r.RHS)

# Opening up Op2s
# --------------

Expand All @@ -107,6 +160,9 @@ function validate_op2_match(d::SummationDecapode, LHS::SummationDecapode)
end
end

validate_op2_match(d::SummationDecapode, r::Op2SDRule) =
validate_op2_match(d, r.LHS)

# Validate whether RHS represents a valid replacement for an op2.
function validate_op2_replacement(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode, proj1::Int, proj2::Int)
if length(infer_states(RHS)) != 2 || length(infer_terminals(RHS)) != 1
Expand All @@ -117,6 +173,9 @@ function validate_op2_replacement(d::SummationDecapode, LHS::Symbol, RHS::Summat
end
end

validate_op2_replacement(d::SummationDecapode, r::Op2SDRule) =
validate_op2_replacement(d, r.LHS, r.RHS, r.proj1, r.proj2)

""" function replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::SummationDecapode, proj1::Int, proj2::Int)

Given a Decapode, d, replace at most one instance of the left-hand-side binary operator with those of the right-hand-side.
Expand Down Expand Up @@ -202,6 +261,12 @@ end
replace_op2!(d::SummationDecapode, LHS::Symbol, RHS::Symbol, proj1, proj2) =
replace_op2!(d, LHS, RHS)

replace_op2!(d::SummationDecapode, r::Op2SDRule) =
replace_op2!(d, r.LHS, r.RHS, r.proj1, r.proj2)

apply_rule!(d::SummationDecapode, r::Op2SDRule) =
replace_op2!(d, r)

""" function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDecapode}, RHS::Union{Symbol, SummationDecapode}, proj1::Int, proj2::Int)

Given a Decapode, d, replace all instances of the left-hand-side binary operator with those of the right-hand-side.
Expand Down Expand Up @@ -245,3 +310,9 @@ function replace_all_op2s!(d::SummationDecapode, LHS::Union{Symbol, SummationDec
any_replaced
end

replace_all_op2s!(d::SummationDecapode, r::Op2SDRule) =
replace_all_op2s!(d, r.LHS, r.RHS, r.proj1, r.proj2)

rewrite!(d::SummationDecapode, rules::AbstractVector{AbstractSDRewriteRule}) =
foreach(r -> apply_rule!(d,r), rules)

1 change: 0 additions & 1 deletion test/core.jl

This file was deleted.

96 changes: 96 additions & 0 deletions test/openoperators.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ACSets
using DiagrammaticEquations
using DiagrammaticEquations.Deca
using Test

@testset "Open Operators" begin
Expand Down Expand Up @@ -239,3 +240,98 @@ using Test
only(incident(RHS, :p1, :name)), only(incident(RHS, :p2, :name)))
end

@testset "Rewrite Rules" begin
# Explicit Rule Application
# -------------------------

# Test expanding the Heat equation.
rule = Op1SDRule(
@decapode begin
(X,y,Z)::Form0
y == Δ(X)
end
,
@decapode begin
(X,y)::Form0
y == -1*∘(d,⋆,d,⋆)(X)
end)
Heat = @decapode begin
C::Form0
∂ₜ(C) == Δ(C)
end
apply_rule!(Heat, rule)
@test Heat == @acset SummationDecapode{Any,Any,Symbol} begin
Var=4
type=[:infer, :Literal, :Form0, :infer]
name=[Symbol("•1"), Symbol("-1"), :C, :Ċ]
TVar=1
incl=[4]
Op1=2
src=[3,3]
tgt=[4,1]
op1=[:∂ₜ, [:d, :⋆, :d, :⋆]]
Op2=1
proj1=[2]
proj2=[1]
res=[4]
op2=[:*]
end

# Test expanding the vector laplacian.
rule = Op1SDRule(
@decapode begin
(X,y)::Form1
y == Δ(X)
end
,
@decapode begin
(X,y)::Form1
y == ∘(d,⋆,d,⋆)(X) + ∘(⋆,d,⋆,d)(X)
end)
VectorHeat = @decapode begin
V::Form1
∂ₜ(V) == -1*Δ(V)
end
apply_rule!(VectorHeat, rule)
@test VectorHeat == @acset SummationDecapode{Any,Any,Symbol} begin
Var = 6
TVar = 1
Op1 = 3
Op2 = 1
Σ = 1
Summand = 2
src = [5, 5, 5]
tgt = [2, 3, 1]
proj1 = [4]
proj2 = [6]
res = [2]
incl = [2]
summand = [1, 3]
summation = [1, 1]
sum = [6]
op1 = Any[:∂ₜ, [:⋆, :d, :⋆, :d], [:d, :⋆, :d, :⋆]]
op2 = [:*]
type = [:infer, :infer, :infer, :Literal, :Form1, :infer]
name = [Symbol("•1"), :V̇, Symbol("•2"), Symbol("-1"), :V, Symbol("•2")]
end

# Default Rules Application
# -------------------------

# Test expanding the Brusselator.
Brusselator = @decapode begin
(U, V)::Form0
U2V::Form0
(U̇, V̇)::Form0
(α)::Constant
F::Parameter
U2V == (U .* U) .* V
U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F
V̇ == (3.4 * U) - U2V + (α * Δ(V))
∂ₜ(U) == U̇
∂ₜ(V) == V̇
end
rewrite!(Brusselator)
@test Brusselator[:op1] == Any[[:d,:⋆,:d,:⋆], [:d,:⋆,:d,:⋆], :∂ₜ, :∂ₜ]
end

Loading