From 0ab1ef13f23c23d81daad22a5f435a78e85241bc Mon Sep 17 00:00:00 2001 From: AlgebraicJulia Bot <129184742+algebraicjuliabot@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:00:37 -0400 Subject: [PATCH 01/39] Set version to 0.1.7 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b40b43a..69f41e4 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DiagrammaticEquations" uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317" license = "MIT" authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"] -version = "0.1.6" +version = "0.1.7" [deps] ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" From e46fdec5f322e5370f6c1daa077e601557c2e3bb Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 23 Aug 2024 13:45:35 -0400 Subject: [PATCH 02/39] Added more exports (#44) Added `apex` and `@relation`, `to_graphviz` from Catlab Co-authored-by: James --- src/DiagrammaticEquations.jl | 5 ++++- test/composition.jl | 5 +---- test/language.jl | 6 ------ 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index be9cace..9729e31 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -2,6 +2,8 @@ """ module DiagrammaticEquations +using Catlab + export DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, # Deca @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!, Collage, collate, ## composition oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, +apex, @relation, # Re-exported from Catlab ## acset SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, @@ -25,12 +28,12 @@ unique_lits!, Plus, AppCirc1, Var, Tan, App1, App2, ## visualization to_graphviz_property_graph, typename, draw_composition, +to_graphviz, # Re-exported from Catlab ## rewrite average_rewrite, ## openoperators transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! -using Catlab using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom using Catlab.Programs diff --git a/test/composition.jl b/test/composition.jl index f0c5c02..408cf37 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -2,15 +2,12 @@ using Test using DiagrammaticEquations using DiagrammaticEquations.Deca using Catlab -using Catlab.WiringDiagrams -using Catlab.Programs -using Catlab.CategoricalAlgebra # import DiagrammaticEquations: OpenSummationDecapode, Open, oapply, oapply_rename # @testset "Composition" begin # Simplest possible decapode relation. -Trivial = @decapode begin +Trivial = @decapode begin H::Form0{X} end diff --git a/test/language.jl b/test/language.jl index 3e3ab4e..c0c8680 100644 --- a/test/language.jl +++ b/test/language.jl @@ -1,11 +1,5 @@ using Test using Catlab -using Catlab.Theories -using Catlab.CategoricalAlgebra -using Catlab.WiringDiagrams -using Catlab.WiringDiagrams.DirectedWiringDiagrams -using Catlab.Graphics -using Catlab.Programs using LinearAlgebra using MLStyle using Base.Iterators From 87a9c5c84219b32ffe5b4e904649aa35e22d8bcc Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 22 Aug 2024 14:35:57 -0400 Subject: [PATCH 03/39] Add type rules for vectorfields --- src/acset.jl | 40 +++++++++++------ src/deca/deca_acset.jl | 26 +++++++---- test/language.jl | 97 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 129 insertions(+), 34 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 9665afd..5366459 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -158,13 +158,16 @@ end # A collection of DecaType getters # TODO: This should be replaced by using a type hierarchy const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, - :Literal, :Parameter, :Constant, :infer] + :PVF, :DVF, + :Literal, :Parameter, :Constant, :infer] const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2] const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2] -const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer] +const VECTORFIELD_TYPES = [:PVF, :DVF] + +const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer] const USER_TYPES = [:Constant, :Parameter] const NUMBER_TYPES = [:Literal] const INFER_TYPES = [:infer] @@ -427,12 +430,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) end """ - filterfor_forms(types::AbstractVector{Symbol}) + filterfor_ec_types(types::AbstractVector{Symbol}) -Return any form type symbols. +Return any form or vector-field type symbols. """ -function filterfor_forms(types::AbstractVector{Symbol}) - conditions = x -> !(x in NONFORM_TYPES) +function filterfor_ec_types(types::AbstractVector{Symbol}) + conditions = x -> !(x in NON_EC_TYPES) filter(conditions, types) end @@ -447,16 +450,16 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) types = d[idxs, :type] all(t != :infer for t in types) && return applied # We need not infer - forms = unique(filterfor_forms(types)) + ec_types = unique(filterfor_ec_types(types)) - form = @match length(forms) begin + ec_type = @match length(ec_types) begin 0 => return applied # We can not infer - 1 => only(forms) - _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms") + 1 => only(ec_types) + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types") end for idx in idxs - applied |= safe_modifytype!(d, idx, form) + applied |= safe_modifytype!(d, idx, ec_type) end return applied @@ -489,11 +492,24 @@ function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - if(check_op && (score_proj1 + score_proj2 + score_res == 2)) + if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res + # Special logic for exponentiation: + elseif d[op2_id, :op2] == :^ && + (type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) || + (type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer)) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0) + mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0) + return mod_proj1 || mod_res + elseif d[op2_id, :op2] == :^ && + (type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) || + (type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer)) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0) + mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0) + return mod_proj1 || mod_res end return false diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 55520e3..c5fa67a 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -32,10 +32,17 @@ op1_inf_rules_1D = [ # Rules for the averaging operator (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), + + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] op2_inf_rules_1D = [ # Rules for ∧₀₀, ∧₁₀, ∧₀₁ @@ -133,13 +140,16 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]), (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]), - (src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]), - (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]), - (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ∧₁₁, ∧₂₀, ∧₀₂ diff --git a/test/language.jl b/test/language.jl index c0c8680..ec1ff3c 100644 --- a/test/language.jl +++ b/test/language.jl @@ -350,13 +350,14 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end -import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES +import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, + DUALFORM_TYPES, VECTORFIELD_TYPES, NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, + INFER_TYPES, NONINFERABLE_TYPES @testset "Type Retrival" begin type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] + NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] # No repeated types @@ -368,12 +369,12 @@ import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_ no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) # Collections of these types should be the same - @test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES)) - @test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES)) - @test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES)) + @test equal_types(ALL_TYPES, FORM_TYPES ∪ VECTORFIELD_TYPES ∪ NON_EC_TYPES) + @test equal_types(FORM_TYPES, PRIMALFORM_TYPES ∪ DUALFORM_TYPES) + @test equal_types(NONINFERABLE_TYPES, USER_TYPES ∪ NUMBER_TYPES) # Proper seperation of types - @test no_overlaps(FORM_TYPES, NONFORM_TYPES) + @test no_overlaps(FORM_TYPES ∪ VECTORFIELD_TYPES, NON_EC_TYPES) @test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES) @test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES) @test INFER_TYPES == [:infer] @@ -394,9 +395,9 @@ end import DiagrammaticEquations: safe_modifytype @testset "Safe Type Modification" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] bad_sources = [:Literal, :Constant, :Parameter] - good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :infer] + good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF, :infer] for tgt in all_types for src in bad_sources @@ -419,13 +420,13 @@ import DiagrammaticEquations: safe_modifytype end end -import DiagrammaticEquations: filterfor_forms +import DiagrammaticEquations: filterfor_ec_types @testset "Form Type Retrieval" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] - @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] - @test isempty(filterfor_forms(Symbol[])) - @test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer])) + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] + @test filterfor_ec_types(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF] + @test isempty(filterfor_ec_types(Symbol[])) + @test isempty(filterfor_ec_types([:Literal, :Constant, :Parameter, :infer])) end @testset "Type Inference" begin @@ -825,6 +826,38 @@ end @test_throws "Type mismatch in summation" infer_types!(d) end + # Test #25: Infer between flattened and sharpened vector fields. + let + d = @decapode begin + A::Form1 + B::DualForm1 + C::PVF + D::DVF + + A == ♭(E) + B == ♭(F) + C == ♯(G) + D == ♯(H) + + I::Form1 + J::DualForm1 + K::PVF + L::DVF + + M == ♯(I) + N == ♯(J) + O == ♭(K) + P == ♭(L) + end + infer_types!(d) + + # TODO: Update this as more sharps and flats are released. + names_types_expected = Set([(:A, :Form1), (:B, :DualForm1), (:C, :PVF), (:D, :DVF), + (:E, :DVF), (:F, :infer), (:G, :Form1), (:H, :DualForm1), + (:I, :Form1), (:J, :DualForm1), (:K, :PVF), (:L, :DVF), + (:M, :PVF), (:N, :DVF), (:O, :infer), (:P, :Form1)]) + @test test_nametype_equality(d, names_types_expected) + end end @testset "Overloading Resolution" begin @@ -1042,6 +1075,42 @@ end op2s_hx = HeatXfer[:op2] op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀] @test op2s_hx == op2s_expected_hx + + # Infer types and resolve overloads for the Halfar equation. + let + d = @decapode begin + h::Form0 + Γ::Form1 + n::Constant + + ∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) ∧ (mag(♯(d(h)))^(n-1)) ∧ (h^(n+2))) + end + d = expand_operators(d) + infer_types!(d) + resolve_overloads!(d) + @test d == @acset SummationDecapode{Any, Any, Symbol} begin + Var = 19 + TVar = 1 + Op1 = 8 + Op2 = 6 + Σ = 1 + Summand = 2 + src = [1, 1, 1, 13, 12, 6, 18, 19] + tgt = [4, 9, 13, 12, 11, 18, 19, 4] + proj1 = [2, 3, 11, 8, 1, 7] + proj2 = [9, 15, 14, 10, 5, 16] + res = [8, 14, 10, 7, 16, 6] + incl = [4] + summand = [3, 17] + summation = [1, 1] + sum = [5] + op1 = [:∂ₜ, :d₀, :d₀, :♯, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] + type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] + name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] + end + end + end @testset "Compilation Transformation" begin From d921c7038b0384bd0990762124c7a91da03b7c98 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 22 Aug 2024 15:38:45 -0400 Subject: [PATCH 04/39] Add musical overload resolution --- src/deca/deca_acset.jl | 5 +++++ test/language.jl | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index c5fa67a..b557553 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -253,6 +253,11 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ), (src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :codif), (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, resolved_name = :♯ᵖᵖ, op = :♯), + (src_type = :DualForm1, tgt_type = :DVF, resolved_name = :♯ᵈᵈ, op = :♯), + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, resolved_name = :♭ᵈᵖ, op = :♭), # Rules for ∇². # TODO: Call this :nabla2 in ASCII? (src_type = :Form0, tgt_type = :Form0, resolved_name = :∇²₀, op = :∇²), diff --git a/test/language.jl b/test/language.jl index ec1ff3c..25f817e 100644 --- a/test/language.jl +++ b/test/language.jl @@ -1104,7 +1104,7 @@ end summand = [3, 17] summation = [1, 1] sum = [5] - op1 = [:∂ₜ, :d₀, :d₀, :♯, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op1 = [:∂ₜ, :d₀, :d₀, :♯ᵖᵖ, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] From 804ca95ad2d12912d69526b15ead5ddfefe3a426 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 22 Aug 2024 17:45:52 -0400 Subject: [PATCH 05/39] Take advantage of :infer in type rules --- src/acset.jl | 32 ++++++-------------------------- src/deca/deca_acset.jl | 6 +++++- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 5366459..34d1d3c 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -466,13 +466,10 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) end function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) - type_src = d[d[op1_id, :src], :type] - type_tgt = d[d[op1_id, :tgt], :type] + score_src = (rule.src_type == d[d[op1_id, :src], :type]) + score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type]) - score_src = (rule.src_type == type_src) - score_tgt = (rule.tgt_type == type_tgt) check_op = (d[op1_id, :op1] in rule.op_names) - if(check_op && (score_src + score_tgt == 1)) mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) @@ -483,33 +480,16 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) end function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) - type_proj1 = d[d[op2_id, :proj1], :type] - type_proj2 = d[d[op2_id, :proj2], :type] - type_res = d[d[op2_id, :res], :type] + score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type]) + score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type]) + score_res = (rule.res_type == d[d[op2_id, :res], :type]) - score_proj1 = (rule.proj1_type == type_proj1) - score_proj2 = (rule.proj2_type == type_proj2) - score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) + mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res - # Special logic for exponentiation: - elseif d[op2_id, :op2] == :^ && - (type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) || - (type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0) - mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0) - return mod_proj1 || mod_res - elseif d[op2_id, :op2] == :^ && - (type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) || - (type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0) - mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0) - return mod_proj1 || mod_res end return false diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index b557553..4334c4b 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -90,7 +90,11 @@ op2_inf_rules_1D = [ (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])] + (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + + # These rules contain infer: + (proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]), + (proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])] """ These are the default rules used to do type inference in the 2D exterior calculus. From 3e916c485b8f183e8c8efee6a7fdde8df9b5567f Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:34:08 -0400 Subject: [PATCH 06/39] Initial attempt at rewriting Converts ACSet to a series of Symbolic terms that can be rewritten with a provided rewriter --- src/DiagrammaticEquations.jl | 2 + src/acset2symbolic.jl | 60 ++++++++++++++++++++++++++++++ src/graph_traversal.jl | 72 ++++++++++++++++++++++++++++++++++++ test/graph_traversal.jl | 64 ++++++++++++++++++++++++++++++++ test/runtests.jl | 3 ++ 5 files changed, 201 insertions(+) create mode 100644 src/acset2symbolic.jl create mode 100644 src/graph_traversal.jl create mode 100644 test/graph_traversal.jl diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 9729e31..4ce554a 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -65,6 +65,8 @@ include("pretty.jl") include("colanguage.jl") include("openoperators.jl") include("symbolictheoryutils.jl") +include("graph_traversal.jl") +include("acset2symbolic.jl") include("deca/Deca.jl") include("learn/Learn.jl") include("SymbolicUtilsInterop.jl") diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl new file mode 100644 index 0000000..a2b24e6 --- /dev/null +++ b/src/acset2symbolic.jl @@ -0,0 +1,60 @@ +using DiagrammaticEquations +using SymbolicUtils +using SymbolicUtils.Rewriters +using SymbolicUtils.Code +using MLStyle + +const DECA_EQUALITY_SYMBOL = (==) + +to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name)) + +function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) + input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name]) + output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name]) + op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) + + rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) + SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) +end + +function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2}) + input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name]) + input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name]) + output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name]) + op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) + + rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) + SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) +end + +#XXX: Always converting + -> .+ here since summation doesn't store the style of addition +# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ}) +# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) +# end + +function extract_symexprs(d::SummationDecapode) + topo_list = topological_sort_edges(d) + sym_list = [] + for node in topo_list + retrieve_name(d, node) != DerivOp || continue + push!(sym_list, to_symbolics(d, node)) + end + sym_list +end + +function apply_rewrites(d::SummationDecapode, rewriter) + + rewritten_list = [] + for sym in extract_symexprs(d) + res_sym = rewriter(sym) + rewritten_sym = isnothing(res_sym) ? sym : res_sym + push!(rewritten_list, rewritten_sym) + end + + rewritten_list +end + +# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet +# @syms Δ(x) d(x) ⋆(x) +# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) +# rewriter = Postwalk(RestartedChain([lap_0_rule])) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl new file mode 100644 index 0000000..b9551c8 --- /dev/null +++ b/src/graph_traversal.jl @@ -0,0 +1,72 @@ +using DiagrammaticEquations +using ACSets + +export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name + +struct TraversalNode{T} + index::Int + name::T +end + +function topological_sort_edges(d::SummationDecapode) + visited_Var = falses(nparts(d, :Var)) + visited_Var[start_nodes(d)] .= true + + # TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ + visited_1 = falses(nparts(d, :Op1)) + visited_2 = falses(nparts(d, :Op2)) + visited_Σ = falses(nparts(d, :Σ)) + + # FIXME: this is a quadratic implementation of topological_sort inlined in here. + op_order = TraversalNode{Symbol}[] + + for _ in 1:number_of_ops(d) + for op in parts(d, :Op1) + if !visited_1[op] && visited_Var[d[op, :src]] + + visited_1[op] = true + visited_Var[d[op, :tgt]] = true + + push!(op_order, TraversalNode(op, :Op1)) + end + end + + for op in parts(d, :Op2) + if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]] + visited_2[op] = true + visited_Var[d[op, :res]] = true + push!(op_order, TraversalNode(op, :Op2)) + end + end + + for op in parts(d, :Σ) + args = subpart(d, incident(d, op, :summation), :summand) + if !visited_Σ[op] && all(visited_Var[args]) + visited_Σ[op] = true + visited_Var[d[op, :sum]] = true + push!(op_order, TraversalNode(op, :Σ)) + end + end + end + + @assert length(op_order) == number_of_ops(d) + + op_order +end + +function number_of_ops(d::SummationDecapode) + return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) +end + +function start_nodes(d::SummationDecapode) + return vcat(infer_states(d), incident(d, :Literal, :type)) +end + +function retrieve_name(d::SummationDecapode, tsr::TraversalNode) + @match tsr.name begin + :Op1 => d[tsr.index, :op1] + :Op2 => d[tsr.index, :op2] + :Σ => :+ + _ => error("$(tsr.name) is not a valid table for names") + end +end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl new file mode 100644 index 0000000..7a5d6ac --- /dev/null +++ b/test/graph_traversal.jl @@ -0,0 +1,64 @@ +using DiagrammaticEquations +using ACSets +using MLStyle +using Test + +function is_correct_length(d::SummationDecapode, result) + return length(result) == number_of_ops(d) +end + +@testset "Topological Sort on Edges" begin + no_edge = @decapode begin + F == S + end + @test isempty(topological_sort_edges(no_edge)) + + one_op1_deca = @decapode begin + F == f(S) + end + result = topological_sort_edges(one_op1_deca) + @test is_correct_length(one_op1_deca, result) + @test retrieve_name(one_op1_deca, only(result)) == :f + + multi_op1_deca = @decapode begin + F == c(b(a(S))) + end + result = topological_sort_edges(multi_op1_deca) + @test is_correct_length(multi_op1_deca, result) + for (edge, test_name) in zip(result, [:a, :b, :c]) + @test retrieve_name(multi_op1_deca, edge) == test_name + end + + cyclic = @decapode begin + B == g(A) + A == f(B) + end + @test_throws AssertionError topological_sort_edges(cyclic) + + just_op2 = @decapode begin + C == A * B + end + result = topological_sort_edges(just_op2) + @test is_correct_length(just_op2, result) + @test retrieve_name(just_op2, only(result)) == :* + + just_simple_sum = @decapode begin + C == A + B + end + result = topological_sort_edges(just_simple_sum) + @test is_correct_length(just_simple_sum, result) + @test retrieve_name(just_simple_sum, only(result)) == :+ + + just_multi_sum = @decapode begin + F == A + B + C + D + E + end + result = topological_sort_edges(just_multi_sum) + @test is_correct_length(just_multi_sum, result) + @test retrieve_name(just_multi_sum, only(result)) == :+ + + op_combo = @decapode begin + F == h(d(A) + f(g(B) * C) + D) + end + result = topological_sort_edges(op_combo) + @test is_correct_length(op_combo, result) +end diff --git a/test/runtests.jl b/test/runtests.jl index dd92531..972d72c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,3 +39,6 @@ end end include("aqua.jl") +@testset "Symbolic Rewriting" begin + include("graph_traversal.jl") +end From 5fbe4b482e0bae6359b32844b7f491851583fbae Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:39:43 -0400 Subject: [PATCH 07/39] Added proof of concept Added a short script showcasing how rewriting could be done with the `Sort` types and a reference ACSet. --- src/sym_rewrite.jl | 60 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/sym_rewrite.jl diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl new file mode 100644 index 0000000..005630c --- /dev/null +++ b/src/sym_rewrite.jl @@ -0,0 +1,60 @@ +using DiagrammaticEquations +using SymbolicUtils +using MLStyle + +test_space = Space(:X, 2) + +test_type = Form(0, false, test_space) + +Heat = @decapode begin + C::Form0 + D::Constant + ∂ₜ(C) == (D+2)*Δ(C) +end + +infer_types!(Heat) +resolve_overloads!(Heat) + +function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort + @match old begin + :Form0 => PrimalForm(0, space) + :Form1 => PrimalForm(1, space) + :Form2 => PrimalForm(2, space) + + :DualForm0 => DualForm(0, space) + :DualForm1 => DualForm(1, space) + :DualForm2 => DualForm(2, space) + + :Constant => Scalar() + :Parameter => Scalar() + end +end + +function isform_zero(x) + getmetadata(x, Sort).dim == 0 +end + +function isform_two(x) + getmetadata(x, Sort).dim == 2 +end + +@syms Δ(x) d(x) ⋆(x) + +@syms C D Ċ sum_1 + +C = setmetadata(C, Sort, oldtype_to_new(Heat[1, :type])) + +lap_0_rule = @rule Δ(~x::(isform_zero)) => ⋆(d(⋆(d(~x)))) +lap_2_rule = @rule Δ(~x::(isform_two)) => d(⋆(d(⋆(~x)))) + +test_eq = Δ(C) +lap_0_rule(test_eq) +lap_2_rule(test_eq) === nothing + +rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain([lap_0_rule])) +test_eq_long = (D + 2) * Δ(C) +rewriter(test_eq_long) + +# Δ₀(x) = ⋆₀⁻¹(dual_d₁(⋆₁(d₀(x)))) +# @syms Δ₀(x) d₀(x) ⋆₁(x) dual_d₁(x) ⋆₀⁻¹(x) +# lap_0_rule = @rule Δ₀(~x::(isform_zero)) => ⋆₀⁻¹(dual_d₁(⋆₁(d₀(~x)))) From 33db81338fb39e338a39c7ce7132457de811b284 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:09:27 -0400 Subject: [PATCH 08/39] Added ability to do through-op rewrites This now supports the ability for ACSet intermediate expressions to be merged into one single expression upon which rewriting rules (like dd=0) may be performed. --- src/acset2symbolic.jl | 70 ++++++++++++++++++++++++++++++++++++++++++ src/graph_traversal.jl | 2 +- src/sym_rewrite.jl | 58 +++++++++++++++++----------------- 3 files changed, 101 insertions(+), 29 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index a2b24e6..0b123b6 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -4,6 +4,8 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle +export extract_symexprs, apply_rewrites, merge_equations + const DECA_EQUALITY_SYMBOL = (==) to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name)) @@ -13,6 +15,9 @@ function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name]) op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) + input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type])) + output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type])) + rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end @@ -23,6 +28,10 @@ function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2}) output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name]) op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) + input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type])) + input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type])) + output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type])) + rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end @@ -32,6 +41,21 @@ end # Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) # end +function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort + @match old begin + :Form0 => PrimalForm(0, space) + :Form1 => PrimalForm(1, space) + :Form2 => PrimalForm(2, space) + + :DualForm0 => DualForm(0, space) + :DualForm1 => DualForm(1, space) + :DualForm2 => DualForm(2, space) + + :Constant => Scalar() + :Parameter => Scalar() + end +end + function extract_symexprs(d::SummationDecapode) topo_list = topological_sort_edges(d) sym_list = [] @@ -54,6 +78,52 @@ function apply_rewrites(d::SummationDecapode, rewriter) rewritten_list end +function merge_equations(d::SummationDecapode, rewritten_syms) + + lookup = Dict() + + final_list = [] + + for node in start_nodes(d) + sym = SymbolicUtils.Sym{Number}(d[node, :name]) + sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type])) + push!(lookup, (sym => sym)) + end + + final_nodes = infer_terminal_names(d) + + for expr in rewritten_syms + lhs = expr_lhs(expr) + rhs = expr_rhs(expr) + + recursive_descent(rhs, lookup) + + push!(lookup, (lhs => rhs)) + + if lhs.name in final_nodes + push!(final_list, expr) + end + end + + final_list +end + +expr_lhs(expr) = expr.arguments[1] +expr_rhs(expr) = expr.arguments[2] + +function recursive_descent(expr, lookup) + # @show expr + for (i, arg) in enumerate(expr.arguments) + # @show arg + if arg in keys(lookup) + expr.arguments[i] = lookup[arg] + else + recursive_descent(arg, lookup) + end + end + return expr +end + # TODO: We need a way to get information like the d and ⋆ even when not in the ACSet # @syms Δ(x) d(x) ⋆(x) # lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index b9551c8..1fd78e2 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,7 +1,7 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name +export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name, start_nodes struct TraversalNode{T} index::Int diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index 005630c..0787091 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -9,52 +9,54 @@ test_type = Form(0, false, test_space) Heat = @decapode begin C::Form0 D::Constant - ∂ₜ(C) == (D+2)*Δ(C) + ∂ₜ(C) == D*Δ(d(C)) end infer_types!(Heat) resolve_overloads!(Heat) -function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort - @match old begin - :Form0 => PrimalForm(0, space) - :Form1 => PrimalForm(1, space) - :Form2 => PrimalForm(2, space) - - :DualForm0 => DualForm(0, space) - :DualForm1 => DualForm(1, space) - :DualForm2 => DualForm(2, space) - - :Constant => Scalar() - :Parameter => Scalar() - end -end - function isform_zero(x) getmetadata(x, Sort).dim == 0 end +function isform_one(x) + getmetadata(x, Sort).dim == 1 +end + function isform_two(x) getmetadata(x, Sort).dim == 2 end -@syms Δ(x) d(x) ⋆(x) +@syms Δ(x) d(x) ⋆(x) Δ₀(x) Δ₁(x) Δ₂(x) d₀(x) -@syms C D Ċ sum_1 +lap_0_convert = @rule Δ₀(~x) => Δ(~x) +lap_1_convert = @rule Δ₁(~x) => Δ(~x) +lap_2_convert = @rule Δ₂(~x) => Δ(~x) -C = setmetadata(C, Sort, oldtype_to_new(Heat[1, :type])) +d_0_convert = @rule d₀(~x) => d(~x) + +overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] lap_0_rule = @rule Δ(~x::(isform_zero)) => ⋆(d(⋆(d(~x)))) +lap_1_rule = @rule Δ(~x::(isform_one)) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) lap_2_rule = @rule Δ(~x::(isform_two)) => d(⋆(d(⋆(~x)))) -test_eq = Δ(C) -lap_0_rule(test_eq) -lap_2_rule(test_eq) === nothing +openers = [lap_0_rule, lap_1_rule, lap_2_rule] + +heat_exprs = extract_symexprs(Heat) + +rewriter = SymbolicUtils.Postwalk( + SymbolicUtils.Fixpoint(SymbolicUtils.Chain(vcat(overloaders, openers)))) + +res_exprs = apply_rewrites(Heat, rewriter) + +merge_exprs = merge_equations(Heat, res_exprs) + +optm_dd_0 = @rule d(d(~x)) => 0 +star_0 = @rule ⋆(0) => 0 +d_0 = @rule d(0) => 0 -rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain([lap_0_rule])) -test_eq_long = (D + 2) * Δ(C) -rewriter(test_eq_long) +optm_rewriter = SymbolicUtils.Postwalk( + SymbolicUtils.Fixpoint(SymbolicUtils.Chain([optm_dd_0, star_0, d_0]))) -# Δ₀(x) = ⋆₀⁻¹(dual_d₁(⋆₁(d₀(x)))) -# @syms Δ₀(x) d₀(x) ⋆₁(x) dual_d₁(x) ⋆₀⁻¹(x) -# lap_0_rule = @rule Δ₀(~x::(isform_zero)) => ⋆₀⁻¹(dual_d₁(⋆₁(d₀(~x)))) +optm_rewriter(merge_exprs[1]) From c9c6aef7fa792889b09ad889040a69fc66d45154 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:23:46 -0400 Subject: [PATCH 09/39] Added Space import --- src/DiagrammaticEquations.jl | 4 +++- src/acset2symbolic.jl | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 4ce554a..0ef26a9 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -66,10 +66,12 @@ include("colanguage.jl") include("openoperators.jl") include("symbolictheoryutils.jl") include("graph_traversal.jl") -include("acset2symbolic.jl") include("deca/Deca.jl") include("learn/Learn.jl") include("SymbolicUtilsInterop.jl") +include("ThDEC.jl") +include("decasymbolic.jl") +include("acset2symbolic.jl") @reexport using .Deca @reexport using .SymbolicUtilsInterop diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 0b123b6..9f91cac 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -4,6 +4,8 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle +import DiagrammaticEquations.ThDEC: Space + export extract_symexprs, apply_rewrites, merge_equations const DECA_EQUALITY_SYMBOL = (==) From 80455457778636c3f6ae0055b695a0a82dd5c9dc Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:37:09 -0400 Subject: [PATCH 10/39] Completed full pipeline Can take ACSets to Symbolics back to ACSets --- src/acset2symbolic.jl | 23 +++++++++++++++++++++++ src/sym_rewrite.jl | 18 +++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 9f91cac..63c3ff2 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -126,6 +126,29 @@ function recursive_descent(expr, lookup) return expr end +function to_acset(og_d, sym_exprs) + final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) + map(x -> x.args[1] = :(==), final_exprs) + + deca_block = quote end + + states = infer_states(og_d) + terminals = infer_terminals(og_d) + + deca_type_gen = idx -> :($(og_d[idx, :name])::$(og_d[idx, :type])) + + append!(deca_block.args, map(deca_type_gen, vcat(states, terminals))) + + append!(deca_block.args, final_exprs) + + d = SummationDecapode(parse_decapode(deca_block)) + + infer_types!(d) + resolve_overloads!(d) + + d +end + # TODO: We need a way to get information like the d and ⋆ even when not in the ACSet # @syms Δ(x) d(x) ⋆(x) # lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index 0787091..6dff12a 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -59,4 +59,20 @@ d_0 = @rule d(0) => 0 optm_rewriter = SymbolicUtils.Postwalk( SymbolicUtils.Fixpoint(SymbolicUtils.Chain([optm_dd_0, star_0, d_0]))) -optm_rewriter(merge_exprs[1]) +res_merge_exprs = map(optm_rewriter, merge_exprs) + +final_exprs = SymbolicUtils.Code.toexpr.(res_merge_exprs) +map(x -> x.args[1] = :(==), final_exprs) + +to_decapode = quote + C::Form0 + D::Constant + Ċ::Form0 +end + +append!(to_decapode.args, final_exprs) + +test = parse_decapode(to_decapode) +deca_test = SummationDecapode(test) +infer_types!(deca_test) +resolve_overloads!(deca_test) From 8097521e075d70cbdee233028171d5757e3b4c5f Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:59:24 -0400 Subject: [PATCH 11/39] Remove metadata usage This needs to switch to use the new type system --- src/DiagrammaticEquations.jl | 2 -- src/acset2symbolic.jl | 45 +++++++++++++++--------------------- src/sym_rewrite.jl | 36 ++++------------------------- 3 files changed, 23 insertions(+), 60 deletions(-) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 0ef26a9..72cada6 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -69,8 +69,6 @@ include("graph_traversal.jl") include("deca/Deca.jl") include("learn/Learn.jl") include("SymbolicUtilsInterop.jl") -include("ThDEC.jl") -include("decasymbolic.jl") include("acset2symbolic.jl") @reexport using .Deca diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 63c3ff2..f84b3f2 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -4,9 +4,7 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle -import DiagrammaticEquations.ThDEC: Space - -export extract_symexprs, apply_rewrites, merge_equations +export extract_symexprs, apply_rewrites, merge_equations, to_acset const DECA_EQUALITY_SYMBOL = (==) @@ -17,8 +15,8 @@ function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name]) op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) - input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type])) - output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type])) + # input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type])) + # output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type])) rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) @@ -30,9 +28,9 @@ function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2}) output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name]) op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) - input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type])) - input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type])) - output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type])) + # input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type])) + # input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type])) + # output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type])) rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) @@ -43,20 +41,20 @@ end # Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) # end -function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort - @match old begin - :Form0 => PrimalForm(0, space) - :Form1 => PrimalForm(1, space) - :Form2 => PrimalForm(2, space) +# function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort +# @match old begin +# :Form0 => PrimalForm(0, space) +# :Form1 => PrimalForm(1, space) +# :Form2 => PrimalForm(2, space) - :DualForm0 => DualForm(0, space) - :DualForm1 => DualForm(1, space) - :DualForm2 => DualForm(2, space) +# :DualForm0 => DualForm(0, space) +# :DualForm1 => DualForm(1, space) +# :DualForm2 => DualForm(2, space) - :Constant => Scalar() - :Parameter => Scalar() - end -end +# :Constant => Scalar() +# :Parameter => Scalar() +# end +# end function extract_symexprs(d::SummationDecapode) topo_list = topological_sort_edges(d) @@ -88,7 +86,7 @@ function merge_equations(d::SummationDecapode, rewritten_syms) for node in start_nodes(d) sym = SymbolicUtils.Sym{Number}(d[node, :name]) - sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type])) + # sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type])) push!(lookup, (sym => sym)) end @@ -148,8 +146,3 @@ function to_acset(og_d, sym_exprs) d end - -# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet -# @syms Δ(x) d(x) ⋆(x) -# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) -# rewriter = Postwalk(RestartedChain([lap_0_rule])) diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index 6dff12a..de21b89 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -2,10 +2,6 @@ using DiagrammaticEquations using SymbolicUtils using MLStyle -test_space = Space(:X, 2) - -test_type = Form(0, false, test_space) - Heat = @decapode begin C::Form0 D::Constant @@ -15,18 +11,6 @@ end infer_types!(Heat) resolve_overloads!(Heat) -function isform_zero(x) - getmetadata(x, Sort).dim == 0 -end - -function isform_one(x) - getmetadata(x, Sort).dim == 1 -end - -function isform_two(x) - getmetadata(x, Sort).dim == 2 -end - @syms Δ(x) d(x) ⋆(x) Δ₀(x) Δ₁(x) Δ₂(x) d₀(x) lap_0_convert = @rule Δ₀(~x) => Δ(~x) @@ -37,9 +21,9 @@ d_0_convert = @rule d₀(~x) => d(~x) overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] -lap_0_rule = @rule Δ(~x::(isform_zero)) => ⋆(d(⋆(d(~x)))) -lap_1_rule = @rule Δ(~x::(isform_one)) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) -lap_2_rule = @rule Δ(~x::(isform_two)) => d(⋆(d(⋆(~x)))) +lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) +lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) +lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) openers = [lap_0_rule, lap_1_rule, lap_2_rule] @@ -61,18 +45,6 @@ optm_rewriter = SymbolicUtils.Postwalk( res_merge_exprs = map(optm_rewriter, merge_exprs) -final_exprs = SymbolicUtils.Code.toexpr.(res_merge_exprs) -map(x -> x.args[1] = :(==), final_exprs) - -to_decapode = quote - C::Form0 - D::Constant - Ċ::Form0 -end - -append!(to_decapode.args, final_exprs) - -test = parse_decapode(to_decapode) -deca_test = SummationDecapode(test) +deca_test = to_acset(Heat, res_merge_exprs) infer_types!(deca_test) resolve_overloads!(deca_test) From e0ff9a873d866945ee30415db558e39368fb16c0 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Mon, 16 Sep 2024 14:03:16 -0400 Subject: [PATCH 12/39] Added DECQuantity types Also switched to using SymbolicsUtils' `substitute`. Still needs tests and code needs to be cleaned up. --- src/acset2symbolic.jl | 97 ++++++++++++++++++------------------------ src/graph_traversal.jl | 2 +- 2 files changed, 43 insertions(+), 56 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index f84b3f2..802d182 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -1,4 +1,5 @@ using DiagrammaticEquations +using ACSets using SymbolicUtils using SymbolicUtils.Rewriters using SymbolicUtils.Code @@ -8,30 +9,37 @@ export extract_symexprs, apply_rewrites, merge_equations, to_acset const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name)) +to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) -function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) - input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name]) - output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name]) - op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) +function symbolics_lookup(d::SummationDecapode) + lookup = Dict{Symbol, SymbolicUtils.BasicSymbolic}() + for i in parts(d, :Var) + push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) + end + lookup +end + +function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) + var = d[index, :name] + new_type = symtype(Deca.DECQuantity, d[index, :type], space) + SymbolicUtils.Sym{new_type}(var) +end - # input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type])) - # output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type])) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1}) + input_sym = symvar_lookup[d[d[op_index, :src], :name]] + output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] + op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end -function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2}) - input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name]) - input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name]) - output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name]) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2}) + input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] + input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] + output_sym = symvar_lookup[d[d[op_index, :res], :name]] op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) - # input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type])) - # input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type])) - # output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type])) - rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end @@ -41,27 +49,23 @@ end # Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) # end -# function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort -# @match old begin -# :Form0 => PrimalForm(0, space) -# :Form1 => PrimalForm(1, space) -# :Form2 => PrimalForm(2, space) +function symbolic_rewriting(old_d::SummationDecapode) + d = deepcopy(old_d) + + infer_types!(d) + resolve_overloads!(d) -# :DualForm0 => DualForm(0, space) -# :DualForm1 => DualForm(1, space) -# :DualForm2 => DualForm(2, space) + symvar_lookup = symbolics_lookup(d) -# :Constant => Scalar() -# :Parameter => Scalar() -# end -# end + symexprs = extract_symexprs(d, symvar_lookup) +end -function extract_symexprs(d::SummationDecapode) +function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}) topo_list = topological_sort_edges(d) sym_list = [] for node in topo_list retrieve_name(d, node) != DerivOp || continue - push!(sym_list, to_symbolics(d, node)) + push!(sym_list, to_symbolics(d, symvar_lookup, node)) end sym_list end @@ -78,52 +82,35 @@ function apply_rewrites(d::SummationDecapode, rewriter) rewritten_list end -function merge_equations(d::SummationDecapode, rewritten_syms) +function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, rewritten_syms) - lookup = Dict() + eqn_lookup = Dict() final_list = [] for node in start_nodes(d) - sym = SymbolicUtils.Sym{Number}(d[node, :name]) - # sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type])) - push!(lookup, (sym => sym)) + sym = symvar_lookup[d[node, :name]] + push!(eqn_lookup, (sym => sym)) end final_nodes = infer_terminal_names(d) for expr in rewritten_syms - lhs = expr_lhs(expr) - rhs = expr_rhs(expr) - recursive_descent(rhs, lookup) + merged_eqn = SymbolicUtils.substitute(expr, eqn_lookup) + lhs = merged_eqn.arguments[1] + rhs = merged_eqn.arguments[2] - push!(lookup, (lhs => rhs)) + push!(eqn_lookup, (lhs => rhs)) if lhs.name in final_nodes - push!(final_list, expr) + push!(final_list, merged_eqn) end end final_list end -expr_lhs(expr) = expr.arguments[1] -expr_rhs(expr) = expr.arguments[2] - -function recursive_descent(expr, lookup) - # @show expr - for (i, arg) in enumerate(expr.arguments) - # @show arg - if arg in keys(lookup) - expr.arguments[i] = lookup[arg] - else - recursive_descent(arg, lookup) - end - end - return expr -end - function to_acset(og_d, sym_exprs) final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) map(x -> x.args[1] = :(==), final_exprs) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 1fd78e2..43b8860 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -67,6 +67,6 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode) :Op1 => d[tsr.index, :op1] :Op2 => d[tsr.index, :op2] :Σ => :+ - _ => error("$(tsr.name) is not a valid table for names") + _ => error("$(tsr.name) is a table without names") end end From b9b4146ebf677c5c6a8cd503380463ae2d4f8ed3 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:42:32 -0400 Subject: [PATCH 13/39] Completed pipeline again Addition now works as well but rewriting seems to be janky, unrelated to this pipeline specifically I believe. --- src/acset2symbolic.jl | 55 +++++++++++++++++++++++++++---------- src/deca/ThDEC.jl | 2 ++ src/sym_rewrite.jl | 64 +++++++++++++++++++++++++++++-------------- 3 files changed, 87 insertions(+), 34 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 802d182..a813a3b 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -5,7 +5,7 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle -export extract_symexprs, apply_rewrites, merge_equations, to_acset +export extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup const DECA_EQUALITY_SYMBOL = (==) @@ -21,14 +21,19 @@ end function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) var = d[index, :name] - new_type = symtype(Deca.DECQuantity, d[index, :type], space) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) + @info new_type SymbolicUtils.Sym{new_type}(var) end function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1}) input_sym = symvar_lookup[d[d[op_index, :src], :name]] output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] - op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) + # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) + + op_sym = getfield(@__MODULE__, d[op_index, :op1]) + + @info typeof(op_sym) rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) @@ -38,26 +43,31 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] output_sym = symvar_lookup[d[d[op_index, :res], :name]] - op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) + # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) + + op_sym = getfield(@__MODULE__, d[op_index, :op2]) rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end #XXX: Always converting + -> .+ here since summation doesn't store the style of addition -# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ}) -# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) -# end +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Σ}) + syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] + output_sym = symvar_lookup[d[d[op_index, :sum], :name]] + + rhs = SymbolicUtils.Term{Number}(+, syms_array) + SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) +end function symbolic_rewriting(old_d::SummationDecapode) d = deepcopy(old_d) infer_types!(d) - resolve_overloads!(d) + # resolve_overloads!(d) symvar_lookup = symbolics_lookup(d) - - symexprs = extract_symexprs(d, symvar_lookup) + merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}) @@ -70,10 +80,10 @@ function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symb sym_list end -function apply_rewrites(d::SummationDecapode, rewriter) +function apply_rewrites(symexprs, rewriter) rewritten_list = [] - for sym in extract_symexprs(d) + for sym in symexprs res_sym = rewriter(sym) rewritten_sym = isnothing(res_sym) ? sym : res_sym push!(rewritten_list, rewritten_sym) @@ -113,7 +123,18 @@ end function to_acset(og_d, sym_exprs) final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) - map(x -> x.args[1] = :(==), final_exprs) + + recursive_descent = @λ begin + e::Expr => begin + if e.head == :call + @show nameof(e.args[1]) + e.args[1] = nameof(e.args[1]) + map(recursive_descent, e.args[2:end]) + end + end + sym => nothing + end + map(recursive_descent, final_exprs) deca_block = quote end @@ -124,12 +145,18 @@ function to_acset(og_d, sym_exprs) append!(deca_block.args, map(deca_type_gen, vcat(states, terminals))) + for op1 in parts(og_d, :Op1) + if og_d[op1, :op1] == DerivOp + push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name])))) + end + end + append!(deca_block.args, final_exprs) d = SummationDecapode(parse_decapode(deca_block)) infer_types!(d) - resolve_overloads!(d) + # resolve_overloads!(d) d end diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 383274e..84186f8 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -178,6 +178,8 @@ end @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) end +@alias (Δ₀, Δ₁, Δ₂) => Δ + @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin (PatScalar(_), PatScalar(_)) => Scalar diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index de21b89..a4b6f2c 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -4,47 +4,71 @@ using MLStyle Heat = @decapode begin C::Form0 + G::Form1 D::Constant - ∂ₜ(C) == D*Δ(d(C)) + ∂ₜ(G) == D*Δ(d(C)) end infer_types!(Heat) -resolve_overloads!(Heat) -@syms Δ(x) d(x) ⋆(x) Δ₀(x) Δ₁(x) Δ₂(x) d₀(x) +Brusselator = @decapode begin + (U, V)::Form0 + U2V::Form0 + (U̇, V̇)::Form0 -lap_0_convert = @rule Δ₀(~x) => Δ(~x) -lap_1_convert = @rule Δ₁(~x) => Δ(~x) -lap_2_convert = @rule Δ₂(~x) => Δ(~x) + (α)::Constant + F::Parameter -d_0_convert = @rule d₀(~x) => d(~x) + U2V == (U .* U) .* V -overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] + U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F + V̇ == (3.4 * U) - U2V + (α * Δ(V)) + ∂ₜ(U) == U̇ + ∂ₜ(V) == V̇ +end +infer_types!(Brusselator) + +Phytodynamics = @decapode begin + (n,w)::Form0 + m::Constant + ∂ₜ(n) == w + m*n + Δ(n) +end +infer_types!(Phytodynamics) +test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics)) + +# resolve_overloads!(Heat) + +# lap_0_convert = @rule Δ₀(~x) => Δ(~x) +# lap_1_convert = @rule Δ₁(~x) => Δ(~x) +# lap_2_convert = @rule Δ₂(~x) => Δ(~x) + +# d_0_convert = @rule d₀(~x) => d(~x) + +# overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] -lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) -lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) -lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) +# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) +# lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) +# lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) -openers = [lap_0_rule, lap_1_rule, lap_2_rule] +# openers = [lap_0_rule, lap_1_rule, lap_2_rule] -heat_exprs = extract_symexprs(Heat) +r = rules(Δ, Val(1)) -rewriter = SymbolicUtils.Postwalk( - SymbolicUtils.Fixpoint(SymbolicUtils.Chain(vcat(overloaders, openers)))) +heat_exprs = symbolic_rewriting(Heat) -res_exprs = apply_rewrites(Heat, rewriter) +rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(test_rule)) -merge_exprs = merge_equations(Heat, res_exprs) +res_exprs = apply_rewrites(heat_exprs, rewriter) optm_dd_0 = @rule d(d(~x)) => 0 -star_0 = @rule ⋆(0) => 0 +star_0 = @rule ★(0) => 0 d_0 = @rule d(0) => 0 optm_rewriter = SymbolicUtils.Postwalk( SymbolicUtils.Fixpoint(SymbolicUtils.Chain([optm_dd_0, star_0, d_0]))) -res_merge_exprs = map(optm_rewriter, merge_exprs) +res_merge_exprs = map(optm_rewriter, res_exprs) -deca_test = to_acset(Heat, res_merge_exprs) +deca_test = to_acset(Heat, res_exprs) infer_types!(deca_test) resolve_overloads!(deca_test) From 87f65fe6de1eb18e6e8d89a60259ba58e38a2a19 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 18 Sep 2024 16:23:46 -0400 Subject: [PATCH 14/39] fixed bug where type-checking subtraction uses +(S1,S2), which is obsolete --- src/deca/ThDEC.jl | 4 +++- test/decasymbolic.jl | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 84186f8..00813ff 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -195,7 +195,9 @@ end end end -@operator -(S1, S2)::DECQuantity begin +(S1, S2) end +@operator -(S1, S2)::DECQuantity begin + promote_symtype(+, S1, S2) +end @operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 0b5b7e9..bd82387 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -13,7 +13,6 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication - @testset "Term Construction" begin @test symtype(a) == Scalar @@ -44,6 +43,8 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar @test promote_symtype(∧, u, u) == PrimalForm{0, :X, 2} @test promote_symtype(∧, u, ω) == PrimalForm{1, :X, 2} + @test promote_symtype(-, a) == Scalar + @test promote_symtype(-, u, u) == PrimalForm{0, :X, 2} # test composition @test promote_symtype(d ∘ d, u) == PrimalForm{2, :X, 2} From 90b1adc6f67c085a359df66e610e80dc167868f8 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 19 Sep 2024 17:29:29 -0400 Subject: [PATCH 15/39] George and I debugged rewriting. Incorrect type passed to resulting term meant typed rewriting would fail --- src/acset2symbolic.jl | 13 ++++++---- src/sym_rewrite.jl | 56 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index a813a3b..0bf24fd 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -33,12 +33,12 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic op_sym = getfield(@__MODULE__, d[op_index, :op1]) - @info typeof(op_sym) - - rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) + S = promote_symtype(op_sym, input_sym) + rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end +# TODO add promote_symtype as Op1 function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2}) input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] @@ -47,7 +47,8 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic op_sym = getfield(@__MODULE__, d[op_index, :op2]) - rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) + S = promote_symtype(op_sym, input1_sym, input2_sym) + rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end @@ -56,7 +57,9 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] output_sym = symvar_lookup[d[d[op_index, :sum], :name]] - rhs = SymbolicUtils.Term{Number}(+, syms_array) + # TODO pls test + S = promote_symtype(+, syms_array...) + rhs = SymbolicUtils.Term{S}(+, syms_array) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index a4b6f2c..73f4cd9 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -1,5 +1,7 @@ +using Test using DiagrammaticEquations using SymbolicUtils +using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, symtype, promote_symtype using MLStyle Heat = @decapode begin @@ -7,8 +9,7 @@ Heat = @decapode begin G::Form1 D::Constant ∂ₜ(G) == D*Δ(d(C)) -end - +end; infer_types!(Heat) Brusselator = @decapode begin @@ -52,13 +53,60 @@ test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics)) # openers = [lap_0_rule, lap_1_rule, lap_2_rule] -r = rules(Δ, Val(1)) +# it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms. heat_exprs = symbolic_rewriting(Heat) +sub = heat_exprs[1].arguments[2].arguments[2] + +a, b = @syms a::Scalar b::Scalar +u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} + +r = rules(Δ, Val(1)) + +# rule without predication works +R = @rule Δ(~x) => ★(d(★(d(~x)))) +rwR = Fixpoint(Prewalk(Chain([R]))) + +R(Δ(d(u))) + +# since promote_symtype(d(u)) returns Any while promote_symtype(d, u). I wonder +# if `d(u)` is not subjected to `symtype` + +Rp = @rule Δ(~x::isForm1) => "Success" +Rp(Δ(v)) # works +Rp(Δ(d(u))) # works + +Rp1 = @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + +Rp1(Δ(v)) # works +Rp1(Δ(d(u))) # works +rwRp1 = Fixpoint(Prewalk(Chain([Rp1]))) +rwRp1(Δ(d(u))) + +rwr = Fixpoint(Prewalk(Chain(r))) +rwr(heat_exprs[1]) # THIS WORKS! + +rwr(Δ(d(u))) # rwr +rwr(heat_exprs[1].arguments[2]) + +r[2](Δ(d(u))) # works + + +# rwR(heat_exprs[1]) +# rwR(sub) + +# tilde? +R1 = @rule Δ(~~x::(x->isForm1(x))) => ★(d(★(d(~x)))) + +@macroexpand @rule Δ(~x::isForm1) => "Success" + +# pulling out the subexpression +rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(SymbolicUtils.Chain(r))) + -rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(test_rule)) res_exprs = apply_rewrites(heat_exprs, rewriter) +sub_exprs = apply_rewrites([sub], rewriter) optm_dd_0 = @rule d(d(~x)) => 0 star_0 = @rule ★(0) => 0 From d4427b13760c9d7a6a67b8d1f95bee4ae7552f1a Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 20 Sep 2024 16:46:43 -0400 Subject: [PATCH 16/39] Cleaning up pipeline This black boxes the intermediate symbolic expressions to the user. The user will simply submit a rewriter that will then be applied --- src/SymbolicUtilsInterop.jl | 1 + src/acset2symbolic.jl | 87 +++++++++++++++++++------------------ src/graph_traversal.jl | 8 ++-- src/sym_rewrite.jl | 29 ++++--------- 4 files changed, 58 insertions(+), 67 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 61502a8..28bde3f 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -14,6 +14,7 @@ struct SymbolicEquation{E} lhs::E rhs::E end +export SymbolicEquation Base.show(io::IO, e::SymbolicEquation) = begin print(io, e.lhs); print(io, " == "); print(io, e.rhs) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 0bf24fd..e5edd1c 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -5,14 +5,16 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle -export extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup +import SymbolicUtils: BasicSymbolic, Symbolic + +export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) +to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) function symbolics_lookup(d::SummationDecapode) - lookup = Dict{Symbol, SymbolicUtils.BasicSymbolic}() + lookup = Dict{Symbol, BasicSymbolic}() for i in parts(d, :Var) push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) end @@ -22,80 +24,69 @@ end function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) var = d[index, :name] new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - @info new_type + SymbolicUtils.Sym{new_type}(var) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1}) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1}) input_sym = symvar_lookup[d[d[op_index, :src], :name]] output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] - # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) op_sym = getfield(@__MODULE__, d[op_index, :op1]) S = promote_symtype(op_sym, input_sym) rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) - SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) + SymbolicEquation{Symbolic}(output_sym, rhs) end -# TODO add promote_symtype as Op1 -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2}) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2}) input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] output_sym = symvar_lookup[d[d[op_index, :res], :name]] - # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) op_sym = getfield(@__MODULE__, d[op_index, :op2]) S = promote_symtype(op_sym, input1_sym, input2_sym) rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) - SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) + SymbolicEquation{Symbolic}(output_sym, rhs) end #XXX: Always converting + -> .+ here since summation doesn't store the style of addition -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Σ}) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ}) syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] output_sym = symvar_lookup[d[d[op_index, :sum], :name]] # TODO pls test S = promote_symtype(+, syms_array...) rhs = SymbolicUtils.Term{S}(+, syms_array) - SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) + SymbolicEquation{Symbolic}(output_sym,rhs) end -function symbolic_rewriting(old_d::SummationDecapode) +function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) d = deepcopy(old_d) infer_types!(d) - # resolve_overloads!(d) symvar_lookup = symbolics_lookup(d) - merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) -end + eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) -function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}) - topo_list = topological_sort_edges(d) - sym_list = [] - for node in topo_list - retrieve_name(d, node) != DerivOp || continue - push!(sym_list, to_symbolics(d, symvar_lookup, node)) + if !isnothing(rewriter) + eqns = map(rewriter, eqns) end - sym_list -end -function apply_rewrites(symexprs, rewriter) + to_acset(d, eqns) +end - rewritten_list = [] - for sym in symexprs - res_sym = rewriter(sym) - rewritten_sym = isnothing(res_sym) ? sym : res_sym - push!(rewritten_list, rewritten_sym) +function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) + sym_list = SymbolicEquation{Symbolic}[] + for node in topological_sort_edges(d) + retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC + push!(sym_list, to_symbolics(d, symvar_lookup, node)) end - - rewritten_list + sym_list end -function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, rewritten_syms) +function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) eqn_lookup = Dict() @@ -108,22 +99,35 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbo final_nodes = infer_terminal_names(d) - for expr in rewritten_syms + for expr in symexpr_list - merged_eqn = SymbolicUtils.substitute(expr, eqn_lookup) - lhs = merged_eqn.arguments[1] - rhs = merged_eqn.arguments[2] + merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) - push!(eqn_lookup, (lhs => rhs)) + push!(eqn_lookup, (expr.lhs => merged_rhs)) - if lhs.name in final_nodes - push!(final_list, merged_eqn) + if expr.lhs.name in final_nodes + push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) end end final_list end +formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) + +function apply_rewrites(symexprs, rewriter) + + rewritten_list = [] + for sym in symexprs + res_sym = rewriter(sym) + rewritten_sym = isnothing(res_sym) ? sym : res_sym + push!(rewritten_list, rewritten_sym) + end + + rewritten_list +end + + function to_acset(og_d, sym_exprs) final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) @@ -159,7 +163,6 @@ function to_acset(og_d, sym_exprs) d = SummationDecapode(parse_decapode(deca_block)) infer_types!(d) - # resolve_overloads!(d) d end diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 43b8860..f2875b2 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,7 +1,7 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name, start_nodes +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes struct TraversalNode{T} index::Int @@ -20,7 +20,7 @@ function topological_sort_edges(d::SummationDecapode) # FIXME: this is a quadratic implementation of topological_sort inlined in here. op_order = TraversalNode{Symbol}[] - for _ in 1:number_of_ops(d) + for _ in 1:n_ops(d) for op in parts(d, :Op1) if !visited_1[op] && visited_Var[d[op, :src]] @@ -49,12 +49,12 @@ function topological_sort_edges(d::SummationDecapode) end end - @assert length(op_order) == number_of_ops(d) + @assert length(op_order) == n_ops(d) op_order end -function number_of_ops(d::SummationDecapode) +function n_ops(d::SummationDecapode) return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) end diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index 73f4cd9..cec55de 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -8,9 +8,15 @@ Heat = @decapode begin C::Form0 G::Form1 D::Constant - ∂ₜ(G) == D*Δ(d(C)) + ∂ₜ(G) == D*Δ(C) end; infer_types!(Heat) +test_heat_same = symbolic_rewriting(Heat) + +r = rules(Δ, Val(1)) + +rwr = Fixpoint(Prewalk(Chain(r))) +test_heat_open = symbolic_rewriting(Heat, rwr) Brusselator = @decapode begin (U, V)::Form0 @@ -35,24 +41,7 @@ Phytodynamics = @decapode begin ∂ₜ(n) == w + m*n + Δ(n) end infer_types!(Phytodynamics) -test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics)) - -# resolve_overloads!(Heat) - -# lap_0_convert = @rule Δ₀(~x) => Δ(~x) -# lap_1_convert = @rule Δ₁(~x) => Δ(~x) -# lap_2_convert = @rule Δ₂(~x) => Δ(~x) - -# d_0_convert = @rule d₀(~x) => d(~x) - -# overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] - -# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) -# lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) -# lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) - -# openers = [lap_0_rule, lap_1_rule, lap_2_rule] - +test_phy = symbolic_rewriting(Phytodynamics) # it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms. heat_exprs = symbolic_rewriting(Heat) @@ -103,8 +92,6 @@ R1 = @rule Δ(~~x::(x->isForm1(x))) => ★(d(★(d(~x)))) # pulling out the subexpression rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(SymbolicUtils.Chain(r))) - - res_exprs = apply_rewrites(heat_exprs, rewriter) sub_exprs = apply_rewrites([sub], rewriter) From 6a3877f07dce56711ae9a77d3551b1e586f6b349 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:04:55 -0400 Subject: [PATCH 17/39] Fixed order of inclusions --- src/DiagrammaticEquations.jl | 3 ++- test/graph_traversal.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 72cada6..558a839 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -69,9 +69,10 @@ include("graph_traversal.jl") include("deca/Deca.jl") include("learn/Learn.jl") include("SymbolicUtilsInterop.jl") -include("acset2symbolic.jl") @reexport using .Deca @reexport using .SymbolicUtilsInterop +include("acset2symbolic.jl") + end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl index 7a5d6ac..b09fd24 100644 --- a/test/graph_traversal.jl +++ b/test/graph_traversal.jl @@ -4,7 +4,7 @@ using MLStyle using Test function is_correct_length(d::SummationDecapode, result) - return length(result) == number_of_ops(d) + return length(result) == n_ops(d) end @testset "Topological Sort on Edges" begin From ea2d8c0be52cc8d4c93de68e2ba195477fcce0a6 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 23 Sep 2024 12:52:56 -0400 Subject: [PATCH 18/39] adding support for Parameters and Constants --- src/deca/ThDEC.jl | 18 +++++++++++++----- test/decasymbolic.jl | 11 ++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 00813ff..57aab47 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -24,8 +24,12 @@ export DECQuantity # this ensures symtype doesn't recurse endlessly SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S -struct Scalar <: DECQuantity end -export Scalar +abstract type AbstractScalar <: DECQuantity end + +struct Scalar <: AbstractScalar end +struct Parameter <: AbstractScalar end +struct ConstScalar <: AbstractScalar end +export Scalar, Parameter, ConstScalar struct FormParams dim::Int @@ -107,7 +111,7 @@ end export PatFormDim @active PatScalar(T) begin - if T <: Scalar + if T <: AbstractScalar Some(T) end end @@ -225,7 +229,9 @@ abstract type SortError <: Exception end # struct WedgeDimError <: SortError end -Base.nameof(s::Scalar) = :Constant +Base.nameof(s::ConstScalar) = :ConstScalar +Base.nameof(s::Parameter) = :Parameter +Base.nameof(s::Scalar) = :Scalar function Base.nameof(f::Form; with_dim_parameter=false) dual = isdual(f) ? "Dual" : "" @@ -269,7 +275,9 @@ end function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) @match qty begin - :Scalar || :Constant => Scalar + :Scalar => Scalar + :ConstScalar => ConstScalar + :Parameter => Parameter :Form0 => PrimalForm{0, space, 1} :Form1 => PrimalForm{1, space, 1} :Form2 => PrimalForm{2, space, 1} diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index bd82387..3a0cb92 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -7,6 +7,7 @@ using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle # load up some variable variables and expressions +c, t = @syms c::ConstScalar t::Parameter a, b = @syms a::Scalar b::Scalar u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} @@ -14,7 +15,9 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} # TODO would be nice to pass the space globally to avoid duplication @testset "Term Construction" begin - + + @test symtype(c) == ConstScalar + @test symtype(t) == Parameter @test symtype(a) == Scalar @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @@ -22,6 +25,10 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test symtype(ϕ) == PrimalVF{:X, 2} @test symtype(ψ) == DualVF{:X, 2} + @test symtype(c + t) == Scalar + @test symtype(t + t) == Scalar + @test symtype(c + c) == Scalar + @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} # @test_throws ThDEC.SortError ThDEC.♯(u) @@ -30,6 +37,8 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} # test unary operator conversion to decaexpr @test Term(1) == Lit(Symbol("1")) @test Term(a) == Var(:a) + @test Term(c) == Var(:c) + @test Term(t) == Var(:t) @test Term(∂ₜ(u)) == Tan(Var(:u)) @test Term(★(ω)) == App1(:★₁, Var(:ω)) From 9bb62697c9b87b7c39affd125c058c1aae552804 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:18:18 -0400 Subject: [PATCH 19/39] Added tests for acset2symbolic --- src/acset2symbolic.jl | 3 +- src/deca/ThDEC.jl | 2 +- src/deca/deca_acset.jl | 51 +++++++------ src/sym_rewrite.jl | 20 ----- test/Project.toml | 1 + test/acset2symbolic.jl | 165 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 5 +- 7 files changed, 198 insertions(+), 49 deletions(-) create mode 100644 test/acset2symbolic.jl diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index e5edd1c..078d08e 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -129,12 +129,13 @@ end function to_acset(og_d, sym_exprs) + + #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) recursive_descent = @λ begin e::Expr => begin if e.head == :call - @show nameof(e.args[1]) e.args[1] = nameof(e.args[1]) map(recursive_descent, e.args[2:end]) end diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 57aab47..518b4f9 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -276,7 +276,7 @@ end function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) @match qty begin :Scalar => Scalar - :ConstScalar => ConstScalar + :Constant => ConstScalar :Parameter => Parameter :Form0 => PrimalForm{0, space, 1} :Form1 => PrimalForm{1, space, 1} diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 4334c4b..e0f5580 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -5,7 +5,7 @@ using ..DiagrammaticEquations These are the default rules used to do type inference in the 1D exterior calculus. """ op1_inf_rules_1D = [ - # Rules for ∂ₜ + # Rules for ∂ₜ (src_type = :Form0, tgt_type = :Form0, op_names = [:∂ₜ,:dt]), (src_type = :Form1, tgt_type = :Form1, op_names = [:∂ₜ,:dt]), @@ -14,10 +14,10 @@ op1_inf_rules_1D = [ (src_type = :DualForm0, tgt_type = :DualForm1, op_names = [:d, :dual_d₀, :d̃₀]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm1, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm0, op_names = [:⋆, :⋆₁, :star]), - (src_type = :DualForm1, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :Form0, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :DualForm1, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -29,7 +29,7 @@ op1_inf_rules_1D = [ # Rules for negation (src_type = :Form0, tgt_type = :Form0, op_names = [:neg, :(-)]), (src_type = :Form1, tgt_type = :Form1, op_names = [:neg, :(-)]), - + # Rules for the averaging operator (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), @@ -39,7 +39,7 @@ op1_inf_rules_1D = [ # Rules for ♭. (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), - + # Rules for magnitude/ norm (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] @@ -52,7 +52,7 @@ op2_inf_rules_1D = [ # Rules for L₀, L₁ (proj1_type = :Form1, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:L, :L₀]), - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), + (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:i, :i₁]), @@ -69,10 +69,10 @@ op2_inf_rules_1D = [ (proj1_type = :Form0, proj2_type = :Parameter, res_type = :Form0, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form1, proj2_type = :Parameter, res_type = :Form1, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form2, proj2_type = :Parameter, res_type = :Form2, op_names = [:/, :./, :*, :.*]),=# - + (proj1_type = :Form0, proj2_type = :Literal, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Literal, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :DualForm0, proj2_type = :Literal, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm1, proj2_type = :Literal, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), @@ -86,7 +86,7 @@ op2_inf_rules_1D = [ (proj1_type = :Constant, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form0, proj2_type = :Constant, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Constant, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), @@ -112,13 +112,13 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm2, op_names = [:d, :dual_d₁, :d̃₁]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm2, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm1, op_names = [:⋆, :⋆₁, :star]), - (src_type = :Form2, tgt_type = :DualForm0, op_names = [:⋆, :⋆₂, :star]), + (src_type = :Form0, tgt_type = :DualForm2, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :Form2, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₂, :star]), - (src_type = :DualForm2, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm1, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form2, op_names = [:⋆, :⋆₂⁻¹, :star_inv]), + (src_type = :DualForm2, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm1, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form2, op_names = [:★, :⋆, :⋆₂⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -154,7 +154,7 @@ op1_inf_rules_2D = [ # Rules for magnitude/ norm (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] - + op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ∧₁₁, ∧₂₀, ∧₀₂ (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form2, op_names = [:∧, :∧₁₁, :wedge]), @@ -162,7 +162,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form0, proj2_type = :Form2, res_type = :Form2, op_names = [:∧, :∧₀₂, :wedge]), # Rules for L₂ - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), + (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:i, :i₂]), @@ -173,7 +173,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ι (proj1_type = :DualForm1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:ι₁₁]), (proj1_type = :DualForm1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:ι₁₂]), - + # Rules for subtraction (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, op_names = [:-, :.-]), (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form1, op_names = [:-, :.-]), @@ -212,7 +212,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for Δ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :Δ), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :Δ)] - + # We merge 1D and 2D rules since it seems op2 rules are metric-free. If # this assumption is false, this needs to change. op2_res_rules_1D = [ @@ -228,8 +228,8 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, resolved_name = :L₁, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, resolved_name = :i₁, op = :i)] - - + + """ These are the default rules used to do function resolution in the 2D exterior calculus. """ @@ -274,7 +274,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = :lapl)] - + # We merge 1D and 2D rules directly here since it seems op2 rules # are metric-free. If this assumption is false, this needs to change. op2_res_rules_2D = vcat(op2_res_rules_1D, [ @@ -290,7 +290,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, resolved_name = :L₂ᵈ, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, resolved_name = :i₂, op = :i)]) - + # TODO: When SummationDecapodes are annotated with the degree of their space, # use dispatch to choose the correct set of rules. infer_types!(d::SummationDecapode) = @@ -358,4 +358,3 @@ Resolve function overloads based on types of src and tgt. """ resolve_overloads!(d::SummationDecapode) = resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D) - diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index cec55de..e63903c 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -4,19 +4,6 @@ using SymbolicUtils using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, symtype, promote_symtype using MLStyle -Heat = @decapode begin - C::Form0 - G::Form1 - D::Constant - ∂ₜ(G) == D*Δ(C) -end; -infer_types!(Heat) -test_heat_same = symbolic_rewriting(Heat) - -r = rules(Δ, Val(1)) - -rwr = Fixpoint(Prewalk(Chain(r))) -test_heat_open = symbolic_rewriting(Heat, rwr) Brusselator = @decapode begin (U, V)::Form0 @@ -35,13 +22,6 @@ Brusselator = @decapode begin end infer_types!(Brusselator) -Phytodynamics = @decapode begin - (n,w)::Form0 - m::Constant - ∂ₜ(n) == w + m*n + Δ(n) -end -infer_types!(Phytodynamics) -test_phy = symbolic_rewriting(Phytodynamics) # it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms. heat_exprs = symbolic_rewriting(Heat) diff --git a/test/Project.toml b/test/Project.toml index 97b4819..4f3a02a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,4 +6,5 @@ CombinatorialSpaces = "b1c52339-7909-45ad-8b6a-6e388f7c67f2" DiagrammaticEquations = "6f00c28b-6bed-4403-80fa-30e0dc12f317" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl new file mode 100644 index 0000000..f13cb79 --- /dev/null +++ b/test/acset2symbolic.jl @@ -0,0 +1,165 @@ +using Test +using DiagrammaticEquations +using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, @rule +using Catlab + +@testset "Basic Roundtrip" begin + op1_only = @decapode begin + A::Form0 + B::Form1 + B == d(A) + end + + @test op1_only == symbolic_rewriting(op1_only) + + + op2_only = @decapode begin + A::Constant + B::Form0 + C::Form0 + + C == A * B + end + + @test op2_only == symbolic_rewriting(op2_only) + + sum_only = @decapode begin + A::Form2 + B::Form2 + C::Form2 + + C == A + B + end + + @test sum_only == symbolic_rewriting(sum_only) + + + multi_sum = @decapode begin + A::Form2 + B::Form2 + C::Form2 + D::Form2 + + D == (A + B) + C + end + infer_types!(multi_sum) + + # TODO: This is correct but the symbolics is splitting up the sum + @test multi_sum == symbolic_rewriting(multi_sum) + + + all_ops = @decapode begin + A::Constant + B::Form0 + C::Form1 + D::Form1 + E::Form1 + F::Form1 + + C == d(B) + D == A * C + F == D + E + end + + # This loses the intermediate names C and D + all_ops_res = symbolic_rewriting(all_ops) + all_ops_res[5, :name] = :D + all_ops_res[6, :name] = :C + @test is_isomorphic(all_ops,all_ops_res) +end + +function expr_rewriter(rules::Vector) + return Fixpoint(Prewalk(Fixpoint(Chain(rules)))) +end + +@testset "Basic Rewriting" begin + op1s = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == B + B + d(d(A)) + end + + dd_0 = @rule d(d(~x)) => 0 + + op1s_rewritten = symbolic_rewriting(op1s, expr_rewriter([dd_0])) + + op1s_equiv = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == 2 * B + end + + @test op1s_equiv == op1s_rewritten + + + op2s = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(∧(A, B), C) + end + + wdg_assoc = @rule ∧(∧(~x, ~y), ~z) => ∧(~x, ∧(~y, ~z)) + + op2s_rewritten = symbolic_rewriting(op2s, expr_rewriter([wdg_assoc])) + + op2s_equiv = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(A, ∧(B, C)) + end + infer_types!(op2s_equiv) + + @test op2s_equiv == op2s_rewritten + +end + +@testset "Heat" begin + Heat = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*Δ(C) + end + infer_types!(Heat) + + # Same up to re-naming + Heat[5, :name] = Symbol("•1") + @test Heat == symbolic_rewriting(Heat) + + Heat_open = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*★(d(★(d(C)))) + end + infer_types!(Heat_open) + + Heat_open[8, :name] = Symbol("•1") + Heat_open[5, :name] = Symbol("•2") + Heat_open[6, :name] = Symbol("•3") + Heat_open[7, :name] = Symbol("•4") + + @test is_isomorphic(Heat_open, symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1))))) +end + +@testset "Phytodynamics" begin + Phytodynamics = @decapode begin + (n,w)::Form0 + m::Constant + ∂ₜ(n) == w + m*n + Δ(n) + end + infer_types!(Phytodynamics) + test_phy = symbolic_rewriting(Phytodynamics) +end diff --git a/test/runtests.jl b/test/runtests.jl index 972d72c..4a64095 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,7 +38,10 @@ end include("openoperators.jl") end -include("aqua.jl") @testset "Symbolic Rewriting" begin include("graph_traversal.jl") + include("acset2symbolic.jl") end + + +include("aqua.jl") From 8661e2fdc03c2037c3176cbf8ca20162d2288a4c Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 23 Sep 2024 14:19:19 -0400 Subject: [PATCH 20/39] etc --- src/deca/ThDEC.jl | 9 +++++---- test/decasymbolic.jl | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 57aab47..91ba946 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -29,7 +29,8 @@ abstract type AbstractScalar <: DECQuantity end struct Scalar <: AbstractScalar end struct Parameter <: AbstractScalar end struct ConstScalar <: AbstractScalar end -export Scalar, Parameter, ConstScalar +struct Literal <: AbstractScalar end +export Scalar, Parameter, ConstScalar, Literal struct FormParams dim::Int @@ -227,8 +228,7 @@ end abstract type SortError <: Exception end -# struct WedgeDimError <: SortError end - +Base.nameof(s::Literal) = :Literal Base.nameof(s::ConstScalar) = :ConstScalar Base.nameof(s::Parameter) = :Parameter Base.nameof(s::Scalar) = :Scalar @@ -276,8 +276,9 @@ end function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) @match qty begin :Scalar => Scalar - :ConstScalar => ConstScalar + :Constant => ConstScalar :Parameter => Parameter + :Literal => Literal :Form0 => PrimalForm{0, space, 1} :Form1 => PrimalForm{1, space, 1} :Form2 => PrimalForm{2, space, 1} diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 3a0cb92..b3ff661 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -7,6 +7,7 @@ using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle # load up some variable variables and expressions +ℓ, = @syms ℓ::Literal c, t = @syms c::ConstScalar t::Parameter a, b = @syms a::Scalar b::Scalar u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @@ -16,9 +17,11 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @testset "Term Construction" begin + @test symtype(ℓ) == Literal @test symtype(c) == ConstScalar @test symtype(t) == Parameter @test symtype(a) == Scalar + @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @test symtype(η) == DualForm{2, :X, 2} From 2228d7a2c591a82efd173295e0b5ae34173764b9 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 23 Sep 2024 15:18:35 -0400 Subject: [PATCH 21/39] Literals testing --- src/SymbolicUtilsInterop.jl | 6 ++++-- src/acset.jl | 1 + test/acset2symbolic.jl | 11 +++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 28bde3f..9d4f6e0 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -1,6 +1,8 @@ module SymbolicUtilsInterop +using ACSets using ..DiagrammaticEquations: AbstractDecapode, Quantity +using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique! import ..DiagrammaticEquations: eval_eq!, SummationDecapode using ..decapodes using ..Deca @@ -123,7 +125,7 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) end function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int}) - eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) + eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) end """ function SummationDecapode(e::SymbolicContext) """ @@ -133,7 +135,7 @@ function SummationDecapode(e::SymbolicContext) foreach(e.vars) do var # convert Sort(var)::PrimalForm0 --> :Form0 - var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var))) + var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var))) symbol_table[var.name] = var_id end diff --git a/src/acset.jl b/src/acset.jl index 34d1d3c..4cfdaac 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -187,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode) isempty(unrecognized_types) || error("Types $unrecognized_types are not recognized. CHECK: $types") end +export recognize_types """ is_expanded(d::AbstractNamedDecapode) diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index f13cb79..e81ed95 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -163,3 +163,14 @@ end infer_types!(Phytodynamics) test_phy = symbolic_rewriting(Phytodynamics) end + +@testset "Literals" begin + Heat = parse_decapode(quote + C::Form0 + G::Form0 + ∂ₜ(G) == 3*Δ(C) + end) + context = SymbolicContext(Heat) + SummationDecapode(context) + +end From bc9ab00df0b96a7d2ed99e9428d1f3efa2c491d0 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 26 Sep 2024 13:50:26 -0400 Subject: [PATCH 22/39] parameters test passing after some debugging. --- src/acset2symbolic.jl | 13 +++++++++++-- src/symbolictheoryutils.jl | 2 +- test/acset2symbolic.jl | 32 ++++++++++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 078d08e..ee1898f 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -56,7 +56,6 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSym syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] output_sym = symvar_lookup[d[d[op_index, :sum], :name]] - # TODO pls test S = promote_symtype(+, syms_array...) rhs = SymbolicUtils.Term{S}(+, syms_array) SymbolicEquation{Symbolic}(output_sym,rhs) @@ -101,6 +100,14 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basic for expr in symexpr_list + # XXX SymbolicUtils.substitute swaps the order of multiplication. + # example: @decapode begin + # u::Form0 + # G::Form0 + # κ::Constant + # ∂ₜ(G) == κ*★(d(★(d(u)))) + # end + # will have the kappa*var term rewritten to var*kappa merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) push!(eqn_lookup, (expr.lhs => merged_rhs)) @@ -127,7 +134,9 @@ function apply_rewrites(symexprs, rewriter) rewritten_list end - +""" +og_d = original reference decapode which provides type information, state and terminal information +""" function to_acset(og_d, sym_exprs) #TODO: This step is breaking up summations diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 0c03a0a..f1693f3 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -73,7 +73,7 @@ end """ macro operator(head, body) - # parse body + # parse head ph = @λ begin Expr(:call, foo, Expr(:(::), vars..., theory)) => (foo, vars, theory) Expr(:(::), Expr(:call, foo, vars...), theory) => (foo, vars, theory) diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index e81ed95..89de387 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -3,6 +3,8 @@ using DiagrammaticEquations using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, @rule using Catlab +(≃) = is_isomorphic + @testset "Basic Roundtrip" begin op1_only = @decapode begin A::Form0 @@ -65,7 +67,7 @@ using Catlab all_ops_res = symbolic_rewriting(all_ops) all_ops_res[5, :name] = :D all_ops_res[6, :name] = :C - @test is_isomorphic(all_ops,all_ops_res) + @test all_ops ≃ all_ops_res end function expr_rewriter(rules::Vector) @@ -151,7 +153,7 @@ end Heat_open[6, :name] = Symbol("•3") Heat_open[7, :name] = Symbol("•4") - @test is_isomorphic(Heat_open, symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1))))) + @test Heat_open ≃ symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) end @testset "Phytodynamics" begin @@ -174,3 +176,29 @@ end SummationDecapode(context) end + +@testset "Parameters" begin + + Heat = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == Δ(u)*κ + end + infer_types!(Heat) + + Heat_open = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == ★(d(★(d(u))))*κ + end + infer_types!(Heat_open) + + Heat_open[7, :name] = Symbol("•4") + Heat_open[8, :name] = Symbol("•1") + + z = symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) + @test Heat_open ≃ z + +end From 2b3198f010a1174dcfdf7add9f7bde5e7c01ec2b Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 26 Sep 2024 20:05:28 -0400 Subject: [PATCH 23/39] supporting Infer, better Base.nameof, better tests --- src/SymbolicUtilsInterop.jl | 24 ++++++------ src/acset2symbolic.jl | 2 +- src/deca/ThDEC.jl | 40 +++++++++++++++---- src/symbolictheoryutils.jl | 2 + test/acset2symbolic.jl | 14 +++++++ test/decasymbolic.jl | 78 +++++++++++++++++++------------------ test/runtests.jl | 4 ++ 7 files changed, 107 insertions(+), 57 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 9d4f6e0..2e50048 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -1,7 +1,7 @@ module SymbolicUtilsInterop using ACSets -using ..DiagrammaticEquations: AbstractDecapode, Quantity +using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique! import ..DiagrammaticEquations: eval_eq!, SummationDecapode using ..decapodes @@ -51,7 +51,7 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) decapodes.Plus(termargs) elseif op == * decapodes.Mult(termargs) - elseif op == ∂ₜ + elseif op ∈ [DerivOp, ∂ₜ] decapodes.Tan(only(termargs)) elseif length(args) == 1 decapodes.App1(nameof(op, symtype.(args)...), termargs...) @@ -85,9 +85,9 @@ Example: SymbolicUtils.BasicSymbolic(context, Term(a)) ``` """ -function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term, __module__=@__MODULE__) +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term) # user must import symbols into scope - ! = (f -> getfield(__module__, f)) + ! = (f -> getfield(@__MODULE__, f)) @match t begin Var(name) => SymbolicUtils.Sym{context[name]}(name) Lit(v) => Meta.parse(string(v)) @@ -98,17 +98,17 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapode # see test/language.jl (f, x) -> (!(f))(x), fs; - init=BasicSymbolic(context, arg, __module__) + init=BasicSymbolic(context, arg) ) - App1(f, x) => (!(f))(BasicSymbolic(context, x, __module__)) - App2(f, x, y) => (!(f))(BasicSymbolic(context, x, __module__), BasicSymbolic(context, y, __module__)) - Plus(xs) => +(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Mult(xs) => *(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Tan(x) => ∂ₜ(BasicSymbolic(context, x, __module__)) + App1(f, x) => (!(f))(BasicSymbolic(context, x)) + App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) + Tan(x) => (!(DerivOp))(BasicSymbolic(context, x)) end end -function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) +function SymbolicContext(d::decapodes.DecaExpr) # associates each var to its sort... context = map(d.context) do j j.var => symtype(Deca.DECQuantity, j.dim, j.space) @@ -119,7 +119,7 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) end context = Dict{Symbol,DataType}(context) eqs = map(d.equations) do eq - SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) + SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...) end SymbolicContext(vars, eqs) end diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index ee1898f..667fbdc 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -79,7 +79,7 @@ end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) sym_list = SymbolicEquation{Symbolic}[] for node in topological_sort_edges(d) - retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC + # retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC push!(sym_list, to_symbolics(d, symvar_lookup, node)) end sym_list diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 91ba946..e71701b 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -26,11 +26,14 @@ SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S abstract type AbstractScalar <: DECQuantity end +struct InferredType <: DECQuantity end +export InferredType + struct Scalar <: AbstractScalar end struct Parameter <: AbstractScalar end -struct ConstScalar <: AbstractScalar end +struct Const <: AbstractScalar end struct Literal <: AbstractScalar end -export Scalar, Parameter, ConstScalar, Literal +export Scalar, Parameter, Const, Literal struct FormParams dim::Int @@ -90,6 +93,18 @@ Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm"*"$(dim(u))") # ACTIVE PATTERNS +@active PatInferredType(T) begin + if T <: InferredType + Some(InferredType) + end +end + +@active PatInferredTypes(T) begin + if any(S->S<:InferredType, T) + Some(InferredType) + end +end + @active PatForm(T) begin if T <: Form Some(T) @@ -158,6 +173,7 @@ export isDualForm, isForm0, isForm1, isForm2 @operator d(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{i+1,d,s,n} _ => throw(ExteriorDerivativeError(S)) end @@ -167,6 +183,7 @@ end @operator ★(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{n-i,d,s,n} _ => throw(HodgeStarError(S)) end @@ -176,6 +193,7 @@ end @operator Δ(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) _ => throw(LaplacianError(S)) end @@ -185,8 +203,11 @@ end @alias (Δ₀, Δ₁, Δ₂) => Δ +# Base.show(io::IO, + @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => S1 # commutativity (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin @@ -206,6 +227,7 @@ end @operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i,d,s,n} _ => throw(BinaryOpError("multiply", S1, S2)) @@ -214,6 +236,7 @@ end @operator ∧(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(WedgeOpError(S1, S2)) if i1 + i2 <= n1 @@ -228,10 +251,10 @@ end abstract type SortError <: Exception end -Base.nameof(s::Literal) = :Literal -Base.nameof(s::ConstScalar) = :ConstScalar -Base.nameof(s::Parameter) = :Parameter -Base.nameof(s::Scalar) = :Scalar +Base.nameof(s::Union{Literal,Type{Literal}}) = :Literal +Base.nameof(s::Union{Const, Type{Const}}) = :Constant +Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter +Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar function Base.nameof(f::Form; with_dim_parameter=false) dual = isdual(f) ? "Dual" : "" @@ -264,6 +287,8 @@ function Base.nameof(::typeof(∧), s1, s2) Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") end +Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") + Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") Base.nameof(::typeof(Δ), s) = :Δ @@ -276,7 +301,7 @@ end function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) @match qty begin :Scalar => Scalar - :Constant => ConstScalar + :Constant => Const :Parameter => Parameter :Literal => Literal :Form0 => PrimalForm{0, space, 1} @@ -285,6 +310,7 @@ function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) :DualForm0 => DualForm{0, space, 1} :DualForm1 => DualForm{1, space, 1} :DualForm2 => DualForm{2, space, 1} + :Infer => InferredType _ => error("Received $qty") end end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index f1693f3..38f4fcc 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -108,6 +108,8 @@ macro operator(head, body) SymbolicUtils.Term{s}($f, Any[$(argnames...)]) end export $f + + Base.show(io::IO, ::typeof($f)) = print(io, $f) end) # if there are rewriting rules, add a method which accepts the function symbol and its arity (to prevent shadowing on operators like `-`) diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index 89de387..c928901 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -202,3 +202,17 @@ end @test Heat_open ≃ z end + +x=@decapode begin + u::Form0 + ∂ₜ(u) == u +end +symbolic_rewriting(x) +# if the `for op1 in parts(og_d, :Op1)...` block is removed, this is annihilated because x has no terminals + +x=@decapode begin + u::Form0 + v::Form0 + ∂ₜ(v) == u +end +symbolic_rewriting(x) # fine diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index b3ff661..722157c 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -7,18 +7,20 @@ using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle # load up some variable variables and expressions -ℓ, = @syms ℓ::Literal -c, t = @syms c::ConstScalar t::Parameter -a, b = @syms a::Scalar b::Scalar -u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} -ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} -ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} +👻, = @syms 👻::InferredType +ℓ, = @syms ℓ::Literal +c, t = @syms c::Const t::Parameter +a, b = @syms a::Scalar b::Scalar +u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} +ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} +ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication @testset "Term Construction" begin + @test symtype(👻) == InferredType @test symtype(ℓ) == Literal - @test symtype(c) == ConstScalar + @test symtype(c) == Const @test symtype(t) == Parameter @test symtype(a) == Scalar @@ -31,9 +33,12 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test symtype(c + t) == Scalar @test symtype(t + t) == Scalar @test symtype(c + c) == Scalar + @test symtype(t + 👻) == InferredType @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} + @test symtype(u ∧ 👻) == InferredType + # @test_throws ThDEC.SortError ThDEC.♯(u) @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} @@ -112,40 +117,39 @@ end @testset "Conversion" begin - context = Dict(:a => Scalar(),:b => Scalar() - ,:u => PrimalForm(0, X),:du => PrimalForm(1, X)) - js = [Judgement(:u, :Form0, :X) - ,Judgement(:∂ₜu, :Form0, :X) - ,Judgement(:Δu, :Form0, :X)] - eqs = [Eq(Var(:∂ₜu) - , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) - , Eq(Tan(Var(:u)), Var(:∂ₜu))] - heat_eq = DecaExpr(js, eqs) + Exp = @decapode begin + u::Form0 + v::Form0 + ∂ₜ(v) == u + end + context = SymbolicContext(Term(Exp)) + Exp′ = SummationDecapode(DecaExpr(context)) - symb_heat_eq = DecaSymbolic(lookup, heat_eq) - deca_expr = DecaExpr(symb_heat_eq) + # does roundtripping work + @test Exp == Exp′ -end - -@testset "Moving between DecaExpr and DecaSymbolic" begin - - @test js == deca_expr.context - - # eqs in the left has AppCirc1[vector, term] - # deca_expr.equations on the right has nested App1 - # expected behavior is that nested AppCirc1 is preserved - @test_broken eqs == deca_expr.equations - # use expand_operators to get rid of parentheses - # infer_types and resolve_overloads - -end + Heat = @decapode begin + u::Form0 + v::Form0 + κ::Constant + ∂ₜ(v) == Δ(u)*κ + end + infer_types!(Heat) + context = SymbolicContext(Term(Heat)) + Heat′ = SummationDecapode(DecaExpr(context)) -# convert both into ACSets then is_iso them -@testset "" begin + @test Heat == Heat′ - Σ = DiagrammaticEquations.SummationDecapode(deca_expr) - Δ = DiagrammaticEquations.SummationDecapode(symb_heat_eq) - @test Σ == Δ + TumorInvasion = @decapode begin + (C,fC)::Form0 + (Dif,Kd,Cmax)::Constant + ∂ₜ(C) == Dif * Δ(C) + fC - Kd * C + end + infer_types!(TumorInvasion) + context = SymbolicContext(Term(TumorInvasion)) + TumorInvasion′ = SummationDecapode(DecaExpr(context)) + # new terms introduced + @test_broken TumorInvasion == TumorInvasion′ end diff --git a/test/runtests.jl b/test/runtests.jl index 4a64095..ec5047b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,10 @@ end include("openoperators.jl") end +@testset "ThDEC Symbolics" begin + include("decasymbolic.jl") +end + @testset "Symbolic Rewriting" begin include("graph_traversal.jl") include("acset2symbolic.jl") From 31ad602145ba5f9d4a901dc3ac9935c4704317a1 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 26 Sep 2024 22:26:28 -0400 Subject: [PATCH 24/39] Clean out-of-order vector constructions --- src/acset2symbolic.jl | 124 +++++++++++------------------------------ src/graph_traversal.jl | 23 +++++++- 2 files changed, 54 insertions(+), 93 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index ee1898f..8df2d91 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -11,92 +11,51 @@ export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_a const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) - function symbolics_lookup(d::SummationDecapode) - lookup = Dict{Symbol, BasicSymbolic}() - for i in parts(d, :Var) - push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) - end - lookup + Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i + (d[i, :name], decavar_to_symbolics(d, i)) + end) end function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) - var = d[index, :name] new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - - SymbolicUtils.Sym{new_type}(var) -end - -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1}) - input_sym = symvar_lookup[d[d[op_index, :src], :name]] - output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] - - op_sym = getfield(@__MODULE__, d[op_index, :op1]) - - S = promote_symtype(op_sym, input_sym) - rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) - SymbolicEquation{Symbolic}(output_sym, rhs) + SymbolicUtils.Sym{new_type}(d[index, :name]) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2}) - input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] - input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] - output_sym = symvar_lookup[d[d[op_index, :res], :name]] +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, d[op_index, :op2]) + op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type))) - S = promote_symtype(op_sym, input1_sym, input2_sym) - rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) + S = promote_symtype(op_sym, input_syms...) + rhs = SymbolicUtils.Term{S}(op_sym, input_syms) SymbolicEquation{Symbolic}(output_sym, rhs) end -#XXX: Always converting + -> .+ here since summation doesn't store the style of addition -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ}) - syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] - output_sym = symvar_lookup[d[d[op_index, :sum], :name]] - - S = promote_symtype(+, syms_array...) - rhs = SymbolicUtils.Term{S}(+, syms_array) - SymbolicEquation{Symbolic}(output_sym,rhs) -end - function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) - d = deepcopy(old_d) - - infer_types!(d) + d = infer_types!(deepcopy(old_d)) symvar_lookup = symbolics_lookup(d) eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) - if !isnothing(rewriter) - eqns = map(rewriter, eqns) - end - - to_acset(d, eqns) + to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - sym_list = SymbolicEquation{Symbolic}[] - for node in topological_sort_edges(d) - retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC - push!(sym_list, to_symbolics(d, symvar_lookup, node)) + non_tangents = filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) + map(non_tangents) do node + to_symbolics(d, symvar_lookup, node.index, node.name) end - sym_list end function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) - - eqn_lookup = Dict() - final_list = [] - for node in start_nodes(d) + eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node sym = symvar_lookup[d[node, :name]] - push!(eqn_lookup, (sym => sym)) - end - - final_nodes = infer_terminal_names(d) + (sym, sym) + end) for expr in symexpr_list @@ -112,7 +71,7 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basic push!(eqn_lookup, (expr.lhs => merged_rhs)) - if expr.lhs.name in final_nodes + if expr.lhs.name in infer_terminal_names(d) push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) end end @@ -123,22 +82,13 @@ end formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) function apply_rewrites(symexprs, rewriter) - - rewritten_list = [] - for sym in symexprs + map(symexprs) do sym res_sym = rewriter(sym) - rewritten_sym = isnothing(res_sym) ? sym : res_sym - push!(rewritten_list, rewritten_sym) + isnothing(res_sym) ? sym : res_sym end - - rewritten_list end -""" -og_d = original reference decapode which provides type information, state and terminal information -""" -function to_acset(og_d, sym_exprs) - +function to_acset(d::SummationDecapode, sym_exprs) #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) @@ -151,28 +101,18 @@ function to_acset(og_d, sym_exprs) end sym => nothing end - map(recursive_descent, final_exprs) - - deca_block = quote end - - states = infer_states(og_d) - terminals = infer_terminals(og_d) + foreach(recursive_descent, final_exprs) - deca_type_gen = idx -> :($(og_d[idx, :name])::$(og_d[idx, :type])) - - append!(deca_block.args, map(deca_type_gen, vcat(states, terminals))) - - for op1 in parts(og_d, :Op1) - if og_d[op1, :op1] == DerivOp - push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name])))) - end + states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx + :($(d[idx, :name])::$(d[idx, :type])) end - append!(deca_block.args, final_exprs) - - d = SummationDecapode(parse_decapode(deca_block)) - - infer_types!(d) + tangents = map(incident(d, DerivOp, :op1)) do op1 + :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) + end - d + deca_block = quote end + deca_block.args = [states_terminals..., tangents..., final_exprs...] + infer_types!(SummationDecapode(parse_decapode(deca_block))) end + diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index f2875b2..cc47435 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,13 +1,34 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function struct TraversalNode{T} index::Int name::T end +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + [d[idx,:src]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + [d[idx,:proj1], d[idx,:proj2]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[incident(d, idx, :summation), :summand] + +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:tgt] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:res] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[idx, :sum] + +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:op1] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:op2] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + :+ + function topological_sort_edges(d::SummationDecapode) visited_Var = falses(nparts(d, :Var)) visited_Var[start_nodes(d)] .= true From d408c26d3747c354545e90b5d53f184603aff1bd Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 26 Sep 2024 23:24:27 -0400 Subject: [PATCH 25/39] Convert to symbolics inside merge_equations --- src/acset2symbolic.jl | 77 ++++++++++++++++++------------------------ src/graph_traversal.jl | 2 +- 2 files changed, 34 insertions(+), 45 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 8df2d91..94afff2 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -1,11 +1,9 @@ using DiagrammaticEquations using ACSets +using MLStyle using SymbolicUtils using SymbolicUtils.Rewriters -using SymbolicUtils.Code -using MLStyle - -import SymbolicUtils: BasicSymbolic, Symbolic +using SymbolicUtils: BasicSymbolic, Symbolic export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup @@ -17,16 +15,16 @@ function symbolics_lookup(d::SummationDecapode) end) end -function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) - new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - SymbolicUtils.Sym{new_type}(d[index, :name]) +function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space) + SymbolicUtils.Sym{new_type}(d[idx, :name]) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol) - input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name]) - output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name]) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type))) + op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) S = promote_symtype(op_sym, input_syms...) rhs = SymbolicUtils.Term{S}(op_sym, input_syms) @@ -35,10 +33,7 @@ end function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) d = infer_types!(deepcopy(old_d)) - - symvar_lookup = symbolics_lookup(d) - eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) - + eqns = merge_equations(d) to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) end @@ -49,34 +44,29 @@ function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basi end end -function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) - final_list = [] +# XXX SymbolicUtils.substitute swaps the order of multiplication. +# example: @decapode begin +# u::Form0 +# G::Form0 +# κ::Constant +# ∂ₜ(G) == κ*★(d(★(d(u)))) +# end +# will have the kappa*var term rewritten to var*kappa +function merge_equations(d::SummationDecapode) + symvar_lookup = symbolics_lookup(d) + symexpr_list = extract_symexprs(d, symvar_lookup) eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node sym = symvar_lookup[d[node, :name]] (sym, sym) end) - - for expr in symexpr_list - - # XXX SymbolicUtils.substitute swaps the order of multiplication. - # example: @decapode begin - # u::Form0 - # G::Form0 - # κ::Constant - # ∂ₜ(G) == κ*★(d(★(d(u)))) - # end - # will have the kappa*var term rewritten to var*kappa + foreach(symexpr_list) do expr merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) - push!(eqn_lookup, (expr.lhs => merged_rhs)) - - if expr.lhs.name in infer_terminal_names(d) - push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) - end end - final_list + terminals = filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) + map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) end formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) @@ -89,9 +79,16 @@ function apply_rewrites(symexprs, rewriter) end function to_acset(d::SummationDecapode, sym_exprs) + outer_types = map([infer_states(d)..., infer_terminals(d)...]) do idx + :($(d[idx, :name])::$(d[idx, :type])) + end + + tangents = map(incident(d, DerivOp, :op1)) do op1 + :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) + end + #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) - recursive_descent = @λ begin e::Expr => begin if e.head == :call @@ -103,16 +100,8 @@ function to_acset(d::SummationDecapode, sym_exprs) end foreach(recursive_descent, final_exprs) - states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx - :($(d[idx, :name])::$(d[idx, :type])) - end - - tangents = map(incident(d, DerivOp, :op1)) do op1 - :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) - end - deca_block = quote end - deca_block.args = [states_terminals..., tangents..., final_exprs...] + deca_block.args = [outer_types..., tangents..., final_exprs...] infer_types!(SummationDecapode(parse_decapode(deca_block))) end diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index cc47435..6266240 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,7 +1,7 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function struct TraversalNode{T} index::Int From 3cd624e4c917a0cf01e4841bdbd9a5ee87388302 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 26 Sep 2024 23:54:52 -0400 Subject: [PATCH 26/39] Reduce cases of topological sort --- src/graph_traversal.jl | 56 ++++++++++++++---------------------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 6266240..e41048e 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -29,59 +29,38 @@ edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = :+ +#XXX: This topological sort is O(n^2). function topological_sort_edges(d::SummationDecapode) visited_Var = falses(nparts(d, :Var)) visited_Var[start_nodes(d)] .= true + visited = Dict(:Op1 => falses(nparts(d, :Op1)), + :Op2 => falses(nparts(d, :Op2)), :Σ => falses(nparts(d, :Σ))) - # TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ - visited_1 = falses(nparts(d, :Op1)) - visited_2 = falses(nparts(d, :Op2)) - visited_Σ = falses(nparts(d, :Σ)) - - # FIXME: this is a quadratic implementation of topological_sort inlined in here. op_order = TraversalNode{Symbol}[] - for _ in 1:n_ops(d) - for op in parts(d, :Op1) - if !visited_1[op] && visited_Var[d[op, :src]] - - visited_1[op] = true - visited_Var[d[op, :tgt]] = true - - push!(op_order, TraversalNode(op, :Op1)) - end - end - - for op in parts(d, :Op2) - if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]] - visited_2[op] = true - visited_Var[d[op, :res]] = true - push!(op_order, TraversalNode(op, :Op2)) - end + function visit(op, op_type) + if !visited[op_type][op] && all(visited_Var[edge_inputs(d,op,Val(op_type))]) + visited[op_type][op] = true + visited_Var[edge_output(d,op,Val(op_type))] = true + push!(op_order, TraversalNode(op, op_type)) end + end - for op in parts(d, :Σ) - args = subpart(d, incident(d, op, :summation), :summand) - if !visited_Σ[op] && all(visited_Var[args]) - visited_Σ[op] = true - visited_Var[d[op, :sum]] = true - push!(op_order, TraversalNode(op, :Σ)) - end - end + for _ in 1:n_ops(d) + visit.(parts(d,:Op1), :Op1) + visit.(parts(d,:Op2), :Op2) + visit.(parts(d,:Σ), :Σ) end @assert length(op_order) == n_ops(d) - op_order end -function n_ops(d::SummationDecapode) - return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) -end +n_ops(d::SummationDecapode) = + nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) -function start_nodes(d::SummationDecapode) - return vcat(infer_states(d), incident(d, :Literal, :type)) -end +start_nodes(d::SummationDecapode) = + vcat(infer_states(d), incident(d, :Literal, :type)) function retrieve_name(d::SummationDecapode, tsr::TraversalNode) @match tsr.name begin @@ -91,3 +70,4 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode) _ => error("$(tsr.name) is a table without names") end end + From 67079cbfd4c5af0b189d74e20cb82e35864a084a Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Fri, 27 Sep 2024 11:32:49 -0400 Subject: [PATCH 27/39] Reify via recursive function, not lambda case --- src/acset2symbolic.jl | 49 +++++++++++++++++-------------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 94afff2..cfef6fa 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -1,11 +1,9 @@ using DiagrammaticEquations using ACSets -using MLStyle using SymbolicUtils -using SymbolicUtils.Rewriters using SymbolicUtils: BasicSymbolic, Symbolic -export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup +export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting const DECA_EQUALITY_SYMBOL = (==) @@ -23,12 +21,10 @@ end function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) S = promote_symtype(op_sym, input_syms...) - rhs = SymbolicUtils.Term{S}(op_sym, input_syms) - SymbolicEquation{Symbolic}(output_sym, rhs) + SymbolicEquation{Symbolic}(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) end function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) @@ -38,34 +34,30 @@ function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - non_tangents = filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) + non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) map(non_tangents) do node to_symbolics(d, symvar_lookup, node.index, node.name) end end # XXX SymbolicUtils.substitute swaps the order of multiplication. -# example: @decapode begin -# u::Form0 -# G::Form0 -# κ::Constant -# ∂ₜ(G) == κ*★(d(★(d(u)))) +# e.g. @decapode begin +# ∂ₜ(G) == κ*u # end -# will have the kappa*var term rewritten to var*kappa +# will have the κ*u term rewritten to u*κ function merge_equations(d::SummationDecapode) symvar_lookup = symbolics_lookup(d) symexpr_list = extract_symexprs(d, symvar_lookup) - eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node - sym = symvar_lookup[d[node, :name]] + eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do i + sym = symvar_lookup[d[i, :name]] (sym, sym) end) - foreach(symexpr_list) do expr - merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) - push!(eqn_lookup, (expr.lhs => merged_rhs)) + foreach(symexpr_list) do x + push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup))) end - terminals = filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) + terminals = Iterators.filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) end @@ -79,8 +71,8 @@ function apply_rewrites(symexprs, rewriter) end function to_acset(d::SummationDecapode, sym_exprs) - outer_types = map([infer_states(d)..., infer_terminals(d)...]) do idx - :($(d[idx, :name])::$(d[idx, :type])) + outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i + :($(d[i, :name])::$(d[i, :type])) end tangents = map(incident(d, DerivOp, :op1)) do op1 @@ -89,19 +81,16 @@ function to_acset(d::SummationDecapode, sym_exprs) #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) - recursive_descent = @λ begin - e::Expr => begin - if e.head == :call - e.args[1] = nameof(e.args[1]) - map(recursive_descent, e.args[2:end]) - end + reify!(exprs) = foreach(exprs) do x + if typeof(x)==Expr && x.head == :call + x.args[1] = nameof(x.args[1]) + reify!(x.args[2:end]) end - sym => nothing end - foreach(recursive_descent, final_exprs) + reify!(final_exprs) deca_block = quote end deca_block.args = [outer_types..., tangents..., final_exprs...] - infer_types!(SummationDecapode(parse_decapode(deca_block))) + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) end From 5b84cc8a93097832e2a0fa492bba913f17904bba Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Sat, 28 Sep 2024 10:49:14 -0400 Subject: [PATCH 28/39] Further improvement of acset2symbolics Remove special DerivOp handling, fixed bug where multiple equations with the same variable result were being dropped, more tests to cover these cases and further clean up. --- src/acset2symbolic.jl | 40 ++++++++++++++++++++-------------------- src/sym_rewrite.jl | 2 ++ test/acset2symbolic.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index c383ca2..b2da836 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -3,6 +3,7 @@ using ACSets using SymbolicUtils using SymbolicUtils: BasicSymbolic, Symbolic +# TODO: Expose only the symbolic_rewriting function export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting const DECA_EQUALITY_SYMBOL = (==) @@ -30,13 +31,13 @@ end function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) d = infer_types!(deepcopy(old_d)) eqns = merge_equations(d) - to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) + to_acset(d, apply_rewrites(eqns, rewriter)) end +apply_rewrites(eqns, rewriter) = isnothing(rewriter) ? eqns : map(rewriter, eqns) + function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - sym_list = SymbolicEquation{Symbolic}[] - non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) - map(non_tangents) do node + map(topological_sort_edges(d)) do node to_symbolics(d, symvar_lookup, node.index, node.name) end end @@ -50,26 +51,22 @@ function merge_equations(d::SummationDecapode) symvar_lookup = symbolics_lookup(d) symexpr_list = extract_symexprs(d, symvar_lookup) - eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do i - sym = symvar_lookup[d[i, :name]] - (sym, sym) - end) + eqn_lookup = Dict() + + terminal_vars = infer_terminal_names(d) + terminal_eqns = SymbolicEquation{Symbolic}[] + foreach(symexpr_list) do x push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup))) + if x.lhs.name in terminal_vars + push!(terminal_eqns, SymbolicEquation{Symbolic}(x.lhs, eqn_lookup[x.lhs])) + end end - terminals = Iterators.filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) - map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) + formed_deca_eqn.(terminal_eqns) end -formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) - -function apply_rewrites(symexprs, rewriter) - map(symexprs) do sym - res_sym = rewriter(sym) - isnothing(res_sym) ? sym : res_sym - end -end +formed_deca_eqn(symeqn::SymbolicEquation{Symbolic}) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [symeqn.lhs, symeqn.rhs]) function to_acset(d::SummationDecapode, sym_exprs) outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i @@ -90,8 +87,11 @@ function to_acset(d::SummationDecapode, sym_exprs) end reify!(final_exprs) - deca_block = quote end - deca_block.args = [outer_types..., tangents..., final_exprs...] + deca_block = quote + $(outer_types...) + $(final_exprs...) + end + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) end diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index e63903c..cc6e4b4 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -1,3 +1,5 @@ +# TODO: Delete this file + using Test using DiagrammaticEquations using SymbolicUtils diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index 89de387..f31efbb 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -68,6 +68,36 @@ using Catlab all_ops_res[5, :name] = :D all_ops_res[6, :name] = :C @test all_ops ≃ all_ops_res + + with_deriv = @decapode begin + A::Form0 + Ȧ::Form0 + + ∂ₜ(A) == Ȧ + Ȧ == Δ(A) + end + + @test with_deriv == symbolic_rewriting(with_deriv) + + repeated_vars = @decapode begin + A::Form0 + B::Form1 + C::Form1 + + C == d(A) + C == Δ(B) + C == d(A) + end + + @test repeated_vars == symbolic_rewriting(repeated_vars) + + # TODO: This is broken because of the terminals issue in #77 + self_changing = @decapode begin + A::Form0 + A == ∂ₜ(A) + end + + @test_broken repeated_vars == symbolic_rewriting(self_changing) end function expr_rewriter(rules::Vector) From 35f7b8e6a38443a7ef6ead0c72819ade5c0f19d1 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Sat, 28 Sep 2024 11:09:29 -0400 Subject: [PATCH 29/39] Remove extraneous tangents --- src/acset2symbolic.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index b2da836..13119c2 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -73,10 +73,6 @@ function to_acset(d::SummationDecapode, sym_exprs) :($(d[i, :name])::$(d[i, :type])) end - tangents = map(incident(d, DerivOp, :op1)) do op1 - :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) - end - #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) reify!(exprs) = foreach(exprs) do x From da0f81ad9dd84d84133a0d20f1a38bdaccf3f7ad Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sat, 28 Sep 2024 12:39:19 -0400 Subject: [PATCH 30/39] Remove redundant helper functions --- src/acset2symbolic.jl | 66 +++++++++++++++++------------------------- test/acset2symbolic.jl | 3 +- test/runtests.jl | 7 ++--- 3 files changed, 30 insertions(+), 46 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 13119c2..1628589 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -3,10 +3,10 @@ using ACSets using SymbolicUtils using SymbolicUtils: BasicSymbolic, Symbolic -# TODO: Expose only the symbolic_rewriting function -export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting +export symbolic_rewriting -const DECA_EQUALITY_SYMBOL = (==) +const EQUALITY = (==) +const SymEqSym = SymbolicEquation{Symbolic} function symbolics_lookup(d::SummationDecapode) Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i @@ -14,9 +14,9 @@ function symbolics_lookup(d::SummationDecapode) end) end -function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I) - new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space) - SymbolicUtils.Sym{new_type}(d[idx, :name]) +function decavar_to_symbolics(d::SummationDecapode, var_idx::Int; space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[var_idx, :type], space) + SymbolicUtils.Sym{new_type}(d[var_idx, :name]) end function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) @@ -25,49 +25,39 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSym op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) S = promote_symtype(op_sym, input_syms...) - SymbolicEquation{Symbolic}(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) + SymEqSym(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) end -function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) - d = infer_types!(deepcopy(old_d)) - eqns = merge_equations(d) - to_acset(d, apply_rewrites(eqns, rewriter)) +function to_symbolics(d::SummationDecapode) + symvar_lookup = symbolics_lookup(d) + map(e -> to_symbolics(d, symvar_lookup, e.index, e.name), topological_sort_edges(d)) end -apply_rewrites(eqns, rewriter) = isnothing(rewriter) ? eqns : map(rewriter, eqns) - -function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - map(topological_sort_edges(d)) do node - to_symbolics(d, symvar_lookup, node.index, node.name) - end +function symbolic_rewriting(d::SummationDecapode, rewriter=identity) + d′ = infer_types!(deepcopy(d)) + eqns = merge_equations(d′) + to_acset(d′, map(rewriter, eqns)) end # XXX SymbolicUtils.substitute swaps the order of multiplication. -# e.g. @decapode begin -# ∂ₜ(G) == κ*u -# end -# will have the κ*u term rewritten to u*κ +# e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ function merge_equations(d::SummationDecapode) - symvar_lookup = symbolics_lookup(d) - symexpr_list = extract_symexprs(d, symvar_lookup) - + symexprs = to_symbolics(d) eqn_lookup = Dict() + terminal_eqns = SymEqSym[] - terminal_vars = infer_terminal_names(d) - terminal_eqns = SymbolicEquation{Symbolic}[] - - foreach(symexpr_list) do x + foreach(symexprs) do x push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup))) - if x.lhs.name in terminal_vars - push!(terminal_eqns, SymbolicEquation{Symbolic}(x.lhs, eqn_lookup[x.lhs])) + if x.lhs.name in infer_terminal_names(d) + push!(terminal_eqns, SymEqSym(x.lhs, eqn_lookup[x.lhs])) end end - formed_deca_eqn.(terminal_eqns) + map(terminal_eqns) do eqn + SymbolicUtils.Term{Number}(EQUALITY, [eqn.lhs, eqn.rhs]) + end end -formed_deca_eqn(symeqn::SymbolicEquation{Symbolic}) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [symeqn.lhs, symeqn.rhs]) - function to_acset(d::SummationDecapode, sym_exprs) outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i :($(d[i, :name])::$(d[i, :type])) @@ -76,18 +66,14 @@ function to_acset(d::SummationDecapode, sym_exprs) #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) reify!(exprs) = foreach(exprs) do x - if typeof(x)==Expr && x.head == :call + if typeof(x) == Expr && x.head == :call x.args[1] = nameof(x.args[1]) reify!(x.args[2:end]) end end reify!(final_exprs) - deca_block = quote - $(outer_types...) - $(final_exprs...) - end - + deca_block = quote end + deca_block.args = [outer_types..., final_exprs...] ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) end - diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index f31efbb..e391fc0 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -93,8 +93,7 @@ using Catlab # TODO: This is broken because of the terminals issue in #77 self_changing = @decapode begin - A::Form0 - A == ∂ₜ(A) + c_exp == ∂ₜ(c_exp) end @test_broken repeated_vars == symbolic_rewriting(self_changing) diff --git a/test/runtests.jl b/test/runtests.jl index ec5047b..c68bded 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,14 +38,13 @@ end include("openoperators.jl") end -@testset "ThDEC Symbolics" begin - include("decasymbolic.jl") -end - @testset "Symbolic Rewriting" begin include("graph_traversal.jl") include("acset2symbolic.jl") end +@testset "ThDEC Symbolics" begin + include("decasymbolic.jl") +end include("aqua.jl") From fb4927c1952d446f8362461564598ec4400d6cca Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sat, 28 Sep 2024 14:13:13 -0400 Subject: [PATCH 31/39] Pass indexed names and types directly --- src/acset2symbolic.jl | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 1628589..5ce020a 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -9,14 +9,14 @@ const EQUALITY = (==) const SymEqSym = SymbolicEquation{Symbolic} function symbolics_lookup(d::SummationDecapode) - Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i - (d[i, :name], decavar_to_symbolics(d, i)) + Dict{Symbol, BasicSymbolic}(map(d[:name],d[:type]) do name,type + (name, decavar_to_symbolics(d, name, type)) end) end -function decavar_to_symbolics(d::SummationDecapode, var_idx::Int; space = :I) - new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[var_idx, :type], space) - SymbolicUtils.Sym{new_type}(d[var_idx, :name]) +function decavar_to_symbolics(d::SummationDecapode, var_name::Symbol, var_type::Symbol; space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space) + SymbolicUtils.Sym{new_type}(var_name) end function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) @@ -42,14 +42,13 @@ end # XXX SymbolicUtils.substitute swaps the order of multiplication. # e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ function merge_equations(d::SummationDecapode) - symexprs = to_symbolics(d) - eqn_lookup = Dict() - terminal_eqns = SymEqSym[] + eqn_lookup, terminal_eqns = Dict(), SymEqSym[] - foreach(symexprs) do x - push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup))) + foreach(to_symbolics(d)) do x + sub = SymbolicUtils.substitute(x.rhs, eqn_lookup) + push!(eqn_lookup, (x.lhs => sub)) if x.lhs.name in infer_terminal_names(d) - push!(terminal_eqns, SymEqSym(x.lhs, eqn_lookup[x.lhs])) + push!(terminal_eqns, SymEqSym(x.lhs, sub)) end end From 2d158e8f80573e2a65648f73b39810422c7bdb44 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Sat, 28 Sep 2024 15:08:49 -0400 Subject: [PATCH 32/39] Removed extraneous d arg --- src/acset2symbolic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 5ce020a..18d6326 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -10,11 +10,11 @@ const SymEqSym = SymbolicEquation{Symbolic} function symbolics_lookup(d::SummationDecapode) Dict{Symbol, BasicSymbolic}(map(d[:name],d[:type]) do name,type - (name, decavar_to_symbolics(d, name, type)) + (name, decavar_to_symbolics(name, type)) end) end -function decavar_to_symbolics(d::SummationDecapode, var_name::Symbol, var_type::Symbol; space = :I) +function decavar_to_symbolics(var_name::Symbol, var_type::Symbol; space = :I) new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space) SymbolicUtils.Sym{new_type}(var_name) end From 0b32babe331d0b39101c4f5781e67dcf67fe81ef Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 30 Sep 2024 15:54:28 -0400 Subject: [PATCH 33/39] fixing work on tumor invasion --- src/acset2symbolic.jl | 3 +-- src/symbolictheoryutils.jl | 1 + test/decasymbolic.jl | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index c383ca2..e0ad0e3 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -35,8 +35,7 @@ end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) sym_list = SymbolicEquation{Symbolic}[] - non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) - map(non_tangents) do node + map(topological_sort_edges(d)) do node to_symbolics(d, symvar_lookup, node.index, node.name) end end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 38f4fcc..81b980e 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -157,6 +157,7 @@ macro alias(body) $rep(s...) end export $alias + Base.nameof(::typeof($alias), s) = Symbol("$alias") end)) end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 722157c..f355071 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -143,7 +143,7 @@ end TumorInvasion = @decapode begin (C,fC)::Form0 (Dif,Kd,Cmax)::Constant - ∂ₜ(C) == Dif * Δ(C) + fC - Kd * C + ∂ₜ(C) == Dif * Δ(C) + fC - C * Kd end infer_types!(TumorInvasion) context = SymbolicContext(Term(TumorInvasion)) @@ -151,5 +151,8 @@ end # new terms introduced @test_broken TumorInvasion == TumorInvasion′ + # TI' has (11, Literal, -1) and (12, infer, mult_1) + # Op1 (2, 1, 4, 7) should be (2, 4, 1, 7) + # Sum is (1, 6), (2, 10) end From fe21de425fd2555ec6f25a9c511d47a2860956d4 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 30 Sep 2024 20:07:11 -0400 Subject: [PATCH 34/39] macros which create export stmts will fail inside @testset due to JuliaLang issue #51325 --- src/SymbolicUtilsInterop.jl | 3 ++ src/symbolictheoryutils.jl | 10 +++--- test/decasymbolic.jl | 70 ++++++++++++++++++------------------- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 2e50048..c95fb54 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -62,6 +62,9 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) end end end +# TODO subtraction is not parsed as such. e.g., +# a, b = @syms a::Scalar b::Scalar +# Term(a-b) = Plus(Term[Var(:a), Mult(Term[Lit(Symbol("-1")), Var(:b)])) decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 81b980e..a0466e7 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -107,12 +107,14 @@ macro operator(head, body) s = promote_symtype($f, $(argnames...)) SymbolicUtils.Term{s}($f, Any[$(argnames...)]) end + export $f Base.show(io::IO, ::typeof($f)) = print(io, $f) end) - # if there are rewriting rules, add a method which accepts the function symbol and its arity (to prevent shadowing on operators like `-`) + # if there are rewriting rules, add a method which accepts the function symbol and its arity + # (to prevent shadowing on operators like `-`) if !isempty(rulecalls) push!(result.args, quote function rules(::typeof($f), ::Val{$arity}) @@ -152,16 +154,16 @@ macro alias(body) result = quote end foreach(aliases) do alias push!(result.args, - esc(quote + quote function $alias(s...) $rep(s...) end export $alias Base.nameof(::typeof($alias), s) = Symbol("$alias") - end)) + end) end - result + return esc(result) end export @alias diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index f355071..01b76ce 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -68,50 +68,47 @@ u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} end -@testset "Operator definition" begin +# this is not nabla but "bizarro Δ" +del_expand_0, del_expand_1 = @operator ∇(S)::DECQuantity begin + @match S begin + PatScalar(_) => error("Argument of type $S is invalid") + PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) + end + @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) + @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) +end; + +# we will test is new operator +(r0, r1, r2) = @operator ρ(S)::DECQuantity begin + S <: Form ? Scalar : Form + @rule ρ(~x::isForm0) => 0 + @rule ρ(~x::isForm1) => 1 + @rule ρ(~x::isForm2) => 2 +end - # this is not nabla but "bizarro Δ" - del_expand_0, del_expand_1 = - @operator ∇(S)::DECQuantity begin - @match S begin - PatScalar(_) => error("Argument of type $S is invalid") - PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) - end - @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) - @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) - end; +R, = @operator φ(S1, S2, S3)::DECQuantity begin + let T1=S1, T2=S2, T3=S3 + Scalar + end + @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) +end +@alias (φ′,) => φ + +@testset "Operator definition" begin + + # ∇ @test_throws Exception ∇(b) @test symtype(∇(u)) == PrimalForm{0, :X ,2} @test promote_symtype(∇, u) == PrimalForm{0, :X, 2} - @test isequal(del_expand_0(∇(u)), ★(d(★(d(u))))) - - # we will test is new operator - (r0, r1, r2) = @operator ρ(S)::DECQuantity begin - if S <: Form - Scalar - else - Form - end - @rule ρ(~x::isForm0) => 0 - @rule ρ(~x::isForm1) => 1 - @rule ρ(~x::isForm2) => 2 - end - + + # ρ @test symtype(ρ(u)) == Scalar - R, = @operator φ(S1, S2, S3)::DECQuantity begin - let T1=S1, T2=S2, T3=S3 - Scalar - end - @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) - end - - # TODO we need to alias rewriting rules - @alias (φ′,) => φ - + # R @test isequal(R(φ(2u,2u,2u)), R(φ′(2u,2u,2u))) + # TODO we need to alias rewriting rules end @@ -149,7 +146,8 @@ end context = SymbolicContext(Term(TumorInvasion)) TumorInvasion′ = SummationDecapode(DecaExpr(context)) - # new terms introduced + # new terms introduced because Symbolics converts subtraction expressions + # e.g., a - b => +(a, -b) @test_broken TumorInvasion == TumorInvasion′ # TI' has (11, Literal, -1) and (12, infer, mult_1) # Op1 (2, 1, 4, 7) should be (2, 4, 1, 7) From ffa7c8c409fa5709aca6ff70ef0d5665b941263c Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 30 Sep 2024 20:35:48 -0400 Subject: [PATCH 35/39] removed ghost emoji and added convenience function for rules. aqua's failing persistent tasks. --- src/symbolictheoryutils.jl | 2 ++ test/decasymbolic.jl | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index a0466e7..63dafc9 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -120,6 +120,8 @@ macro operator(head, body) function rules(::typeof($f), ::Val{$arity}) [($(rulecalls...))] end + + rules(::typeof($f)) = rules($f, Val{1}) end) end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 01b76ce..e9bb0df 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -6,8 +6,10 @@ using SymbolicUtils using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle +import DiagrammaticEquations: rules + # load up some variable variables and expressions -👻, = @syms 👻::InferredType +ϐ, = @syms ϐ::InferredType # \varbeta ℓ, = @syms ℓ::Literal c, t = @syms c::Const t::Parameter a, b = @syms a::Scalar b::Scalar @@ -18,7 +20,7 @@ u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @testset "Term Construction" begin - @test symtype(👻) == InferredType + @test symtype(ϐ) == InferredType @test symtype(ℓ) == Literal @test symtype(c) == Const @test symtype(t) == Parameter @@ -33,11 +35,11 @@ u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test symtype(c + t) == Scalar @test symtype(t + t) == Scalar @test symtype(c + c) == Scalar - @test symtype(t + 👻) == InferredType + @test symtype(t + ϐ) == InferredType @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} - @test symtype(u ∧ 👻) == InferredType + @test symtype(u ∧ ϐ) == InferredType # @test_throws ThDEC.SortError ThDEC.♯(u) @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} From 97e8b2d3efd79f1ee40a36b6366c53b4bdf40f1e Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:01:14 -0400 Subject: [PATCH 36/39] Added more tests for acset2symbolics --- src/acset2symbolic.jl | 7 ++++-- test/acset2symbolic.jl | 55 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index a43c511..7e13b4c 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -14,7 +14,7 @@ function symbolics_lookup(d::SummationDecapode) end) end -function decavar_to_symbolics(var_name::Symbol, var_type::Symbol; space = :I) +function decavar_to_symbolics(var_name::Symbol, var_type::Symbol, space = :I) new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space) SymbolicUtils.Sym{new_type}(var_name) end @@ -65,7 +65,9 @@ function merge_equations(d::SummationDecapode) end function to_acset(d::SummationDecapode, sym_exprs) - outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i + literals = incident(d, :Literal, :type) + + outer_types = map([infer_states(d)..., infer_terminals(d)..., literals...]) do i :($(d[i, :name])::$(d[i, :type])) end @@ -81,5 +83,6 @@ function to_acset(d::SummationDecapode, sym_exprs) deca_block = quote end deca_block.args = [outer_types..., final_exprs...] + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) end diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index e391fc0..685e06d 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -97,6 +97,35 @@ using Catlab end @test_broken repeated_vars == symbolic_rewriting(self_changing) + + literal = @decapode begin + A::Form0 + B::Form0 + + B == A * 2 + end + + @test literal == symbolic_rewriting(literal) + + parameter = @decapode begin + A::Form0 + P::Parameter + B::Form0 + + B == A * P + end + + @test parameter == symbolic_rewriting(parameter) + + constant = @decapode begin + A::Form0 + C::Constant + B::Form0 + + B == A * C + end + + @test constant == symbolic_rewriting(constant) end function expr_rewriter(rules::Vector) @@ -154,6 +183,30 @@ end @test op2s_equiv == op2s_rewritten + + distr_d = @decapode begin + A::Form0 + B::Form0 + C::Form0 + + C == d(∧(A, B)) + end + infer_types!(distr_d) + + leibniz = @rule d(∧(~x, ~y)) => ∧(d(~x), ~y) + ∧(~x, d(~y)) + + distr_d_rewritten = symbolic_rewriting(distr_d, expr_rewriter([leibniz])) + + distr_d_res = @decapode begin + A::Form0 + B::Form0 + C::Form0 + + C == ∧(d(A), B) + ∧(A, d(B)) + end + infer_types!(distr_d_res) + + @test distr_d_res == distr_d_rewritten end @testset "Heat" begin @@ -215,7 +268,7 @@ end ∂ₜ(G) == Δ(u)*κ end infer_types!(Heat) - + Heat_open = @decapode begin u::Form0 G::Form0 From 25493226f96171b71dafda8ed0238d136dbd1bc2 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:26:22 -0400 Subject: [PATCH 37/39] Fixed persistence issue Also set default form dim to 2 and allowed it to vary. --- src/deca/ThDEC.jl | 19 +++++++++---------- src/symbolictheoryutils.jl | 19 +++++++++---------- test/acset2symbolic.jl | 8 ++++---- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index e71701b..c174175 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -203,13 +203,12 @@ end @alias (Δ₀, Δ₁, Δ₂) => Δ -# Base.show(io::IO, - +# TODO: Determine what we need to do for .+ @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar - (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => S1 # commutativity + (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i, d, s, n} # commutativity (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) Form{i1, d1, s1, n1} @@ -298,18 +297,18 @@ function Base.nameof(::typeof(★), s) Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end -function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) +function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol, dim::Int = 2) @match qty begin :Scalar => Scalar :Constant => Const :Parameter => Parameter :Literal => Literal - :Form0 => PrimalForm{0, space, 1} - :Form1 => PrimalForm{1, space, 1} - :Form2 => PrimalForm{2, space, 1} - :DualForm0 => DualForm{0, space, 1} - :DualForm1 => DualForm{1, space, 1} - :DualForm2 => DualForm{2, space, 1} + :Form0 => PrimalForm{0, space, dim} + :Form1 => PrimalForm{1, space, dim} + :Form2 => PrimalForm{2, space, dim} + :DualForm0 => DualForm{0, space, dim} + :DualForm1 => DualForm{1, space, dim} + :DualForm2 => DualForm{2, space, dim} :Infer => InferredType _ => error("Received $qty") end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 63dafc9..7746559 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -43,7 +43,7 @@ Creates an operator `foo` with arguments which are types in a given Theory. This (@rule expr1) ... (@rule exprN) -end +end ``` builds ``` @@ -81,7 +81,7 @@ macro operator(head, body) end (f, types, Theory) = ph(head) - # Passing types to functions requires that we type the signature with ::Type{T}. + # Passing types to functions requires that we type the signature with ::Type{T}. # This means that the user would have to write `my_op(::Type{T1}, ::Type{T2}, ...)` # As a convenience to the user, we allow them to specify the signature using just the types themselves: # `my_op(T1, T2, ...)` @@ -89,15 +89,15 @@ macro operator(head, body) sort_constraints = [:($S<:$Theory) for S in types] arity = length(sort_types) - # Parse the body for @rule calls. + # Parse the body for @rule calls. block, rulecalls = @match Base.remove_linenums!(body) begin Expr(:block, block, rules...) => (block, rules) s => nothing end - + # initialize the result result = quote end - + # construct the function on basic symbolics argnames = [gensym(:x) for _ in 1:arity] argclaus = [:($a::Symbolic) for a in argnames] @@ -110,17 +110,17 @@ macro operator(head, body) export $f - Base.show(io::IO, ::typeof($f)) = print(io, $f) + # Base.show(io::IO, ::typeof($f)) = print(io, $f) end) - # if there are rewriting rules, add a method which accepts the function symbol and its arity + # if there are rewriting rules, add a method which accepts the function symbol and its arity # (to prevent shadowing on operators like `-`) if !isempty(rulecalls) push!(result.args, quote function rules(::typeof($f), ::Val{$arity}) [($(rulecalls...))] end - + rules(::typeof($f)) = rules($f, Val{1}) end) end @@ -158,7 +158,7 @@ macro alias(body) push!(result.args, quote function $alias(s...) - $rep(s...) + $rep(s...) end export $alias @@ -171,4 +171,3 @@ export @alias alias(x) = error("$x has no aliases") export alias - diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index 685e06d..0212a7e 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -186,8 +186,8 @@ end distr_d = @decapode begin A::Form0 - B::Form0 - C::Form0 + B::Form1 + C::Form2 C == d(∧(A, B)) end @@ -199,8 +199,8 @@ end distr_d_res = @decapode begin A::Form0 - B::Form0 - C::Form0 + B::Form1 + C::Form2 C == ∧(d(A), B) + ∧(A, d(B)) end From 1fae01cd0dbfa4fd7c4a581c63858275e05eba26 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:33:41 -0400 Subject: [PATCH 38/39] Final touches --- src/deca/ThDEC.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index c174175..4fa4695 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -199,6 +199,7 @@ end end @rule Δ(~x::isForm0) => ★(d(★(d(~x)))) @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) + @rule Δ(~x::isForm2) => d(★(d(★(~x)))) end @alias (Δ₀, Δ₁, Δ₂) => Δ @@ -297,6 +298,7 @@ function Base.nameof(::typeof(★), s) Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end +# TODO: Check that form type is no larger than the ambient dimension function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol, dim::Int = 2) @match qty begin :Scalar => Scalar From 6125c1edc2bc0899fa418aa33cd173516c1988a3 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:43:09 -0400 Subject: [PATCH 39/39] Remove unused fuctionality --- src/acset2symbolic.jl | 7 ---- src/sym_rewrite.jl | 91 ------------------------------------------- 2 files changed, 98 deletions(-) delete mode 100644 src/sym_rewrite.jl diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 7e13b4c..330fa76 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -33,13 +33,6 @@ function to_symbolics(d::SummationDecapode) map(e -> to_symbolics(d, symvar_lookup, e.index, e.name), topological_sort_edges(d)) end -function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - sym_list = SymbolicEquation{Symbolic}[] - map(topological_sort_edges(d)) do node - to_symbolics(d, symvar_lookup, node.index, node.name) - end -end - function symbolic_rewriting(d::SummationDecapode, rewriter=identity) d′ = infer_types!(deepcopy(d)) eqns = merge_equations(d′) diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl deleted file mode 100644 index cc6e4b4..0000000 --- a/src/sym_rewrite.jl +++ /dev/null @@ -1,91 +0,0 @@ -# TODO: Delete this file - -using Test -using DiagrammaticEquations -using SymbolicUtils -using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, symtype, promote_symtype -using MLStyle - - -Brusselator = @decapode begin - (U, V)::Form0 - U2V::Form0 - (U̇, V̇)::Form0 - - (α)::Constant - F::Parameter - - U2V == (U .* U) .* V - - U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F - V̇ == (3.4 * U) - U2V + (α * Δ(V)) - ∂ₜ(U) == U̇ - ∂ₜ(V) == V̇ -end -infer_types!(Brusselator) - - -# it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms. -heat_exprs = symbolic_rewriting(Heat) -sub = heat_exprs[1].arguments[2].arguments[2] - -a, b = @syms a::Scalar b::Scalar -u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} - -r = rules(Δ, Val(1)) - -# rule without predication works -R = @rule Δ(~x) => ★(d(★(d(~x)))) -rwR = Fixpoint(Prewalk(Chain([R]))) - -R(Δ(d(u))) - -# since promote_symtype(d(u)) returns Any while promote_symtype(d, u). I wonder -# if `d(u)` is not subjected to `symtype` - -Rp = @rule Δ(~x::isForm1) => "Success" -Rp(Δ(v)) # works -Rp(Δ(d(u))) # works - -Rp1 = @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) - -Rp1(Δ(v)) # works -Rp1(Δ(d(u))) # works -rwRp1 = Fixpoint(Prewalk(Chain([Rp1]))) -rwRp1(Δ(d(u))) - -rwr = Fixpoint(Prewalk(Chain(r))) -rwr(heat_exprs[1]) # THIS WORKS! - -rwr(Δ(d(u))) # rwr -rwr(heat_exprs[1].arguments[2]) - -r[2](Δ(d(u))) # works - - -# rwR(heat_exprs[1]) -# rwR(sub) - -# tilde? -R1 = @rule Δ(~~x::(x->isForm1(x))) => ★(d(★(d(~x)))) - -@macroexpand @rule Δ(~x::isForm1) => "Success" - -# pulling out the subexpression -rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(SymbolicUtils.Chain(r))) - -res_exprs = apply_rewrites(heat_exprs, rewriter) -sub_exprs = apply_rewrites([sub], rewriter) - -optm_dd_0 = @rule d(d(~x)) => 0 -star_0 = @rule ★(0) => 0 -d_0 = @rule d(0) => 0 - -optm_rewriter = SymbolicUtils.Postwalk( - SymbolicUtils.Fixpoint(SymbolicUtils.Chain([optm_dd_0, star_0, d_0]))) - -res_merge_exprs = map(optm_rewriter, res_exprs) - -deca_test = to_acset(Heat, res_exprs) -infer_types!(deca_test) -resolve_overloads!(deca_test)