Skip to content

Commit

Permalink
Complete type checking
Browse files Browse the repository at this point in the history
I've changing our type checking behavior to return all errors. If there is any error, an exception will be thrown.
  • Loading branch information
GeorgeR227 committed Oct 30, 2024
1 parent c389f1d commit c8c6d92
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 29 deletions.
82 changes: 58 additions & 24 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ using ACSets.InterTypes

@intertypes "decapodeacset.it" module decapodeacset end

import Base.show

using .decapodeacset
# TODO: Move this export to main file
export Operator, same_type_rules_op, arthimetic_operators, infer_resolve!, type_check
export DecaTypeExeception

# Transferring pointers
# --------------------
Expand Down Expand Up @@ -566,9 +569,57 @@ function apply_overloading_rule!(d::SummationDecapode, op_id, rule, edge_val)
return false
end

function apply_type_checking_rule(d::SummationDecapode, op_id, rule, edge_val)
name_present, type_diff = check_operator(d, op_id, rule, edge_val; check_name = true, ignore_infers = true, ignore_usertypes = true)
return name_present, type_diff == 0
struct DecaTypeError{T}
rule::Operator{T}
idx::Int
table::Symbol
end

Base.show(io::IO, type_error::DecaTypeError{T}) where T = println("Operator at index $(type_error.idx) in table $(type_error.table) is not correctly typed. Perhaps the operator was meant to be $(type_error.rule)?")

struct DecaTypeExeception{T} <: Exception
type_errors::Vector{DecaTypeError{T}}
end

function Base.show(io::IO, type_except::DecaTypeExeception{T}) where T
map(x -> Base.show(io, x), type_except.type_errors)
end

function run_typechecking(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})

type_errors = DecaTypeError{Symbol}[]

for table in [:Op1, :Op2]
for op_idx in parts(d, table)
type_error = run_typechecking_for_op(d, op_idx, type_rules, Val(table))
if type_error !== nothing
push!(type_errors, type_error)
end
end
end

return type_errors
end

function run_typechecking_for_op(d::SummationDecapode, op_id, type_rules, edge_val::Val{table}) where table

type_error = nothing
min_type_diff = Inf

for rule in type_rules
name_present, type_diff = check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true, ignore_infers = true, ignore_usertypes = true)

if name_present
if type_diff == 0
return nothing
elseif type_diff < min_type_diff
min_type_diff = type_diff
type_error = DecaTypeError{Symbol}(rule, op_id, table)
end
end

end
return type_error
end

# TODO: Although the big-O complexity is the same, it might be more efficent on
Expand Down Expand Up @@ -640,29 +691,12 @@ function resolve_overloads!(d::SummationDecapode, resolve_rules::AbstractVector{
end

function type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})
for table in [:Op1, :Op2]
for op_idx in parts(d, table)

check_passed = true
for rule in type_rules
rule_applies, rule_checked = apply_type_checking_rule(d, op_idx, rule, Val(table))
type_errors = run_typechecking(d, type_rules)

rule_applies || continue
check_passed = false
isempty(type_errors) && return true

rule_checked || continue
check_passed = true
break
end

check_passed || return false

end
end

# TODO: Add summation type checking

true
throw(DecaTypeExeception{Symbol}(type_errors))
return false
end

function infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}})
Expand Down
8 changes: 3 additions & 5 deletions src/deca/deca_acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ op1_operators = [
# Rules for Δ
Operator(:Form0, :Form0, :Δ₀, [, :∇², :lapl]),
Operator(:Form1, :Form1, :Δ₁, [, :∇², :lapl]),
Operator(:Form2, :Form2, :Δ₂, [, :∇², :lapl]), # TODO: Test with this
Operator(:Form2, :Form2, :Δ₂, [, :∇², :lapl]),

# Rules for Δᵈ
Operator(:DualForm0, :DualForm0, :Δᵈ₀, [, :∇², :lapl]),
Expand Down Expand Up @@ -100,7 +100,7 @@ op2_operators = [
arthimetic_operators(:.^, true)...,

# TODO: Only labelled as broadcasted since Decapodes converts all these
# to their broadcasted forms. They really should have differnt rules.
# to their broadcasted forms. They really should have different rules.
arthimetic_operators(:-, true)...,
arthimetic_operators(:/, true)...,
arthimetic_operators(:*, true)...,
Expand All @@ -114,6 +114,7 @@ op2_operators = [
]

# TODO: When SummationDecapodes are annotated with the degree of their space,
# use dispatch to choose the correct set of rules.
function default_operators(dim)
@assert 1 <= dim <= 2
metric_free = vcat(op1_operators, op2_operators)
Expand Down Expand Up @@ -190,6 +191,3 @@ function vec_to_dec!(d::SummationDecapode)

d
end

# TODO: When SummationDecapodes are annotated with the degree of their space,
# use dispatch to choose the correct set of rules.
73 changes: 73 additions & 0 deletions test/language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,14 @@ end
only_type_dec_res = Set([(:A, :Form0), (:B, :Form0), (:C, :Form1)])
test_nametype_equality(only_type_dec, only_type_dec_res)
@test type_check(only_type_dec)

poorly_type_deca = @decapode begin
(A,B)::Form0

B == d(A)
end
resolve_overloads!(poorly_type_deca)
@test_throws DecaTypeExeception type_check(poorly_type_deca)
end

@testset "Type Inference and Overloading Resolution Integration" begin
Expand Down Expand Up @@ -1088,6 +1096,71 @@ end
end
end

# Momentum-formulation of Navier Stokes on sphere
DiffusionExprBody = quote
(T, Ṫ)::Form0{X}
ϕ::DualForm1{X}
k::Parameter{X}
# Fick's first law
ϕ == (k*d(T))
# Diffusion equation
== (d(ϕ))
end
Diffusion = SummationDecapode(parse_decapode(DiffusionExprBody))
AdvectionExprBody = quote
(M,V)::Form1{X} # M = ρV
(ρ, p, T, Ṫ)::Form0{X}
V == M/avg(ρ)
ρ == p / R₀(T)
== neg((L(V, (T))))
end
Advection = SummationDecapode(parse_decapode(AdvectionExprBody))
SuperpositionExprBody = quote
(T, Ṫ, Ṫ₁, Ṫₐ)::Form0{X}
== Ṫ₁ + Ṫₐ
∂ₜ(T) ==
end
Superposition = SummationDecapode(parse_decapode(SuperpositionExprBody))
compose_continuity = @relation () begin
diffusion(T, Ṫ₁)
advection(M, ρ, P, T, Ṫₐ)
superposition(T, Ṫ, Ṫ₁, Ṫₐ)
end
continuity_cospan = oapply(compose_continuity,
[Open(Diffusion, [:T, :Ṫ]),
Open(Advection, [:M, , :p, :T, :Ṫ]),
Open(Superposition, [:T, :Ṫ, :Ṫ₁, :Ṫₐ])])

continuity = apex(continuity_cospan)
NavierStokesExprBody = quote
(M, Ṁ, G, V)::Form1{X}
(T, ρ, p, ṗ)::Form0{X}
(two,three,kᵥ)::Parameter{X}
V == M/avg(ρ)
== neg(L(V, (V)))*avg(ρ) +
kᵥ*(Δ(V) + d(δ(V))/three) +
d(i(V, (V))/two)*avg(ρ) +
neg(d(p)) +
G*avg(ρ)
∂ₜ(M) ==
== neg((L(V, (p)))) # *Lie(3Form) = Div(*3Form x v) --> conservation of pressure
∂ₜ(p) ==
end
NavierStokes = SummationDecapode(parse_decapode(NavierStokesExprBody))
compose_heatXfer = @relation () begin
continuity(M, ρ, P, T)
navierstokes(M, ρ, P, T)
end
heatXfer_cospan = oapply(compose_heatXfer,
[Open(continuity, [:M, , :P, :T]),
Open(NavierStokes, [:M, , :p, :T])])
HeatXfer = apex(heatXfer_cospan)
infer_types!(HeatXfer)
resolve_overloads!(HeatXfer)

@test_throws DecaTypeExeception type_check(HeatXfer)
@test HeatXfer[12, :op2] == :*
@test HeatXfer[40, :type] == :DualForm1 && HeatXfer[39, :type] == :Form1
end

@testset "Compilation Transformation" begin
Expand Down

0 comments on commit c8c6d92

Please sign in to comment.