diff --git a/src/StockFlow.jl b/src/StockFlow.jl index 3665ad51..cae235ed 100644 --- a/src/StockFlow.jl +++ b/src/StockFlow.jl @@ -9,7 +9,9 @@ funcDynam, flowVariableIndex, funcFlow, funcFlows, funcSV, funcSVs, TransitionMa vectorfield, funcFlowsRaw, funcFlowRaw, inflowsAll, outflowsAll,instock,outstock, stockssv, stocksv, svsv, svsstock, vsstock, vssv, svsstockAllF, vsstockAllF, vssvAllF, StockAndFlowUntyped, StockAndFlowFUntyped, StockAndFlowStructureUntyped, StockAndFlowStructureFUntyped, StockAndFlowUntyped0, Open, snames, fnames, svnames, vnames, object_shift_right, foot, leg, lsnames, OpenStockAndFlow, OpenStockAndFlowOb, fv, fvs, nlvv, nlpv, vtgt, vsrc, vpsrc, vptgt, pname, pnames, make_v_expr, -vop, lvvposition, lvtgtposition, lsvvposition, lpvvposition, recreate_stratified, set_snames!, set_fnames!, set_svnames!, set_vnames!, set_pnames!, set_sname!, set_fname!, set_svname!, set_vname!, set_pname! +vop, lvvposition, lvtgtposition, lsvvposition, lpvvposition, recreate_stratified, set_snames!, set_fnames!, set_svnames!, set_vnames!, set_pnames!, set_sname!, set_fname!, set_svname!, set_vname!, set_pname!, +get_lss, get_lssv, get_lsvsv, get_lsvv, get_lvs, get_lvv, get_is, get_ifn, get_os, get_ofn, get_lpvp, get_lpvv, get_lvsrc, get_lvtgt, get_links + using Catlab using Catlab.CategoricalAlgebra @@ -252,6 +254,23 @@ nlvv(p::AbstractStockAndFlowStructureF) = nparts(p,:LVV) #links from dynamic var nlpv(p::AbstractStockAndFlowStructureF) = nparts(p,:LPV) #links from dynamic variable to dynamic varibale np(p::AbstractStockAndFlowStructureF) = nparts(p,:P) #parameters + +get_lss(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lss].m)) +get_lssv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lssv].m)) +get_lsvsv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lsvsv].m)) +get_lsvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lsvv].m)) +get_lvs(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvs].m)) +get_lvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvv].m)) +get_is(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:is].m)) +get_ifn(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:ifn].m)) +get_os(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:os].m)) +get_ofn(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:ofn].m)) +get_lpvp(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lpvp].m)) +get_lpvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lpvv].m)) +get_lvsrc(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvsrc].m)) +get_lvtgt(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvtgt].m)) + + #EXAMPLE: #sir_StockAndFlow=StockAndFlow(((:S, 990)=>(:birth,(:inf,:deathS),(:v_inf,:v_deathS),:N), (:I, 10)=>(:inf,(:rec,:deathI),(:v_rec,:v_deathI,:v_fractionNonS),:N),(:R, 0)=>(:rec,:deathR,(:v_deathR,:v_fractionNonS),:N)), # (:birth=>:v_birth,:inf=>:v_inf,:rec=>:v_rec,:deathS=>:v_deathS,:deathI=>:v_deathI,:deathR=>:v_deathR), diff --git a/src/Syntax.jl b/src/Syntax.jl index ec5aa14f..60b3b1b0 100644 --- a/src/Syntax.jl +++ b/src/Syntax.jl @@ -103,11 +103,13 @@ end ``` """ module Syntax -export @stock_and_flow, @foot, @feet +export @stock_and_flow, @foot, @feet, infer_links using ..StockFlow using MLStyle +import Base: ==, Iterators.flatmap + """ stock_and_flow(block :: Expr) @@ -1030,6 +1032,214 @@ function match_foot_format(footblock::Expr) end end +############################################# + +""" + infer_particular_link!(sfsrc, sftgt, f1, f2, map1, map2, destination_vector, posf=nothing) + +infer_particular_link!(sfsrc, sftgt, get_lvs, get_lvv, stockmaps, dyvarmaps, lvmaps, get_lvvposition) # LV + +If we're mapping the same value to multiple positions, it doesn't matter which one goes where. +We have a few options, on how we want to distribute mappings. Way it's done here, always goes to the last position. + +""" +function infer_particular_link!(sfsrc, sftgt, f1, f2, map1, map2, destination_vector) + + hom1′_mappings = f1(sftgt) + hom2′_mappings = f2(sftgt) + tgt::Dict{Tuple{Int, Int}, Int} = Dict((hom1′, hom2′) => i for (i, (hom1′, hom2′)) in enumerate(zip(hom1′_mappings, hom2′_mappings))) # ISSUE: If there are two matches, second one overwrites the first. + # SOLUTION: Who cares. Just map to the last. + for (i, (hom1, hom2)) in enumerate(zip(f1(sfsrc), f2(sfsrc))) + mapped_index1 = map1[hom1] + mapped_index2 = map2[hom2] + + linkmap = tgt[(mapped_index1, mapped_index2)] + + + + destination_vector[i] = linkmap # updated + end + return destination_vector end + + +""" + infer_links(sfsrc :: StockAndFlowF, sftgt :: StockAndFlowF, NecMaps :: Dict{Symbol, Vector{Int64}}) + +Infer LS, I, O, LV, LSV, LVV, LPV mappings for an ACSetTransformation. +Returns dictionary of Symbols to lists of indices, corresponding to an ACSetTransformation argument. +If there exist no such mappings (eg, no LVV), that pairing will not be included in the returned dictionary. + +If A <- C -> B, and we have A -> A' and B -> B' and a unique C' such that A' <- C' -> B', we can assume C -> C'. + +:S => [2,4,1,3], :F => [1,2,4,3], ... + +NecMaps must contain keys S, F, SV, P, V, each pointing to a (possibly empty) array of indices +""" +function infer_links(sfsrc :: StockAndFlowF, sftgt :: StockAndFlowF, NecMaps :: Dict{Symbol, Vector{Int64}}) + + + stockmaps = NecMaps[:S] + flowmaps = NecMaps[:F] + summaps = NecMaps[:SV] + parammaps = NecMaps[:P] + dyvarmaps = NecMaps[:V] + + lsmaps = zeros(Int, nls(sfsrc)) + imaps = zeros(Int, ni(sfsrc)) + omaps = zeros(Int, no(sfsrc)) + lvmaps = zeros(Int, nlv(sfsrc)) + lsvmaps = zeros(Int, nlsv(sfsrc)) + lvvmaps = zeros(Int, nlvv(sfsrc)) + lpvmaps = zeros(Int, nlpv(sfsrc)) + # After the following calls, there should be no zeroes. + + + infer_particular_link!(sfsrc, sftgt, get_lss, get_lssv, stockmaps, summaps, lsmaps) # LS + infer_particular_link!(sfsrc, sftgt, get_ifn, get_is, flowmaps, stockmaps, imaps) # I + infer_particular_link!(sfsrc, sftgt, get_ofn, get_os, flowmaps, stockmaps, omaps) # O + infer_particular_link!(sfsrc, sftgt, get_lvs, get_lvv, stockmaps, dyvarmaps, lvmaps) # LV + infer_particular_link!(sfsrc, sftgt, get_lsvsv, get_lsvv, summaps, dyvarmaps, lsvmaps) # LSV + infer_particular_link!(sfsrc, sftgt, get_lvsrc, get_lvtgt, dyvarmaps, dyvarmaps, lvvmaps) # LVV + infer_particular_link!(sfsrc, sftgt, get_lpvp, get_lpvv, parammaps, dyvarmaps, lpvmaps) # LPV + + return Dict(:LS => lsmaps, :LSV => lsvmaps, :LV => lvmaps, :I => imaps, :O => omaps, :LPV => lpvmaps, :LVV => lvvmaps) + + +end + + +struct DSLArgument + key::Symbol + value::Symbol + flags::Set{Symbol} # At present, the only flag that exists is ~ + DSLArgument(kv::Pair{Union{Expr, Symbol}, Symbol}) = begin # this constructor seemed to fail... need to figure out why. Maybe it can't call other constructors. + key, flags = unwrap_expression(first(kv)) + new(key, second(kv), flags) + end + DSLArgument(k::Union{Expr, Symbol}, v::Symbol) = begin + key, flags = unwrap_expression(k) + new(key, v, flags) + end + DSLArgument(k::Symbol, v::Symbol, f::Set{Symbol}) = new(k, v, f) +end + +==(a::DSLArgument, b::DSLArgument) = a.key == b.key && a.value == b.value && a.flags == b.flags + + +function unwrap_expression(x::Union{Symbol, Expr}, flags::Set{Symbol}=Set{Symbol}())::Tuple{Symbol, Set{Symbol}} # No mutable default arguments. + if typeof(x) == Symbol + return (x, flags) + else + return unwrap_expression(x.args[2], push!(flags, x.args[1])) + end +end + + +""" +S₁ => I₁ +S₂ => I₂ +S₁ => S₂ + +⊢ + +I₁ => I₂ + +Determine what index an element e maps to based upon what f we have in the mapping such that e -> f +""" +function connect_by_value(; src::Dict{T,U}, mapping::Dict{T,T}, tgt::Dict{T,U})::Dict{U, U} where {T, U} + @assert allunique(values(src)) + + @assert all(x -> x ∈ keys(mapping), keys(src)) + @assert all(x -> x ∈ keys(tgt), values(mapping)) + + return Dict(src[key] => tgt[value] for (key, value) in mapping) + +end + + +""" +Filter a vector for all elements with substr as a substring. +""" +function substring_matches(v::Vector, substr::String)::Vector + return filter(x -> occursin(substr, string(x)), v) +end + + +""" +Takes a symbol 'key', applys flags, finds matches in s, and returns a vector of matching keys. +Currently, there are two options: no flags, in which case [key] is returned, or ~ is the only flag, in which case all Symbols with matching substrings are returned. +""" +function apply_flags(key::Symbol, flags::Set{Symbol}, s::Vector{Symbol})::Vector{Symbol} # Could make this a generator? + + if isempty(flags) + @assert (key ∈ s) "$s does not contain key $key ! Did you forget to prefix ~?" + return [key] # potentially inefficient + elseif :~ ∈ flags + + matches = collect(substring_matches(s, string(key))) + + new_flags = copy(flags) # copy isn't necessary, probably + pop!(new_flags, :~) + + return collect(flatmap(x -> apply_flags(x, new_flags, s), matches)) # this is just in case we add additional flags. As is, the recursion is unnecessary. + else + error("Unknown flag found! $(flags)") + end +end + +""" + substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + +Convert Dict(SymA => IntA), Dict(SymB => IntB), Dict(SymA => SymB) into Dict{IntA => IntB} +Using original sf defintions, and the user defined mappings, transform user defined symbol mappings to index mappings. +""" +function substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + if !use_flags + mapping = Dict(arg.key => arg.value for arg in m) + return connect_by_value(src=s, mapping=mapping, tgt=t) + else + master_dict::Dict{Int, Int} = Dict() + for statement in m + key_matches = apply_flags(statement.key, statement.flags, collect(keys(s))) # Vector of Symbol + if isempty(key_matches) + println("WARNING! No matches on $(statement.key) with flags $(statement.flags)") + else + mergewith!((x...) -> first(x), master_dict, Dict(s[match] => t[statement.value] for match ∈ key_matches)) + end + end + return master_dict + end +end + + +""" +Convert a vector of unique elements to a dictionary with each element pointing to their original index. +""" +function invert_vector(v::Vector{K})::Dict{K, Int} where {K} # Elements of v must be hashable + new_dict = Dict(val => i for (i, val) ∈ enumerate(v)) + @assert length(new_dict) == length(v) "Nonunique key in vector v: $v" + return new_dict +end + + +""" +Takes any arguments and returns nothing. +Used so we can maintain equality when making ACSetTransformations. +""" +NothingFunction(x...)::Nothing = nothing; + + + + + +include("syntax/Stratification.jl") + +end + + + + + diff --git a/src/syntax/Stratification.jl b/src/syntax/Stratification.jl new file mode 100755 index 00000000..4235b2eb --- /dev/null +++ b/src/syntax/Stratification.jl @@ -0,0 +1,357 @@ +module Stratification +export sfstratify, @stratify + +using ...StockFlow +using ..Syntax +using MLStyle +import Base: get +using Catlab.CategoricalAlgebra +import ..Syntax: infer_links, substitute_symbols, DSLArgument, NothingFunction, invert_vector + + +""" + interpret_stratification_notation(mapping_pair::Expr)::Tuple{Vector{DSLArgument}, Vector{DSLArgument}} +Take an expression of the form a1, ..., => t <= s1, ..., where every element is a symbol, and return a 2-tuple of form ((a1 => t, a2 => t, ...), (s1 => t, ...)) +""" +function interpret_stratification_notation(mapping_pair::Expr)::Tuple{Vector{DSLArgument}, Vector{DSLArgument}} + @match mapping_pair begin + + + :($s => $t <= $a) => return ([DSLArgument(s,t)], [DSLArgument(a,t)]) + :($s => $t <= $a, $(atail...)) => ([DSLArgument(s,t)], [DSLArgument(a,t) ; [DSLArgument(as,t) for as in atail] ])#return (Dict(unwrap_key_expression(s, t)), push!(Dict(unwrap_key_expression(as, t) for as in atail), unwrap_key_expression(a, t))) + :($(shead...), $s => $t <= $a) => ([[DSLArgument(ss, t) for ss in shead] ; DSLArgument(s, t)], [DSLArgument(a, t)])#return (push!(Dict(unwrap_key_expression(ss, t) for ss in shead), unwrap_key_expression(s, t)), Dict(unwrap_key_expression(a, t))) + + if mapping_pair.head == :tuple end => begin + middle_index = findfirst(x -> typeof(x) == Expr && length(x.args) == 3, mapping_pair.args) # still isn't specific enough + if isnothing(middle_index) + error("Malformed line $mapping_pair, could not find center.") + end + @match mapping_pair.args[middle_index] begin + :($stail => $t <= $ahead) => begin + sdict = [[DSLArgument(ss, t) for ss in mapping_pair.args[1:middle_index-1]] ; DSLArgument(stail, t)] + adict = [DSLArgument(ahead, t) ; [DSLArgument(as, t) for as in mapping_pair.args[middle_index+1:end]]] + return (sdict, adict) + end + _ => "Unknown format found for match; middle three values formatted incorrectly." + end + end + _ => error("Unknown line format found in stratification notation.") + end +end + + + +""" +Gets mapping information from each line and updates dictionaries. If a symbol already has a mapping and another is found, keep the first, or throw an error if strict_matches = true. +""" +function read_stratification_line_and_update_dictionaries!(line::Expr, strata_names::Dict{Symbol, Int}, type_names::Dict{Symbol, Int}, aggregate_names::Dict{Symbol, Int}, strata_mappings::Dict{Int, Int}, aggregate_mappings::Dict{Int, Int} ; strict_matches = false, use_flags = true) + current_strata_symbol_dict, current_aggregate_symbol_dict = interpret_stratification_notation(line) + + current_strata_dict = substitute_symbols(strata_names, type_names, current_strata_symbol_dict ; use_flags=use_flags) + current_aggregate_dict = substitute_symbols(aggregate_names, type_names, current_aggregate_symbol_dict ; use_flags=use_flags) + + if strict_matches + @assert (all(x -> x ∉ keys(strata_mappings), keys(current_strata_dict))) "Attempt to overwrite a mapping in strata!" + # check that we're not overwriting a value which has already been assigned + merge!(strata_mappings, current_strata_dict) # accumulate dictionary keys + + + @assert (all(x -> x ∉ keys(aggregate_mappings), keys(current_aggregate_dict))) "Attempt to overwrite a mapping in aggregate!" + merge!(aggregate_mappings, current_aggregate_dict) + + else + mergewith!((x, y) -> x, strata_mappings, current_strata_dict) # alternatively, can use: only ∘ first + mergewith!((x, y) -> x, aggregate_mappings, current_aggregate_dict) + end + +end + +""" +Print all symbols such that the corresponding int is 0, representing an unmapped object. +""" +function print_unmapped(mappings::Vector{Pair{Vector{Int}, Vector{Symbol}}}, name="STOCKFLOW") + for (ints, dicts) in mappings + for (i, val) in enumerate(ints) + if val == 0 + println("UNMAPPED IN $(name):") + println(dicts[i]) + end + end + end +end + +""" +Iterates over each line in a stratification syntax block and updates the appropriate dictionaries. +""" +function iterate_over_stratification_lines!(block, strata_names, type_names, aggregate_names, strata_mappings, aggregate_mappings; strict_matches=false, use_flags=true) + current_phase = (_, _) -> () + for statement in block.args + @match statement begin + QuoteNode(:stocks) => begin + current_phase = s -> read_stratification_line_and_update_dictionaries!(s, strata_names[1], type_names[1], aggregate_names[1], strata_mappings[1], aggregate_mappings[1]; strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:sums) => begin + current_phase = sv -> read_stratification_line_and_update_dictionaries!(sv, strata_names[2], type_names[2], aggregate_names[2], strata_mappings[2], aggregate_mappings[2]; strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:dynamic_variables) => begin + current_phase = v -> read_stratification_line_and_update_dictionaries!(v, strata_names[3], type_names[3], aggregate_names[3], strata_mappings[3], aggregate_mappings[3]; strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:flows) => begin + current_phase = f -> read_stratification_line_and_update_dictionaries!(f, strata_names[4], type_names[4], aggregate_names[4], strata_mappings[4], aggregate_mappings[4]; strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:parameters) => begin + current_phase = p -> read_stratification_line_and_update_dictionaries!(p, strata_names[5], type_names[5], aggregate_names[5], strata_mappings[5], aggregate_mappings[5]; strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(kw) => + error("Unknown block type for stratify syntax: " * String(kw)) + _ => current_phase(statement) + end + end +end + +""" +Apply default mappings, infer mapping if there's only a single option, and convert from Dict{Int, Int} to Vector{Int} +""" +function complete_mappings(strata_all_index_mappings::Vector{Dict{Int, Int}}, aggregate_all_index_mappings::Vector{Dict{Int, Int}}, sfstrata::AbstractStockAndFlowF, sftype::AbstractStockAndFlowF, sfaggregate::AbstractStockAndFlowF; strict_mappings = false) + # get the default value, if it has been assigned. Use 0 if it hasn't. + default_index_strata_stock = get(strata_all_index_mappings[1], -1, 0) + default_index_strata_sum = get(strata_all_index_mappings[2], -1, 0) + default_index_strata_dyvar = get(strata_all_index_mappings[3], -1, 0) + default_index_strata_flow = get(strata_all_index_mappings[4], -1, 0) + default_index_strata_param = get(strata_all_index_mappings[5], -1, 0) + + default_index_aggregate_stock = get(aggregate_all_index_mappings[1], -1, 0) + default_index_aggregate_sum = get(aggregate_all_index_mappings[2], -1, 0) + default_index_aggregate_dyvar = get(aggregate_all_index_mappings[3], -1, 0) + default_index_aggregate_flow = get(aggregate_all_index_mappings[4], -1, 0) + default_index_aggregate_param = get(aggregate_all_index_mappings[5], -1, 0) + + + # STEP 3 + if !strict_mappings + one_type_stock = length(snames(sftype)) == 1 ? 1 : 0 # if there is only one stock, it needs to have index 1 + one_type_flow = length(fnames(sftype)) == 1 ? 1 : 0 + one_type_dyvar = length(vnames(sftype)) == 1 ? 1 : 0 + one_type_param = length(pnames(sftype)) == 1 ? 1 : 0 + one_type_sum = length(svnames(sftype)) == 1 ? 1 : 0 + else + one_type_stock = one_type_flow = one_type_dyvar = one_type_param = one_type_sum = 0 + end + + # Convert back to vectors. If you find a zero, check if there's a default and use that. If there isn't a default, check if there's only one option and use that. + # Otherwise, there's an unassigned value which can't be inferred. + # Taking max because it's less verbose than ternary and accomplishes the same thing: + # - If both default_index and one_type are mapped, they must be mapped to the same thing, because one_type being mapped implies there's only one option. + # - If only one_type is mapped, then it will be positive, and default_infex will be 0 + # - If only default_index is mapped, it will be positive and one_type will be 0 + + strata_stock_mappings::Vector{Int} = [get(strata_all_index_mappings[1], i, max(default_index_strata_stock, one_type_stock)) for i in 1:ns(sfstrata)] + strata_sum_mappings::Vector{Int} = [get(strata_all_index_mappings[2], i, max(default_index_strata_sum, one_type_sum)) for i in 1:nsv(sfstrata)] + strata_dyvar_mappings::Vector{Int} = [get(strata_all_index_mappings[3], i, max(default_index_strata_dyvar, one_type_dyvar)) for i in 1:nvb(sfstrata)] + strata_flow_mappings::Vector{Int} = [get(strata_all_index_mappings[4], i, max(default_index_strata_flow, one_type_flow)) for i in 1:nf(sfstrata)] + strata_param_mappings::Vector{Int} = [get(strata_all_index_mappings[5], i, max(default_index_strata_param, one_type_param)) for i in 1:np(sfstrata)] + + aggregate_stock_mappings::Vector{Int} = [get(aggregate_all_index_mappings[1], i, max(default_index_aggregate_stock, one_type_stock)) for i in 1:ns(sfaggregate)] + aggregate_sum_mappings::Vector{Int} = [get(aggregate_all_index_mappings[2], i, max(default_index_aggregate_sum, one_type_sum)) for i in 1:nsv(sfaggregate)] + aggregate_dyvar_mappings::Vector{Int} = [get(aggregate_all_index_mappings[3], i, max(default_index_aggregate_dyvar, one_type_dyvar)) for i in 1:nvb(sfaggregate)] + aggregate_flow_mappings::Vector{Int} = [get(aggregate_all_index_mappings[4], i, max(default_index_aggregate_flow, one_type_flow)) for i in 1:nf(sfaggregate)] + aggregate_param_mappings::Vector{Int} = [get(aggregate_all_index_mappings[5], i, max(default_index_aggregate_param, one_type_param)) for i in 1:np(sfaggregate)] + + + return ((strata_stock_mappings, strata_sum_mappings, strata_dyvar_mappings, strata_flow_mappings, strata_param_mappings), (aggregate_stock_mappings, aggregate_sum_mappings, aggregate_dyvar_mappings, aggregate_flow_mappings, aggregate_param_mappings)) +end + + +""" + sfstratify(strata, type, aggregate, block ; kwargs) + + 1. Grab all names from strata, type and aggregate, and create dictionaries which map them to their indices + 2. Iterate over each line in the block + 2a. Split each line into a dictionary which maps all strata to that type and all aggregate to that type + 2b. Convert from two Symbol => Symbol dictionaries to two Int => Int dictionaries, using the dictionaries from step 1 + 2bα. If applicable, for symbols with ~ as a prefix, find all symbols with matching substrings in the symbol dictionaries, and map all those + 2c. Accumulate respective dictionaries (optionally, only allow first match vs throw an error (strict_matches = false vs true)) + 3. Create an array of 0s for stocks, flows, parameters, dyvars and sums for strata and aggregate. Insert into arrays all values from the two Int => Int dictionaries + 3a. If strict_mappings = false, if there only exists one option in type to map to, and it hasn't been explicitly specified, add it. If strict_mappings = true and it hasn't been specified, throw an error. + 4. Do a once-over of arrays and ensure there aren't any zeroes (unmapped values) remaining (helps with debugging when you screw up stratifying) + 5. Deal with attributes (create a copy of type sf with attributes mapped to nothing) + 6. Infer LS, LSV, etc. + 7. Construct strata -> type and aggregate -> type ACSetTransformations + 8. Return pullback (with flattened attributes) +""" +function sfstratify(strata::AbstractStockAndFlowStructureF, type::AbstractStockAndFlowStructureF, aggregate::AbstractStockAndFlowStructureF, block::Expr ; strict_mappings = false, strict_matches = false, temp_strat_default = :_, use_temp_strat_default = true, use_flags = true, return_homs = false) + + Base.remove_linenums!(block) + + # STEP 1 + + # invert_vector: Vector{K} -> Dict{K, Int} where int is original index and all K (symbols, in this case) are unique. + strata_snames::Dict{Symbol, Int} = invert_vector(snames(strata)) + strata_svnames::Dict{Symbol, Int} = invert_vector(svnames(strata)) + strata_vnames::Dict{Symbol, Int} = invert_vector(vnames(strata)) + strata_fnames::Dict{Symbol, Int} = invert_vector(fnames(strata)) + strata_pnames::Dict{Symbol, Int} = invert_vector(pnames(strata)) + + type_snames::Dict{Symbol, Int} = invert_vector(snames(type)) + type_svnames::Dict{Symbol, Int} = invert_vector(svnames(type)) + type_vnames::Dict{Symbol, Int} = invert_vector(vnames(type)) + type_fnames::Dict{Symbol, Int} = invert_vector(fnames(type)) + type_pnames::Dict{Symbol, Int} = invert_vector(pnames(type)) + + aggregate_snames::Dict{Symbol, Int} = invert_vector(snames(aggregate)) + aggregate_svnames::Dict{Symbol, Int} = invert_vector(svnames(aggregate)) + aggregate_vnames::Dict{Symbol, Int} = invert_vector(vnames(aggregate)) + aggregate_fnames::Dict{Symbol, Int} = invert_vector(fnames(aggregate)) + aggregate_pnames::Dict{Symbol, Int} = invert_vector(pnames(aggregate)) + + + strata_stock_mappings_dict::Dict{Int, Int} = Dict() + strata_flow_mappings_dict::Dict{Int, Int} = Dict() + strata_dyvar_mappings_dict::Dict{Int, Int} = Dict() + strata_param_mappings_dict::Dict{Int, Int} = Dict() + strata_sum_mappings_dict::Dict{Int, Int} = Dict() + + aggregate_stock_mappings_dict::Dict{Int, Int} = Dict() + aggregate_flow_mappings_dict::Dict{Int, Int} = Dict() + aggregate_dyvar_mappings_dict::Dict{Int, Int} = Dict() + aggregate_param_mappings_dict::Dict{Int, Int} = Dict() + aggregate_sum_mappings_dict::Dict{Int, Int} = Dict() + + + strata_all_name_mappings::Vector{Dict{Symbol, Int}} = [strata_snames, strata_svnames, strata_vnames, strata_fnames, strata_pnames] + type_all_name_mappings::Vector{Dict{Symbol, Int}} = [type_snames, type_svnames, type_vnames, type_fnames, type_pnames] + aggregate_all_name_mappings::Vector{Dict{Symbol, Int}} = [aggregate_snames, aggregate_svnames, aggregate_vnames, aggregate_fnames, aggregate_pnames] + + strata_all_index_mappings::Vector{Dict{Int, Int}} = [strata_stock_mappings_dict, strata_sum_mappings_dict, strata_dyvar_mappings_dict, strata_flow_mappings_dict, strata_param_mappings_dict] + aggregate_all_index_mappings::Vector{Dict{Int, Int}} = [aggregate_stock_mappings_dict, aggregate_sum_mappings_dict, aggregate_dyvar_mappings_dict, aggregate_flow_mappings_dict, aggregate_param_mappings_dict] + + + if use_temp_strat_default + + strata_all_names::Vector{Vector{Symbol}} = [snames(strata), svnames(strata), vnames(strata), fnames(strata), pnames(strata)] + aggregate_all_names::Vector{Vector{Symbol}} = [snames(aggregate), svnames(aggregate), vnames(aggregate), fnames(aggregate), pnames(aggregate)] + + @assert all(x -> temp_strat_default ∉ keys(x), strata_all_names) "Strata contains $temp_strat_default ! Please change temp_strat_default to a different symbol or rename offending object." + @assert all(x -> temp_strat_default ∉ keys(x), aggregate_all_names) "Aggregate contains $temp_strat_default ! Please change temp_strat_default to a different symbol or rename offending object." + + map(x -> (push!(x, (temp_strat_default => -1))), strata_all_name_mappings) + map(x -> (push!(x, (temp_strat_default => -1))), aggregate_all_name_mappings) + end + + # STEP 2 + iterate_over_stratification_lines!(block, strata_all_name_mappings, type_all_name_mappings, aggregate_all_name_mappings, strata_all_index_mappings, aggregate_all_index_mappings ; strict_matches=strict_matches, use_flags=use_flags) + + + strata_mappings, aggregate_mappings = complete_mappings(strata_all_index_mappings, aggregate_all_index_mappings, strata, type, aggregate ; strict_mappings=strict_mappings) + + strata_stock_mappings, strata_sum_mappings, strata_dyvar_mappings, strata_flow_mappings, strata_param_mappings = strata_mappings + aggregate_stock_mappings, aggregate_sum_mappings, aggregate_dyvar_mappings, aggregate_flow_mappings, aggregate_param_mappings = aggregate_mappings + + + all_mappings = [strata_stock_mappings..., strata_flow_mappings..., strata_dyvar_mappings..., strata_param_mappings..., strata_sum_mappings..., aggregate_stock_mappings..., aggregate_flow_mappings..., aggregate_dyvar_mappings..., aggregate_param_mappings..., aggregate_sum_mappings...] + + + # STEP 4 + + # This bit makes debugging when making a stratification easier. Tells you exactly which ones you forgot to map. + + #unmapped: + if !(all(x -> x != 0, all_mappings)) + strata_mappings_to_names::Vector{Pair{Vector{Int}, Vector{Symbol}}} = [strata_stock_mappings => snames(strata), strata_flow_mappings => fnames(strata), strata_dyvar_mappings => vnames(strata), strata_param_mappings => pnames(strata), strata_sum_mappings => svnames(strata)] + aggregate_mappings_to_names::Vector{Pair{Vector{Int}, Vector{Symbol}}} = [aggregate_stock_mappings => snames(aggregate), aggregate_flow_mappings => fnames(aggregate), aggregate_dyvar_mappings => vnames(aggregate), aggregate_param_mappings => pnames(aggregate), aggregate_sum_mappings => svnames(aggregate)] + print_unmapped(strata_mappings_to_names, "STRATA") + print_unmapped(aggregate_mappings_to_names, "AGGREGATE") + error("There is an unmapped value!") + end + + + # STEP 5 + # NothingFunction(x...) = nothing; + no_attribute_type = map(type, Name=NothingFunction, Op=NothingFunction, Position=NothingFunction) + + # STEP 6/7 + # This is where we pull out the magic to infer links. + # + # A <- C -> B + # || || + # v v + # A'<- C'-> B' + # + # implies + # + # A <- C -> B + # || || || + # v v v + # A'<- C'-> B' + # + + strata_necmaps = Dict(:S => strata_stock_mappings, :F => strata_flow_mappings, :V => strata_dyvar_mappings, :P => strata_param_mappings, :SV => strata_sum_mappings) + strata_inferred_links = infer_links(strata, type, strata_necmaps) + strata_to_type = ACSetTransformation(strata, no_attribute_type; strata_necmaps..., strata_inferred_links..., Op = NothingFunction, Position = NothingFunction, Name = NothingFunction) + + + aggregate_necmaps = Dict(:S => aggregate_stock_mappings, :F => aggregate_flow_mappings, :V => aggregate_dyvar_mappings, :P => aggregate_param_mappings, :SV => aggregate_sum_mappings) + aggregate_inferred_links = infer_links(aggregate, type, aggregate_necmaps) + aggregate_to_type = ACSetTransformation(aggregate, no_attribute_type; aggregate_necmaps..., aggregate_inferred_links..., Op = NothingFunction, Position = NothingFunction, Name =NothingFunction) + + + + # STEP 8 + pullback_model = pullback(strata_to_type, aggregate_to_type) |> apex |> rebuildStratifiedModelByFlattenSymbols; + + if return_homs + return pullback_model, strata_to_type, aggregate_to_type + else + return pullback_model + end + +end + +""" + stratify(strata, type, aggregate, block) +Take three stockflows and a block describing where the first and third map on to the second, and get a new stratified stockflow. +Left side are strata objects, middle are type, right are aggregate. Each strata and aggregate object is associated with one type object. +The resultant stockflow contains objects which are the product of strata and aggregate objects which map to the same type object. +Use _ to match all objects in that category, ~ as a prefix to match all objects with the following string as a substring. Objects always go with their first match. +If the type model has a single object in a category, the mapping to it is automatically assumed. In the below example, we wouldn't need to specify :stocks or :sums. + +```julia + +@stratify WeightModel l_type ageWeightModel begin + :stocks + _ => pop <= _ + + :flows + ~Death => f_death <= ~Death + ~id => f_aging <= ~aging + ~Becoming => f_fstOrder <= ~id + _ => f_birth <= f_NB + + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + ~Death => v_death <= ~Death + ~id => v_aging <= v_agingCA, v_agingAS + v_BecomingOverWeight, v_BecomingObese => v_fstOrder <= v_idC, v_idA, v_idS + + :parameters + μ => μ <= μ + δw, δo => δ <= δC, δA, δS + rw, ro => rFstOrder <= r + rage => rage <= rageCA, rageAS + + :sums + N => N <= N + +end +``` +""" +macro stratify(strata, type, aggregate, block) + escaped_block = Expr(:quote, block) + quote + sfstratify($(esc(strata)), $(esc(type)), $(esc(aggregate)), $(esc(escaped_block))) + end +end + + +end \ No newline at end of file diff --git a/test/Stratification.jl b/test/Stratification.jl new file mode 100755 index 00000000..5a74915b --- /dev/null +++ b/test/Stratification.jl @@ -0,0 +1,403 @@ +using StockFlow.Syntax.Stratification + +using StockFlow.Syntax.Stratification: interpret_stratification_notation +using StockFlow.Syntax: NothingFunction, DSLArgument, unwrap_expression, substitute_symbols + +using Catlab.WiringDiagrams +using Catlab.ACSets +using Catlab.CategoricalAlgebra + + + + +@testset "Pullback computed in standard way is equal to DSL pullbacks" begin + + + l_type = @stock_and_flow begin + :stocks + pop + + :parameters + μ + δ + rFstOrder + rage + + :dynamic_variables + v_aging = pop * rage + v_fstOrder = pop * rFstOrder + v_birth = N * μ + v_death = pop * δ + + :flows + pop => f_aging(v_aging) => pop + pop => f_fstOrder(v_fstOrder) => pop + CLOUD => f_birth(v_birth) => pop + pop => f_death(v_death) => CLOUD + + :sums + N = [pop] + + end; + l_type_noatts = map(l_type, Name=NothingFunction, Op=NothingFunction, Position=NothingFunction); + + + WeightModel = @stock_and_flow begin + :stocks + NormalWeight + OverWeight + Obese + + :parameters + μ + δw + rw + ro + δo + rage + + :dynamic_variables + v_NewBorn = N * μ + v_DeathNormalWeight = NormalWeight * δw + v_BecomingOverWeight = NormalWeight * rw + v_DeathOverWeight = OverWeight * δw + v_BecomingObese = OverWeight * ro + v_DeathObese = Obese * δo + v_idNW = NormalWeight * rage + v_idOW = OverWeight * rage + v_idOb = Obese * rage + + :flows + CLOUD => f_NewBorn(v_NewBorn) => NormalWeight + NormalWeight => f_DeathNormalWeight(v_DeathNormalWeight) => ClOUD + NormalWeight => f_BecomingOverWeight(v_BecomingOverWeight) => OverWeight + OverWeight => f_DeathOverWeight(v_DeathOverWeight) => CLOUD + + OverWeight => f_BecomingObese(v_BecomingObese) => Obese + Obese => f_DeathObese(v_DeathObese) => CLOUD + NormalWeight => f_idNW(v_idNW) => NormalWeight + OverWeight => f_idOW(v_idOW) => OverWeight + Obese => f_idOb(v_idOb) => Obese + + :sums + N = [NormalWeight, OverWeight, Obese] + + end; + + + ageWeightModel = @stock_and_flow begin + :stocks + Child + Adult + Senior + + :parameters + μ + δC + δA + δS + rageCA + rageAS + r + + :dynamic_variables + v_NB = N * μ + v_DeathC = Child * δC + v_idC = Child * r + v_agingCA = Child * rageCA + v_DeathA = Adult * δA + v_idA = Adult * r + v_agingAS = Adult * rageAS + v_DeathS = Senior * δS + v_idS = Senior * r + + :flows + CLOUD => f_NB(v_NB) => Child + Child => f_idC(v_idC) => Child + Child => f_DeathC(v_DeathC) => CLOUD + Child => f_agingCA(v_agingCA) => Adult + Adult => f_idA(v_idA) => Adult + Adult => f_DeathA(v_DeathA) => CLOUD + Adult => f_agingAS(v_agingAS) => Senior + Senior => f_idS(v_idS) => Senior + Senior => f_DeathS(v_DeathS) => CLOUD + + :sums + N = [Child, Adult, Senior] + + end; + + begin + s, = parts(l_type, :S) + N, = parts(l_type, :SV) + lsn, = parts(l_type, :LS) + f_aging, f_fstorder, f_birth, f_death = parts(l_type, :F) + i_aging, i_fstorder, i_birth = parts(l_type, :I) + o_aging, o_fstorder, o_death = parts(l_type, :O) + v_aging, v_fstorder, v_birth, v_death = parts(l_type, :V) + lv_aging1, lv_fstorder1, lv_death1 = parts(l_type, :LV) + lsv_birth1, = parts(l_type, :LSV) + p_μ, p_δ, p_rfstOrder, p_rage = parts(l_type, :P) + lpv_aging2, lpv_fstorder2, lpv_birth2, lpv_death2 = parts(l_type, :LPV) + end; + + typed_WeightModel=ACSetTransformation(WeightModel, l_type_noatts, + S = [s,s,s], + SV = [N], + LS = [lsn,lsn,lsn], + F = [f_birth, f_death, f_fstorder, f_death, f_fstorder, f_death, f_aging, f_aging, f_aging], + I = [i_birth, i_aging, i_fstorder, i_aging, i_fstorder, i_aging], + O = [o_death, o_fstorder, o_aging, o_death, o_fstorder, o_aging, o_death, o_aging], + V = [v_birth, v_death, v_fstorder, v_death, v_fstorder, v_death, v_aging, v_aging, v_aging], + LV = [lv_death1, lv_fstorder1, lv_death1, lv_fstorder1, lv_death1, lv_aging1, lv_aging1, lv_aging1], + LSV = [lsv_birth1], + P = [p_μ, p_δ, p_rfstOrder, p_rfstOrder, p_δ, p_rage], + LPV = [lpv_birth2, lpv_death2, lpv_fstorder2, lpv_death2, lpv_fstorder2, lpv_death2, lpv_aging2, lpv_aging2, lpv_aging2], + Name=NothingFunction, Op=NothingFunction, Position=NothingFunction + ); + @assert is_natural(typed_WeightModel); + + + + typed_ageWeightModel=ACSetTransformation(ageWeightModel, l_type_noatts, + S = [s,s,s], + SV = [N], + LS = [lsn,lsn,lsn], + F = [f_birth, f_fstorder, f_death, f_aging, f_fstorder, f_death, f_aging, f_fstorder, f_death], + I = [i_birth, i_fstorder, i_aging, i_fstorder, i_aging, i_fstorder], + O = [o_fstorder, o_death, o_aging, o_fstorder, o_death, o_aging, o_fstorder, o_death], + V = [v_birth, v_death, v_fstorder, v_aging, v_death, v_fstorder, v_aging, v_death, v_fstorder], + LV = [lv_death1, lv_fstorder1, lv_aging1, lv_death1, lv_fstorder1, lv_aging1, lv_death1, lv_fstorder1], + LSV = [lsv_birth1], + P = [p_μ, p_δ, p_δ, p_δ, p_rage, p_rage, p_rfstOrder], + LPV = [lpv_birth2, lpv_death2, lpv_fstorder2, lpv_aging2, lpv_death2, lpv_fstorder2, lpv_aging2, lpv_death2, lpv_fstorder2], + Name =NothingFunction, Op=NothingFunction, Position=NothingFunction + ); + @assert is_natural(typed_ageWeightModel); + + aged_weight = pullback(typed_WeightModel, typed_ageWeightModel) |> apex |> rebuildStratifiedModelByFlattenSymbols; + + # ######################################### + + age_weight_2 = @stratify WeightModel l_type ageWeightModel begin + :stocks + NormalWeight, OverWeight, Obese => pop <= Child, Adult, Senior + + :flows + f_NewBorn => f_birth <= f_NB + f_DeathNormalWeight, f_DeathOverWeight, f_DeathObese => f_death <= f_DeathC, f_DeathA, f_DeathS + f_idNW, f_idOW, f_idOb => f_aging <= f_agingCA, f_agingAS + f_BecomingOverWeight, f_BecomingObese => f_fstOrder <= f_idC, f_idA, f_idS + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + v_DeathNormalWeight, v_DeathOverWeight, v_DeathObese => v_death <= v_DeathC, v_DeathA, v_DeathS + v_idNW, v_idOW, v_idOb => v_aging <= v_agingCA, v_agingAS + v_BecomingOverWeight, v_BecomingObese => v_fstOrder <= v_idC, v_idA, v_idS + + :parameters + μ => μ <= μ + δw, δo => δ <= δC, δA, δS + rw, ro => rFstOrder <= r + rage => rage <= rageCA, rageAS + + :sums + N => N <= N + + end + ######################################### + + age_weight_3 = @stratify WeightModel l_type ageWeightModel begin + + :flows + f_NewBorn => f_birth <= f_NB + ~Death => f_death <= ~Death + ~id => f_aging <= ~aging + ~Becoming => f_fstOrder <= ~id + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + ~Death => v_death <= ~Death + ~id => v_aging <= ~aging + ~Becoming => v_fstOrder <= ~id + + :parameters + μ => μ <= μ + ~δ => δ <= ~δ + rage => rage <= rageCA, rageAS + _ => rFstOrder <= _ + + end + + age_weight_4 = @stratify WeightModel l_type ageWeightModel begin + + :flows + ~NO_MATCHES => f_birth <= ~NO_MATCHES + f_NewBorn => f_birth <= f_NB + ~Death => f_death <= ~Death + ~id => f_aging <= ~aging + ~Becoming => f_fstOrder <= ~id + ~Becoming => f_aging <= ~id # Everything already matched; ignored + _ => f_aging <= _ # also ignored + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + ~Death => v_death <= ~Death + ~id => v_aging <= ~aging + _ => v_fstOrder <= _ + + :parameters + μ => μ <= μ + ~δ => δ <= ~δ + rage => rage <= rageCA, rageAS + _ => rFstOrder <= _ + + end + + + + + @test aged_weight == age_weight_2 + @test aged_weight == age_weight_3 + @test aged_weight == age_weight_4 +end + +@testset "Ensuring interpret_stratification_notation correctly reads lines" begin # This should be all valid cases. There's always going to be at least one value on both sides. + # Note the orders. The lists produced go left to right. A1, A2 => B <= C1, C2 results in [A1 => B, A2 => B], [C1 => B. C2 => B] + + + @test interpret_stratification_notation(:(A => B <= C)) == ([DSLArgument(:A, :B, Set{Symbol}())], [DSLArgument(:C, :B, Set{Symbol}())]) + + @test interpret_stratification_notation(:(A1, A2 => B <= C)) == ( + [DSLArgument(:A1, :B, Set{Symbol}()), DSLArgument(:A2, :B, Set{Symbol}())], + [DSLArgument(:C, :B, Set{Symbol}())] + ) + @test interpret_stratification_notation(:(A => B <= C1, C2)) == ( + [DSLArgument(:A, :B, Set{Symbol}())], + [DSLArgument(:C1, :B, Set{Symbol}()), DSLArgument(:C2, :B, Set{Symbol}())], + ) + @test interpret_stratification_notation(:(_ => B <= _)) == ( + [DSLArgument(:_, :B, Set{Symbol}())], + [DSLArgument(:_, :B, Set{Symbol}())], + ) + @test interpret_stratification_notation(:(~A => B <= ~C)) == ( + [DSLArgument(:A, :B, Set{Symbol}([:~]))], + [DSLArgument(:C, :B, Set{Symbol}([:~]))], + ) + @test interpret_stratification_notation(:(~A1, A2 => B <= ~C)) == ( + [DSLArgument(:A1, :B, Set{Symbol}([:~])), DSLArgument(:A2, :B, Set{Symbol}())], + [DSLArgument(:C, :B, Set{Symbol}([:~]))], + ) + + @test interpret_stratification_notation(:(~_ => B <= ~_, C)) == ( # Weird case. Matches everything with _ as a substring. + [DSLArgument(:_, :B, Set{Symbol}([:~]))], + [DSLArgument(:_, :B, Set{Symbol}([:~])), DSLArgument(:C, :B, Set{Symbol}())] + ) + +end + + +@testset "Unwrapping expressions works correctly" begin + @test unwrap_expression(:S) == (:S, Set{Symbol}()) + @test unwrap_expression(:(~S)) == (:S, Set{Symbol}([:~])) + @test unwrap_expression(:(~_)) == (:_, Set{Symbol}([:~])) +end + + +# function substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + +@testset "Testing substituting symbols" begin # underscore matching occurs at the very end, after this step. + s1 = Dict(:A => 1) + t1 = Dict(:B => 2) + m1₁ = [DSLArgument(:A, :B, Set{Symbol}())] + m1₂ = [DSLArgument(:A, :B, Set{Symbol}([:~]))] + + @test substitute_symbols(s1, t1, m1₁) == Dict(1 => 2) # A=>B -> 1=>2 + @test substitute_symbols(s1, t1, m1₂) == Dict(1 => 2) # A=>B -> 1=>2 + @test substitute_symbols(s1, t1, m1₂, use_flags=false) == Dict(1 => 2) # A=>B -> 1=>2 + + s2 = Dict(:A1 => 10, :A2 => 20, :A3 => 30) # Unfortunately, cannot do substring matches starting with numbers, since it would require a symbol starting with a numbre. Might need to add something for this... + t2 = Dict(:B1 => 1, :B2 => 2) + m2₁ = [DSLArgument(:A, :B1, Set{Symbol}([:~]))] + + @test substitute_symbols(s2, t2, m2₁) == Dict(10 => 1, 20 => 1, 30 => 1) #~A=>B -> 10=>1, 20=>1, 30=>1 + # @test substitute_symbols(s2, t2, m2₁, use_flags=false) == Dict() # deliberately throws an error + + s3 = Dict{Symbol, Int}() + t3 = Dict{Symbol, Int}() + m3 = Vector{DSLArgument}() + + @test substitute_symbols(s3, t3, m3) == Dict() + @test substitute_symbols(s3, t3, m3, use_flags=false) == Dict() + + s4 = Dict(:A1 => 1, :A2 => 2, :AB3 => 3, :AB4 => 4, :A5 => 5) + t4 = Dict(:B1 => 1, :B2 => 2, :B3 => 3) + m4 = [DSLArgument(:A1, :B1, Set{Symbol}()), DSLArgument(:B, :B2, Set{Symbol}([:~])), DSLArgument(:A, :B3, Set{Symbol}([:~]))] + + # always goes with first match. A1 is taken, B matches AB3 and AB4, then A matches A2 and A5 + @test substitute_symbols(s4, t4, m4) == Dict(1 => 1, 3 => 2, 4 => 2, 2 => 3, 5 => 3) +end + + +@testset "nondefault flags work as expected" begin + A_ = (@stock_and_flow begin + :stocks + A + _ + end) + + X_ = (@stock_and_flow begin + :stocks + X + _ + end) + + B_ = (@stock_and_flow begin + :stocks + B + _ + end) + + strat_AXB = (quote # Note, we use a quote when calling the function, begin when calling the macro. + :stocks + _ => _ <= _ + A => X <= B + ~A => X <= ~B # everything is already assigned, so does nothing (or throws error if strict_matches is true) + end) + + + sfA = (@stock_and_flow begin; :stocks; A; end;) + + @test (sfstratify(A_, X_, B_, strat_AXB, use_temp_strat_default=false) + == (@stock_and_flow begin + :stocks + AB + __ + end)) + + # doesn't show up anywhere, so doesn't affect anything. Could also set it to something untypable in the DSL, like Symbol("") + @test (sfstratify(A_, X_, B_, strat_AXB, temp_strat_default=:ABABABABA) + == (@stock_and_flow begin + :stocks + AB + __ + end)) + + @test_throws AssertionError (sfstratify(A_, X_, B_, strat_AXB, strict_matches=true)) # A matches against A and ~A, which is disallowed with this flag. + + @test_throws ErrorException (sfstratify(sfA,sfA,sfA,(quote end) ; strict_mappings=true)) # strict_mappings=false wouldn't throw an error, and would infer strata and aggregate need to map to the only stock. + + + nothing_sfA = map(sfA, Position=NothingFunction, Op=NothingFunction, Name=NothingFunction) + + @test (sfstratify(sfA,sfA,sfA,(quote end), return_homs=true) == ( + (@stock_and_flow begin + :stocks + AA + end), + ACSetTransformation(sfA, nothing_sfA ; S=[1], F=Vector{Int}(),V =Vector{Int}(),SV=Vector{Int}(),P=Vector{Int}(),LS=Vector{Int}(),I=Vector{Int}(),O=Vector{Int}(),LV=Vector{Int}(),LSV=Vector{Int}(),LVV=Vector{Int}(),LPV=Vector{Int}(), Position=NothingFunction, Op=NothingFunction, Name=NothingFunction), # strata -> type + ACSetTransformation(sfA, nothing_sfA ; S=[1], F=Vector{Int}(),V =Vector{Int}(),SV=Vector{Int}(),P=Vector{Int}(),LS=Vector{Int}(),I=Vector{Int}(),O=Vector{Int}(),LV=Vector{Int}(),LSV=Vector{Int}(),LVV=Vector{Int}(),LPV=Vector{Int}(), Position=NothingFunction, Op=NothingFunction, Name=NothingFunction) # aggregate -> type + )) # the empty lists are necessary for equality, but it'd still be an equivalent homomorphism if you didn't specify them. + +end + diff --git a/test/Syntax.jl b/test/Syntax.jl old mode 100644 new mode 100755 index 36df5324..bbcbf76a --- a/test/Syntax.jl +++ b/test/Syntax.jl @@ -2,7 +2,11 @@ using Base: is_unary_and_binary_operator using Test using StockFlow using StockFlow.Syntax -using StockFlow.Syntax: is_binop_or_unary, sum_variables, infix_expression_to_binops, fnone_value_or_vector, extract_function_name_and_args_expr, is_recursive_dyvar, create_foot +using StockFlow.Syntax: is_binop_or_unary, sum_variables, infix_expression_to_binops, fnone_value_or_vector, extract_function_name_and_args_expr, is_recursive_dyvar, create_foot, apply_flags, substitute_symbols + +@testset "Stratification DSL" begin + include("Stratification.jl") +end @testset "is_binop_or_unary recognises binops" begin @test is_binop_or_unary(:(a + b)) @@ -337,5 +341,209 @@ end @test_throws Exception @eval @feet begin A => B; 1 => 2; end end +########################### + +@testset "infer_links works as expected" begin + # No prior mappings means no inferred mappings + @test (infer_links(StockAndFlowF(), StockAndFlowF(), Dict{Symbol, Vector{Int64}}(:S => [], :F => [], :SV => [], :P => [], :V => [])) + == Dict(:LS => [], :LSV => [], :LV => [], :I => [], :O => [], :LPV => [], :LVV => [])) + + # S: 1 -> 1 and SV: 1 -> 1 implies LS: 1 -> 1 + @test (infer_links( + (@stock_and_flow begin; :stocks; A; :sums; NA = [A]; end), + (@stock_and_flow begin; :stocks; B; :sums; NB = [B]; end), + Dict{Symbol, Vector{Int64}}(:S => [1], :F => [], :SV => [1], :P => [], :V => [])) + == Dict(:LS => [1], :LSV => [], :LV => [], :I => [], :O => [], :LPV => [], :LVV => [])) + + # annoying exanmple, required me to add code to disambiguate using position + # that is, vA = A + A, vA -> vB, A -> implies that the As in the vA definition map to the Bs in the vB definition + # But both As link to the same stock and dynamic variable so just looking at those isn't enough to figure out what it maps to. + # There will exist cases where it's impossible to tell - eg, when there exist multiple duplicate links, and some positions don't match up. + + # It does not currently look at the operator. You could therefore map vA = A + A -> vB = B * B + # I can see this being useful, actually, specifically when mapping between + and -, * and /, etc. Probably logs and powers too. + # Just need to be aware that it won't say it's invalid. + @test (infer_links( + (@stock_and_flow begin; :stocks; A; :dynamic_variables; vA = A + A; end), + (@stock_and_flow begin; :stocks; B; :dynamic_variables; vB = B + B; end), + Dict{Symbol, Vector{Int64}}(:S => [1], :F => [], :SV => [], :P => [], :V => [1])) + == Dict(:LS => [], :LSV => [], :LV => [2,2], :I => [], :O => [], :LPV => [], :LVV => [])) # If duplicate values, always map to end. + + @test (infer_links( + (@stock_and_flow begin; :stocks; A; :parameters; pA; :dynamic_variables; vA = A + pA; end), + (@stock_and_flow begin; :stocks; B; :parameters; pB; :dynamic_variables; vB = pB + B; end), + Dict{Symbol, Vector{Int64}}(:S => [1], :F => [], :SV => [], :P => [1], :V => [1])) + == Dict(:LS => [], :LSV => [], :LV => [1], :I => [], :O => [], :LPV => [1], :LVV => [])) + + @test (infer_links( + (@stock_and_flow begin + :stocks + S + I + R + + :parameters + p_inf + p_rec + + + :flows + S => f_StoI(p_inf * S) => I + I => f_ItoR(I * p_rec) => R + + :sums + N = [S,I,R] + NI = [I] + NS = [S,I,R] + end), + (@stock_and_flow begin + :stocks + pop + + :parameters + p_generic + + + :flows + pop => f_generic(p_generic * pop) => pop + + :sums + N = [pop] + NI = [pop] + NS = [pop] + end), + + Dict{Symbol, Vector{Int64}}(:S => [1,1,1], :F => [1,1], :SV => [1,2,3], :P => [1,1], :V => [1,1])) + == Dict(:LS => [1,3,1,2,3,1,3], :LSV => [], :LV => [1,1], :I => [1,1], :O => [1,1], :LPV => [1,1], :LVV => [])) + + +end + + +@testset "Applying flags can correctly find substring matches" begin + @test apply_flags(:f_, Set([:~]), Vector{Symbol}()) == [] + @test apply_flags(:f_, Set([:~]), [:f_death, :f_birth]) == [:f_death, :f_birth] + @test apply_flags(:NOMATCH, Set([:~]), [:f_death, :f_birth]) == [] + @test apply_flags(:f_birth, Set([:~]), [:f_death, :f_birth]) == [:f_birth] + @test apply_flags(:f_birth, Set{Symbol}(), [:f_death, :f_birth]) == [:f_birth] + + # Note, apply_flags is specifically meant to work on vectors without duplicates; the vector which is input are the keys of a dictionary. + # Regardless, the following will hold: + @test apply_flags(:f_birth, Set{Symbol}(), [:f_death, :f_birth, :f_birth, :f_birth]) == [:f_birth] + @test apply_flags(:f_birth, Set{Symbol}([:~]), [:f_death, :f_birth, :f_birth, :f_birth]) == [:f_birth, :f_birth, :f_birth] +end + + +@testset "substitute_symbols will correctly associate values of the two provided dictionaries based on user defined mappings" begin + # substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + + + # Note, these dictionaries represent a vector where all the entries are unique, and the values are the original indices. + # So, both keys and values should be unique. + # For stratification, first dictionary is strata or aggregate, second is type, and the vector of DSLArgument are the user-defined maps. + # For homomorphism, first argument is src, second is dest, vector are user-defined maps. + @test substitute_symbols(Dict{Symbol, Int}(), Dict{Symbol, Int}(), Vector{DSLArgument}()) == Dict{Int, Int}() + @test substitute_symbols(Dict{Symbol, Int}(), Dict(:B => 2), Vector{DSLArgument}()) == Dict{Int, Int}() + + @test substitute_symbols(Dict(:A => 1), Dict(:B => 1), [DSLArgument(:A, :B, Set{Symbol}())]) == Dict(1 => 1) + @test substitute_symbols(Dict(:A1 => 1, :A2 => 2), Dict(:B => 1), [DSLArgument(:A1, :B, Set{Symbol}()), DSLArgument(:A2, :B, Set{Symbol}())]) == Dict(1 => 1, 2 => 1) + @test substitute_symbols(Dict(:A1 => 1), Dict(:B1 => 1, :B2 => 2), [DSLArgument(:A1, :B2, Set{Symbol}())]) == Dict(1 => 2) + + + @test substitute_symbols(Dict(:A1 => 1, :A2 => 2), Dict(:B1 => 1, :B2 => 2), [DSLArgument(:A, :B2, Set{Symbol}([:~]))]) == Dict(1 => 2, 2 => 2) + + # 1:100 + # 1:50 + # All multiples x of 14 below 100 go to x % 10 + 1 + @test (substitute_symbols(Dict(Symbol(i) => i for i ∈ 1:100), Dict(Symbol(-i) => i for i ∈ 1:50), [DSLArgument(Symbol(i), Symbol(-((i%10) + 1)), Set{Symbol}()) for i ∈ 1:100 if i % 14 == 0]) + == Dict(14 => 5, 28 => 9, 42 => 3, 56 => 7, 70 => 1, 84 => 5, 98 => 9)) + + # Captures everything with a 7 as a digit + @test (substitute_symbols(Dict(Symbol(i) => i for i ∈ 1:100), Dict(Symbol(-i) => i for i ∈ 1:50), [DSLArgument(Symbol(7), Symbol(-1), Set{Symbol}([:~]))]) + == Dict(7 => 1, 17 => 1, 27 => 1, 37 => 1, 47 => 1, 57 => 1, 67 => 1, 70 => 1, 71 => 1, 72 => 1, 73 => 1, 74 => 1, 75 => 1, 76 => 1, 77 => 1, 78 => 1, 79 => 1, 87 => 1, 97 => 1)) + + @test substitute_symbols(Dict(Symbol("~") => 1), Dict(:R => 1), [DSLArgument(Symbol("~"), :R, Set([:~]))], ; use_flags = false) == Dict(1 => 1) # Note, the Set([:~]) is ignored because use_flags is false + +end + + +@testset "non-natural transformations fail infer_links" begin + + # Map both dynamic variables to the same + # Obviously, this will fail, as the new dynamic variable needs a LVV and one LV, but instead has two LV + @test_throws KeyError (infer_links( + (@stock_and_flow begin + :stocks + A + + :dynamic_variables + v1 = A + A + v2 = v1 + A + end), + (@stock_and_flow begin + :stocks + A + + :dynamic_variables + v1 = A + A + end), + Dict{Symbol, Vector{Int64}}(:S => [1], :V => [1,1]))) + + + + # Mapping it all to I + + # This one fails when trying to figure out the inflow. Stock maps to 2, and flow maps to 2, + # But inflows on the target have (1,2) and (2,3) + + # This also wouldn't work if we tried mapping flow to 1 instead. Outflows expect 1,1 or 2,2, + # so it fails on (2,1). + @test_throws KeyError (infer_links( + (@stock_and_flow begin + :stocks + pop + + :parameters + p_generic + + + :flows + pop => f_generic(p_generic * pop) => pop + + :sums + N = [pop] + NI = [pop] + NS = [pop] + end), + (@stock_and_flow begin + :stocks + S + I + R + + :parameters + p_inf + p_rec + + + :flows + S => f_StoI(p_inf * S) => I + I => f_ItoR(I * p_rec) => R + + :sums + N = [S,I,R] + NI = [I] + NS = [S,I,R] + end), + Dict{Symbol, Vector{Int64}}(:S => [2], :F => [2], :SV => [1,2,3], :P => [2], :V => [2]))) + +end + +@testset "Applying flags throws on invalid inputs" begin + @test_throws ErrorException apply_flags(:f_, Set([:+]), [:f_death, :f_birth]) # fails because :+ is not a defined operation + @test_throws ErrorException apply_flags(:f_birth, Set([:~, :+]), [:f_death, :f_birth]) # also fails for same reason + @test_throws AssertionError apply_flags(:NOMATCH, Set{Symbol}(), Vector{Symbol}()) # fails because it's not looking for substrings, and :NOMATCH isn't in the list of options. + @test_throws AssertionError apply_flags(:NOMATCH, Set{Symbol}(), [:nomatch]) # same reason +end \ No newline at end of file