Skip to content

Commit

Permalink
Add type rules for vectorfields
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 authored and jpfairbanks committed Aug 23, 2024
1 parent b14f09d commit c9b8898
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 34 deletions.
40 changes: 28 additions & 12 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ end
# A collection of DecaType getters
# TODO: This should be replaced by using a type hierarchy
const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2,
:Literal, :Parameter, :Constant, :infer]
:PVF, :DVF,
:Literal, :Parameter, :Constant, :infer]

const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2]
const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2]

const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer]
const VECTORFIELD_TYPES = [:PVF, :DVF]

const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer]
const USER_TYPES = [:Constant, :Parameter]
const NUMBER_TYPES = [:Literal]
const INFER_TYPES = [:infer]
Expand Down Expand Up @@ -427,12 +430,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)
end

"""
filterfor_forms(types::AbstractVector{Symbol})
filterfor_ec_types(types::AbstractVector{Symbol})
Return any form type symbols.
Return any form or vector-field type symbols.
"""
function filterfor_forms(types::AbstractVector{Symbol})
conditions = x -> !(x in NONFORM_TYPES)
function filterfor_ec_types(types::AbstractVector{Symbol})
conditions = x -> !(x in NON_EC_TYPES)
filter(conditions, types)
end

Expand All @@ -447,16 +450,16 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
types = d[idxs, :type]
all(t != :infer for t in types) && return applied # We need not infer

forms = unique(filterfor_forms(types))
ec_types = unique(filterfor_ec_types(types))

form = @match length(forms) begin
ec_type = @match length(ec_types) begin
0 => return applied # We can not infer
1 => only(forms)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms")
1 => only(ec_types)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types")
end

for idx in idxs
applied |= safe_modifytype!(d, idx, form)
applied |= safe_modifytype!(d, idx, ec_type)
end

return applied
Expand Down Expand Up @@ -489,11 +492,24 @@ function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
score_res = (rule.res_type == type_res)
check_op = (d[op2_id, :op2] in rule.op_names)

if(check_op && (score_proj1 + score_proj2 + score_res == 2))
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
# Special logic for exponentiation:
elseif d[op2_id, :op2] == :^ &&
(type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) ||
(type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer))
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0)
mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0)
return mod_proj1 || mod_res
elseif d[op2_id, :op2] == :^ &&
(type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) ||
(type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer))
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0)
mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0)
return mod_proj1 || mod_res
end

return false
Expand Down
26 changes: 18 additions & 8 deletions src/deca/deca_acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ op1_inf_rules_1D = [

# Rules for the averaging operator
(src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]),

# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]),
(src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]),

# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]),

# Rules for magnitude/ norm
(src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]),
(src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])]
(src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]),
(src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])]

op2_inf_rules_1D = [
# Rules for ∧₀₀, ∧₁₀, ∧₀₁
Expand Down Expand Up @@ -133,13 +140,16 @@ op1_inf_rules_2D = [
(src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]),
(src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]),

# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]),
(src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]),

# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]),

# Rules for magnitude/ norm
(src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]),
(src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]),
(src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]),
(src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]),
(src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]),
(src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])]
(src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]),
(src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])]

op2_inf_rules_2D = vcat(op2_inf_rules_1D, [
# Rules for ∧₁₁, ∧₂₀, ∧₀₂
Expand Down
97 changes: 83 additions & 14 deletions test/language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,14 @@ end
@test issetequal([:V,:X,:k], infer_state_names(oscillator))
end

import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES,
NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES
import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES,
DUALFORM_TYPES, VECTORFIELD_TYPES, NON_EC_TYPES, USER_TYPES, NUMBER_TYPES,
INFER_TYPES, NONINFERABLE_TYPES

@testset "Type Retrival" begin

type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES,
NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES]
NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES]


# No repeated types
Expand All @@ -368,12 +369,12 @@ import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_
no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2))

# Collections of these types should be the same
@test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES))
@test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES))
@test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES))
@test equal_types(ALL_TYPES, FORM_TYPES VECTORFIELD_TYPES NON_EC_TYPES)
@test equal_types(FORM_TYPES, PRIMALFORM_TYPES DUALFORM_TYPES)
@test equal_types(NONINFERABLE_TYPES, USER_TYPES NUMBER_TYPES)

# Proper seperation of types
@test no_overlaps(FORM_TYPES, NONFORM_TYPES)
@test no_overlaps(FORM_TYPES VECTORFIELD_TYPES, NON_EC_TYPES)
@test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES)
@test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES)
@test INFER_TYPES == [:infer]
Expand All @@ -394,9 +395,9 @@ end
import DiagrammaticEquations: safe_modifytype

@testset "Safe Type Modification" begin
all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer]
all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer]
bad_sources = [:Literal, :Constant, :Parameter]
good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :infer]
good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF, :infer]

for tgt in all_types
for src in bad_sources
Expand All @@ -419,13 +420,13 @@ import DiagrammaticEquations: safe_modifytype
end
end

import DiagrammaticEquations: filterfor_forms
import DiagrammaticEquations: filterfor_ec_types

@testset "Form Type Retrieval" begin
all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer]
@test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
@test isempty(filterfor_forms(Symbol[]))
@test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer]))
all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer]
@test filterfor_ec_types(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF]
@test isempty(filterfor_ec_types(Symbol[]))
@test isempty(filterfor_ec_types([:Literal, :Constant, :Parameter, :infer]))
end

@testset "Type Inference" begin
Expand Down Expand Up @@ -825,6 +826,38 @@ end
@test_throws "Type mismatch in summation" infer_types!(d)
end

# Test #25: Infer between flattened and sharpened vector fields.
let
d = @decapode begin
A::Form1
B::DualForm1
C::PVF
D::DVF

A == (E)
B == (F)
C == (G)
D == (H)

I::Form1
J::DualForm1
K::PVF
L::DVF

M == (I)
N == (J)
O == (K)
P == (L)
end
infer_types!(d)

# TODO: Update this as more sharps and flats are released.
names_types_expected = Set([(:A, :Form1), (:B, :DualForm1), (:C, :PVF), (:D, :DVF),
(:E, :DVF), (:F, :infer), (:G, :Form1), (:H, :DualForm1),
(:I, :Form1), (:J, :DualForm1), (:K, :PVF), (:L, :DVF),
(:M, :PVF), (:N, :DVF), (:O, :infer), (:P, :Form1)])
@test test_nametype_equality(d, names_types_expected)
end
end

@testset "Overloading Resolution" begin
Expand Down Expand Up @@ -1042,6 +1075,42 @@ end
op2s_hx = HeatXfer[:op2]
op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀]
@test op2s_hx == op2s_expected_hx

# Infer types and resolve overloads for the Halfar equation.
let
d = @decapode begin
h::Form0
Γ::Form1
n::Constant

∂ₜ(h) == (, d, )(Γ * d(h) (mag((d(h)))^(n-1)) (h^(n+2)))
end
d = expand_operators(d)
infer_types!(d)
resolve_overloads!(d)
@test d == @acset SummationDecapode{Any, Any, Symbol} begin
Var = 19
TVar = 1
Op1 = 8
Op2 = 6
Σ = 1
Summand = 2
src = [1, 1, 1, 13, 12, 6, 18, 19]
tgt = [4, 9, 13, 12, 11, 18, 19, 4]
proj1 = [2, 3, 11, 8, 1, 7]
proj2 = [9, 15, 14, 10, 5, 16]
res = [8, 14, 10, 7, 16, 6]
incl = [4]
summand = [3, 17]
summation = [1, 1]
sum = [5]
op1 = [:∂ₜ, :d₀, :d₀, :♯, :mag, :₁, :dual_d₁, :₀⁻¹]
op2 = [:*, :-, :^, :₁₀, :^, :₁₀]
type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2]
name = [:h, , :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")]
end
end

end

@testset "Compilation Transformation" begin
Expand Down

0 comments on commit c9b8898

Please sign in to comment.