Skip to content

Commit

Permalink
Merge pull request #85 from neonWhiteout/Composition_DSL_good
Browse files Browse the repository at this point in the history
Created Composition DSL.
  • Loading branch information
Xiaoyan-Li authored Sep 27, 2023
2 parents 32661c5 + ab693d4 commit 78cc1f3
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,5 +1045,6 @@ function match_foot_format(footblock::Expr)
end


include("syntax/Composition.jl")

end
170 changes: 170 additions & 0 deletions src/syntax/Composition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
module Composition
export sfcompose, @compose

using ...StockFlow
using ..Syntax
using Catlab.CategoricalAlgebra
using Catlab.WiringDiagrams

import ..Syntax: create_foot
import Catlab.Programs.RelationalPrograms: UntypedUnnamedRelationDiagram


RETURN_UWD = false

"""
Construct a uwd to compose your open stockflows
"""
function create_uwd(;
Box::Vector{Symbol} = Vector{Symbol}(), # stockflows
Port::Vector{Tuple{Int, Int}} = Vector{Tuple{Int, Int}}(), # stockflow => foot number, for each foot on stockflow
OuterPort::Vector{Int} = Vector{Int}(), # unique feet number (1:n)
Junction::Vector{Symbol} = Vector{Symbol}() # A symbol for each (unique) foot
)

uwd = UntypedUnnamedRelationDiagram{Symbol, Symbol}(0)
add_parts!(uwd, :Box, length(Box), name=Box)
add_parts!(uwd, :Junction, length(Junction), variable=Junction)
add_parts!(uwd, :Port, length(Port), box=map(first, Port), junction=map(last, Port))
add_parts!(uwd, :OuterPort, length(OuterPort), outer_junction=OuterPort)
return uwd
end

"""
Parse expression of form A ^ B => C, extract sf A and foot B => C
"""
function interpret_center_of_composition_statement(center::Expr)::Tuple{Symbol, Expr} # sf, foot defintion
@assert length(center.args) == 3 && center.args[1] == :(=>) && typeof(center.args[2]) == Expr "Invalid argument: expected A ^ B => C, A ^ () => C or A ^ B => (), got $center"
# third argument can be symbol or (), the latter of which is an Expr
center_caret_statement = center.args[2]
@assert length(center_caret_statement.args) == 3 && center_caret_statement.args[1] == :^ && typeof(center_caret_statement.args[2]) == Symbol "Invalid center argument: expected A ^ B or A ^ (), got $center"
# third argument here, too, can be symbol or ()
return (center_caret_statement.args[2], Expr(:call, :(=>), center_caret_statement.args[3], center.args[3]))
end

"""
Go line by line and associate stockflows and feet
"""
function interpret_composition_notation(mapping_pair::Expr)::Tuple{Vector{Symbol}, StockAndFlow0}

if mapping_pair.head == :call # (A ^ B => C) case (incl where B or C are ())
sf, foot_def = interpret_center_of_composition_statement(mapping_pair)
return [sf], create_foot(foot_def)
end

expr_args = mapping_pair.args
stockflows = collect(Base.Iterators.takewhile(x -> typeof(x) == Symbol, expr_args))
center_index = length(stockflows) + 1
@assert center_index <= length(expr_args) "A tuple is an invalid expression for composition syntax. Expected argument of form sf1, sf2, ... ^ stock1 => sum1, stock2 => sum2, ..."
center = expr_args[center_index]

foot_temp = Vector{Expr}()

sf, foot_def = interpret_center_of_composition_statement(center)
push!(foot_temp, foot_def)
push!(stockflows, sf)
append!(foot_temp, expr_args[center_index+1:end])

return (stockflows, create_foot(Expr(:tuple, foot_temp...)))
end


"""
sirv = sfcompose(sir, svi, quote
(sr, sv)
sr, sv ^ S => N, I => N
end)
Cannot use () => () as a foot,
the length of the first tuple must be the same as the number of stock flows given as argument,
and every foot can only be used once.
"""
function sfcompose(sfs::Vector{K}, block::Expr) where {K <: AbstractStockAndFlowF}#(sf1, sf2, ..., block)



Base.remove_linenums!(block)
sf_names = block.args[1].args

if length(sfs) == 0 # Composing 0 stock flows should give you an empty stock flow
return StockAndFlowF()
end

@assert length(sf_names) == length(sfs) "The number of symbols on the first line is not the same as the number of stock flow arguments provided. Stockflow #: $(length(sfs)) Symbol #: $(length(sf_names))"



@assert allunique(sf_names) "Not all choices of names for stock flows are unique!"


empty_foot = (@foot () => ())


# symbol representation of sf => (sf itself, sf's feet)
# Every sf has empty foot as first foot to get around being unable to create OpenStockAndFlowF without feet
sf_map::Dict{Symbol, Tuple{AbstractStockAndFlowF, Vector{StockAndFlow0}}} = Dict(sf_names[i] => (sfs[i], [empty_foot]) for i eachindex(sf_names)) # map the symbols to their corresponding stockflows

# all feet
feet_index_dict::Dict{StockAndFlow0, Int} = Dict(empty_foot => 1)
for statement in block.args[2:end]
stockflows, foot = interpret_composition_notation(statement)
# adding new foot to list
@assert (foot keys(feet_index_dict)) "Foot has already been used, or you are using an empty foot!"
push!(feet_index_dict, foot => length(feet_index_dict) + 1)
for stockflow in stockflows
# adding this foot to each stock flow to its left
push!(sf_map[stockflow][2], foot)
end
end

Box::Vector{Symbol} = sf_names


Port = Vector{Tuple{Int, Int}}()

for (k, v) sf_map # TODO: Just find a better way to do this.
for foot v[2]
push!(Port, (findfirst(x -> x == k, sf_names), feet_index_dict[foot]))
end
end

Junction::Vector{Symbol} = [gensym() for _ 1:length(feet_index_dict)]
OuterPort::Vector{Int} = collect(1:length(feet_index_dict))

uwd = create_uwd(Box=Box, Port=Port, Junction=Junction, OuterPort=OuterPort)

# I'd prefer this to be a vector, but oapply didn't like that
# I'd also prefer that I don't include the empty foot, but Open doesn't want to accept stockflows with no feet.
# open_stockflows::AbstractDict = Dict(sf_key => Open(sf_val, foot_dict[sf_val]...,) for (sf_key, sf_val) ∈ sf_map)

open_stockflows::AbstractDict = Dict(sf_key => Open(sf_val[1], sf_val[2]...) for (sf_key, sf_val) sf_map)

if RETURN_UWD # UWD might be a bit screwed up from the empty foot being first.
return apex(oapply(uwd, open_stockflows)), uwd
else
return apex(oapply(uwd, open_stockflows))
end

end


"""
Compose models.
"""
macro compose(args...)
if length(args) == 0
return :(MethodError("No arguments provided! Please provide some number of stockflows, then a quote block."))
end
escaped_block = Expr(:quote, args[end])
sfs = esc.(args[1:end-1])
quote
if length($sfs) == 0
sfcompose(Vector{StockAndFlowF}(), $escaped_block)
else
sfcompose([$(sfs...)], $escaped_block)
end
end
end


end
4 changes: 4 additions & 0 deletions test/Syntax.jl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ using StockFlow
using StockFlow.Syntax
using StockFlow.Syntax: is_binop_or_unary, sum_variables, infix_expression_to_binops, fnone_value_or_vector, extract_function_name_and_args_expr, is_recursive_dyvar, create_foot

@testset "Composition DSL" begin
include("syntax/Composition.jl")
end

@testset "is_binop_or_unary recognises binops" begin
@test is_binop_or_unary(:(a + b))
@test is_binop_or_unary(:(f(a, b)))
Expand Down
109 changes: 109 additions & 0 deletions test/syntax/Composition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using Test
using StockFlow
using StockFlow.Syntax
using StockFlow.Syntax.Composition
import StockFlow.Syntax.Composition: interpret_composition_notation

@testset "Composition creates expected stock flows" begin
empty_sf = StockAndFlowF()


@test (@compose (begin # composing no stock flows returns an empty stock flow.
()
end)) == empty_sf

@test (@compose empty_sf begin
(sf,)
end) == empty_sf

@test (@compose (@stock_and_flow begin; :stocks; A; end;) (@stock_and_flow begin; :stocks; B; end;) (begin
(sf1, sf2)
end)) == (@stock_and_flow begin; :stocks; A; B; end;) # Combining without any composing

@test (@compose (@stock_and_flow begin; :stocks; A; end;) (@stock_and_flow begin; :stocks; A; end;) (begin
(sf1, sf2)
end)) == (@stock_and_flow begin; :stocks; A; A; end;)

@test (@compose (@stock_and_flow begin; :stocks; A; end;) (@stock_and_flow begin; :stocks; A; end;) (begin
(sf1, sf2)
sf1, sf2 ^ A => ()
end)) == (@stock_and_flow begin; :stocks; A; end;)

@test ((@compose (@stock_and_flow begin
:stocks
A
B

:dynamic_variables
v1 = A + B

:sums
N = [A,B]
end) (@stock_and_flow begin
:stocks
B
C

:dynamic_variables
v2 = B + C

:sums
N = [B,C]
end) (begin
(sfA, sfC)
sfA, sfC ^ B => N
end))
==
(@stock_and_flow begin
:stocks
A
B
C

:dynamic_variables
v1 = A + B
v2 = B + C

:sums
N = [A, B, C]
end))



end

@testset "interpret_composition_notation interprets arguments correctly" begin
# @test interpret_composition_notation(:(() ^ A => N)) == (Vector{Symbol}(), (@foot A => N))
@test interpret_composition_notation(:(sf ^ A => N)) == ([:sf], (@foot A => N))
@test interpret_composition_notation(:(sf1, sf2 ^ A => N)) == ([:sf1,:sf2], (@foot A => N))
@test interpret_composition_notation(:(sf1, sf2 ^ A => N, A => NI)) == ([:sf1,:sf2], (@foot A => N, A => NI))
@test interpret_composition_notation(:(sf1, sf2, sf3, sf4 ^ () => NI)) == ([:sf1, :sf2, :sf3, :sf4], (@foot () => NI))
@test interpret_composition_notation(:(sf1, sf2 ^ L => ())) == ([:sf1,:sf2], (@foot L => ()))

@test interpret_composition_notation(:(sf1, sf2 ^ () => ())) == ([:sf1,:sf2], (@foot () => ()))

end

@testset "invalid composition expressions fail" begin
@test_throws AssertionError interpret_composition_notation(:(B => C))
@test_throws AssertionError interpret_composition_notation(:(A, B, C))
@test_throws AssertionError interpret_composition_notation(:(A ^ B ^ C))
@test_throws AssertionError interpret_composition_notation(:(A => B => C))
@test_throws ErrorException interpret_composition_notation(:(A ^ B => C => D)) # caught by create_foot
end

@testset "invalid sfcompose calls fail" begin
@test_throws AssertionError sfcompose([(@stock_and_flow begin; :stocks; A; end;), (@stock_and_flow begin; :stocks; A; end;)], quote
(sf1, sf2)
sf1, sf2 ^ () => ()
end) # not allowed to map to empty
@test_throws AssertionError sfcompose([(@stock_and_flow begin; :stocks; A; end;), (@stock_and_flow begin; :stocks; A; end;)], quote
(sf1, sf2)
sf1 ^ A => ()
sf2 ^ A => ()
end) # not allowed to map to the same foot twice
@test_throws AssertionError sfcompose([(@stock_and_flow begin; :stocks; A; end;), (@stock_and_flow begin; :stocks; A; end;)], quote
(sf1,)
sf1 ^ A => ()
end) # incorrect number of symbols on the first line in the quote
end

0 comments on commit 78cc1f3

Please sign in to comment.