diff --git a/src/StockFlow.jl b/src/StockFlow.jl index 1cd477c2..ebe439e2 100644 --- a/src/StockFlow.jl +++ b/src/StockFlow.jl @@ -9,7 +9,9 @@ funcDynam, flowVariableIndex, funcFlow, funcFlows, funcSV, funcSVs, TransitionMa vectorfield, funcFlowsRaw, funcFlowRaw, inflowsAll, outflowsAll,instock,outstock, stockssv, stocksv, svsv, svsstock, vsstock, vssv, svsstockAllF, vsstockAllF, vssvAllF, StockAndFlowUntyped, StockAndFlowFUntyped, StockAndFlowStructureUntyped, StockAndFlowStructureFUntyped, StockAndFlowUntyped0, Open, snames, fnames, svnames, vnames, object_shift_right, foot, leg, lsnames, OpenStockAndFlow, OpenStockAndFlowOb, fv, fvs, nlvv, nlpv, vtgt, vsrc, vpsrc, vptgt, pname, pnames, make_v_expr, -vop, lvvposition, lvtgtposition, lsvvposition, lpvvposition, recreate_stratified, set_snames!, set_fnames!, set_svnames!, set_vnames!, set_pnames!, set_sname!, set_fname!, set_svname!, set_vname!, set_pname! +vop, lvvposition, lvtgtposition, lsvvposition, lpvvposition, recreate_stratified, set_snames!, set_fnames!, set_svnames!, set_vnames!, set_pnames!, set_sname!, set_fname!, set_svname!, set_vname!, set_pname!, +get_lss, get_lssv, get_lsvsv, get_lsvv, get_lvs, get_lvv, get_is, get_ifn, get_os, get_ofn, get_lpvp, get_lpvv, get_lvsrc, get_lvtgt, get_links + using Catlab using Catlab.CategoricalAlgebra @@ -268,6 +270,20 @@ nlpv(p::AbstractStockAndFlowStructureF) = nparts(p,:LPV) #links from dynamic var np(p::AbstractStockAndFlowStructureF) = nparts(p,:P) #parameters +get_lss(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lss].m)) +get_lssv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lssv].m)) +get_lsvsv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lsvsv].m)) +get_lsvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lsvv].m)) +get_lvs(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvs].m)) +get_lvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvv].m)) +get_is(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:is].m)) +get_ifn(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:ifn].m)) +get_os(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:os].m)) +get_ofn(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:ofn].m)) +get_lpvp(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lpvp].m)) +get_lpvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lpvv].m)) +get_lvsrc(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvsrc].m)) +get_lvtgt(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvtgt].m)) #EXAMPLE: diff --git a/src/Syntax.jl b/src/Syntax.jl index 1d94fabe..7b77f456 100644 --- a/src/Syntax.jl +++ b/src/Syntax.jl @@ -103,11 +103,13 @@ end ``` """ module Syntax -export @stock_and_flow, @foot, @feet +export @stock_and_flow, @foot, @feet, infer_links using ..StockFlow using MLStyle +import Base: ==, Iterators.flatmap + """ stock_and_flow(block :: Expr) @@ -1044,7 +1046,215 @@ function match_foot_format(footblock::Expr) end end +############################################# + +""" + infer_particular_link!(sfsrc, sftgt, f1, f2, map1, map2, destination_vector, posf=nothing) + +infer_particular_link!(sfsrc, sftgt, get_lvs, get_lvv, stockmaps, dyvarmaps, lvmaps, get_lvvposition) # LV + +If we're mapping the same value to multiple positions, it doesn't matter which one goes where. +We have a few options, on how we want to distribute mappings. Way it's done here, always goes to the last position. + +""" +function infer_particular_link!(sfsrc, sftgt, f1, f2, map1, map2, destination_vector) + + hom1′_mappings = f1(sftgt) + hom2′_mappings = f2(sftgt) + + tgt::Dict{Tuple{Int, Int}, Int} = Dict((hom1′, hom2′) => i for (i, (hom1′, hom2′)) in enumerate(zip(hom1′_mappings, hom2′_mappings))) # ISSUE: If there are two matches, second one overwrites the first. + # SOLUTION: Who cares. Just map to the last. + for (i, (hom1, hom2)) in enumerate(zip(f1(sfsrc), f2(sfsrc))) + mapped_index1 = map1[hom1] + mapped_index2 = map2[hom2] + + linkmap = tgt[(mapped_index1, mapped_index2)] + + + + destination_vector[i] = linkmap # updated + end + return destination_vector + +end + + +""" + infer_links(sfsrc :: StockAndFlowF, sftgt :: StockAndFlowF, NecMaps :: Dict{Symbol, Vector{Int64}}) + +Infer LS, I, O, LV, LSV, LVV, LPV mappings for an ACSetTransformation. +Returns dictionary of Symbols to lists of indices, corresponding to an ACSetTransformation argument. +If there exist no such mappings (eg, no LVV), that pairing will not be included in the returned dictionary. + +If A <- C -> B, and we have A -> A' and B -> B' and a unique C' such that A' <- C' -> B', we can assume C -> C'. + +:S => [2,4,1,3], :F => [1,2,4,3], ... + +NecMaps must contain keys S, F, SV, P, V, each pointing to a (possibly empty) array of indices +""" +function infer_links(sfsrc :: StockAndFlowF, sftgt :: StockAndFlowF, NecMaps :: Dict{Symbol, Vector{Int64}}) + + + stockmaps = NecMaps[:S] + flowmaps = NecMaps[:F] + summaps = NecMaps[:SV] + parammaps = NecMaps[:P] + dyvarmaps = NecMaps[:V] + + lsmaps = zeros(Int, nls(sfsrc)) + imaps = zeros(Int, ni(sfsrc)) + omaps = zeros(Int, no(sfsrc)) + lvmaps = zeros(Int, nlv(sfsrc)) + lsvmaps = zeros(Int, nlsv(sfsrc)) + lvvmaps = zeros(Int, nlvv(sfsrc)) + lpvmaps = zeros(Int, nlpv(sfsrc)) + # After the following calls, there should be no zeroes. + + + infer_particular_link!(sfsrc, sftgt, get_lss, get_lssv, stockmaps, summaps, lsmaps) # LS + infer_particular_link!(sfsrc, sftgt, get_ifn, get_is, flowmaps, stockmaps, imaps) # I + infer_particular_link!(sfsrc, sftgt, get_ofn, get_os, flowmaps, stockmaps, omaps) # O + infer_particular_link!(sfsrc, sftgt, get_lvs, get_lvv, stockmaps, dyvarmaps, lvmaps) # LV + infer_particular_link!(sfsrc, sftgt, get_lsvsv, get_lsvv, summaps, dyvarmaps, lsvmaps) # LSV + infer_particular_link!(sfsrc, sftgt, get_lvsrc, get_lvtgt, dyvarmaps, dyvarmaps, lvvmaps) # LVV + infer_particular_link!(sfsrc, sftgt, get_lpvp, get_lpvv, parammaps, dyvarmaps, lpvmaps) # LPV + + return Dict(:LS => lsmaps, :LSV => lsvmaps, :LV => lvmaps, :I => imaps, :O => omaps, :LPV => lpvmaps, :LVV => lvvmaps) include("syntax/Composition.jl") end + + +struct DSLArgument + key::Symbol + value::Symbol + flags::Set{Symbol} # At present, the only flag that exists is ~ + DSLArgument(kv::Pair{Union{Expr, Symbol}, Symbol}) = begin # this constructor seemed to fail... need to figure out why. Maybe it can't call other constructors. + key, flags = unwrap_expression(first(kv)) + new(key, second(kv), flags) + end + DSLArgument(k::Union{Expr, Symbol}, v::Symbol) = begin + key, flags = unwrap_expression(k) + new(key, v, flags) + end + DSLArgument(k::Symbol, v::Symbol, f::Set{Symbol}) = new(k, v, f) +end + +==(a::DSLArgument, b::DSLArgument) = a.key == b.key && a.value == b.value && a.flags == b.flags + + +function unwrap_expression(x::Union{Symbol, Expr}, flags::Set{Symbol}=Set{Symbol}())::Tuple{Symbol, Set{Symbol}} # No mutable default arguments. + if typeof(x) == Symbol + return (x, flags) + else + return unwrap_expression(x.args[2], push!(flags, x.args[1])) + end +end + + +""" +S₁ => I₁ +S₂ => I₂ +S₁ => S₂ + +⊢ + +I₁ => I₂ + +Determine what index an element e maps to based upon what f we have in the mapping such that e -> f +""" +function connect_by_value(; src::Dict{T,U}, mapping::Dict{T,T}, tgt::Dict{T,U})::Dict{U, U} where {T, U} + @assert allunique(values(src)) + + @assert all(x -> x ∈ keys(mapping), keys(src)) + @assert all(x -> x ∈ keys(tgt), values(mapping)) + + return Dict(src[key] => tgt[value] for (key, value) in mapping) + +end + + +""" +Filter a vector for all elements with substr as a substring. +""" +function substring_matches(v::Vector, substr::String)::Vector + return filter(x -> occursin(substr, string(x)), v) +end + + +""" +Takes a symbol 'key', applys flags, finds matches in s, and returns a vector of matching keys. +Currently, there are two options: no flags, in which case [key] is returned, or ~ is the only flag, in which case all Symbols with matching substrings are returned. +""" +function apply_flags(key::Symbol, flags::Set{Symbol}, s::Vector{Symbol})::Vector{Symbol} # Could make this a generator? + + if isempty(flags) + @assert (key ∈ s) "$s does not contain key $key ! Did you forget to prefix ~?" + return [key] # potentially inefficient + elseif :~ ∈ flags + + matches = collect(substring_matches(s, string(key))) + + new_flags = copy(flags) # copy isn't necessary, probably + pop!(new_flags, :~) + + return collect(flatmap(x -> apply_flags(x, new_flags, s), matches)) # this is just in case we add additional flags. As is, the recursion is unnecessary. + else + error("Unknown flag found! $(flags)") + end +end + +""" + substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + +Convert Dict(SymA => IntA), Dict(SymB => IntB), Dict(SymA => SymB) into Dict{IntA => IntB} +Using original sf defintions, and the user defined mappings, transform user defined symbol mappings to index mappings. +""" +function substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + if !use_flags + mapping = Dict(arg.key => arg.value for arg in m) + return connect_by_value(src=s, mapping=mapping, tgt=t) + else + master_dict::Dict{Int, Int} = Dict() + for statement in m + key_matches = apply_flags(statement.key, statement.flags, collect(keys(s))) # Vector of Symbol + if isempty(key_matches) + println("WARNING! No matches on $(statement.key) with flags $(statement.flags)") + else + mergewith!((x...) -> first(x), master_dict, Dict(s[match] => t[statement.value] for match ∈ key_matches)) + end + end + return master_dict + end +end + + +""" +Convert a vector of unique elements to a dictionary with each element pointing to their original index. +""" +function invert_vector(v::Vector{K})::Dict{K, Int} where {K} # Elements of v must be hashable + new_dict = Dict(val => i for (i, val) ∈ enumerate(v)) + @assert length(new_dict) == length(v) "Nonunique key in vector v: $v" + return new_dict +end + + +""" +Takes any arguments and returns nothing. +Used so we can maintain equality when making ACSetTransformations. +""" +NothingFunction(x...)::Nothing = nothing; + + + + +include("syntax/Composition.jl") +include("syntax/Stratification.jl") + +end + + + + + diff --git a/src/SystemStructure.jl b/src/SystemStructure.jl index 01f9a2d3..9e036f56 100644 --- a/src/SystemStructure.jl +++ b/src/SystemStructure.jl @@ -212,7 +212,7 @@ extracVAndAttrStructureAndFlatten(p::AbstractStockAndFlowF) = begin if nvb(p)>0 for v in 1:nvb(p) vn = flattenTupleNames(vname(p,v)) - v_op = foldr(==,vop(p,v)) ? vop(p,v)[1] : error("operators $(vop(p,v)) in the stratified model's auxiliary variable: $(join(vname(p,v))) should be the same!") + v_op = allequal(vop(p,v)) ? vop(p,v)[1] : error("operators $(vop(p,v)) in the stratified model's auxiliary variable: $(join(vname(p,v))) should be the same!") vnp = vn=>(args(p,v)=>v_op) vs = vcat(vs,vnp) end diff --git a/src/syntax/Stratification.jl b/src/syntax/Stratification.jl new file mode 100755 index 00000000..fad2b345 --- /dev/null +++ b/src/syntax/Stratification.jl @@ -0,0 +1,441 @@ +module Stratification +export sfstratify, @stratify, @n_stratify + +using ...StockFlow +using ..Syntax +using MLStyle +import Base: get +using Catlab.CategoricalAlgebra +import ..Syntax: infer_links, substitute_symbols, DSLArgument, NothingFunction, invert_vector + + + +struct SFNames + + sf::AbstractStockAndFlowF + + snames::Vector{Symbol} + svnames::Vector{Symbol} + vnames::Vector{Symbol} + fnames::Vector{Symbol} + pnames::Vector{Symbol} + + # name -> index + s::Dict{Symbol, Int} + sv::Dict{Symbol, Int} + v::Dict{Symbol, Int} + f::Dict{Symbol, Int} + p::Dict{Symbol, Int} + + # index -> new index + ms::Dict{Int, Int} + msv::Dict{Int, Int} + mv::Dict{Int, Int} + mf::Dict{Int, Int} + mp::Dict{Int, Int} + + # index -> new index, where the first index is the actual index of the vector, the second is the int at that location + mvs::Vector{Int} + mvsv::Vector{Int} + mvv::Vector{Int} + mvf::Vector{Int} + mvp::Vector{Int} + + + SFNames(sfarg::AbstractStockAndFlowF) = (new(sfarg, + snames(sfarg), svnames(sfarg), vnames(sfarg), fnames(sfarg), pnames(sfarg), + invert_vector(snames(sfarg)), invert_vector(svnames(sfarg)), invert_vector(vnames(sfarg)), invert_vector(fnames(sfarg)), invert_vector(pnames(sfarg)), + Dict{Int, Int}(), Dict{Int, Int}(), Dict{Int, Int}(), Dict{Int, Int}(), Dict{Int, Int}(), + Vector{Int}(),Vector{Int}(), Vector{Int}(), Vector{Int}(), Vector{Int}())) +end + +function get_mappings(sfn::SFNames)::NTuple{5, Dict{Int, Int}} + return sfn.ms, sfn.msv, sfn.mv, sfn.mf, sfn.mp +end + +function get_mapped_vectors(sfn::SFNames)::NTuple{5, Vector{Int}} + return sfn.mvs, sfn.mvsv, sfn.mvv, sfn.mvf, sfn.mvp +end + +function get_mappings_infer_links_format(sfn::SFNames)::Dict{Symbol, Vector{Int}} + Dict(:S => sfn.mvs, :SV => sfn.mvsv, :V => sfn.mvv, :F => sfn.mvf, :P => sfn.mvp) +end + +function all_unique_names(sfn::SFNames)::Bool # Unnecessary, this is checked in invert_vector + return allunique(sfn.snames) && allunique(sfn.svnames) && allunique(vnames) && allunique(fnames) && allunique(pnames) +end + +function no_temp_strat_default_in_names(sfn::SFNames, temp_strat_default)::Bool + return temp_strat_default ∉ keys(sfn.s) && temp_strat_default ∉ keys(sfn.sv) && temp_strat_default ∉ keys(sfn.v) && temp_strat_default ∉ keys(sfn.f) && temp_strat_default ∉ keys(sfn.p) +end + +function add_temp_strat_default!(sfn::SFNames, temp_strat_default) + push!(sfn.s, temp_strat_default => -1) + push!(sfn.sv, temp_strat_default => -1) + push!(sfn.v, temp_strat_default => -1) + push!(sfn.f, temp_strat_default => -1) + push!(sfn.p, temp_strat_default => -1) +end + +function is_all_mapped(sfn::SFNames)::Bool + return all(vec -> 0 ∉ vec, get_mapped_vectors(sfn)) +end + +function get_names(sfn::SFNames)::NTuple{5, Vector{Symbol}} + return sfn.snames, sfn.svnames, sfn.vnames, sfn.fnames, sfn.pnames +end + + + + +""" + interpret_stratification_standard_notation(mapping_pair::Expr)::Tuple{Vector{DSLArgument}, Vector{DSLArgument}} +Take an expression of the form a1, ..., => t <= s1, ..., where every element is a symbol, and return a 2-tuple of form ((a1 => t, a2 => t, ...), (s1 => t, ...)) +""" +function interpret_stratification_standard_notation(mapping_pair::Expr)::Vector{Vector{DSLArgument}} + @match mapping_pair begin + + + :($s => $t <= $a) => return [[DSLArgument(s,t)], [DSLArgument(a,t)]] + :($s => $t <= $a, $(atail...)) => [[DSLArgument(s,t)], [DSLArgument(a,t) ; [DSLArgument(as,t) for as in atail] ]] + :($(shead...), $s => $t <= $a) => [[[DSLArgument(ss, t) for ss in shead] ; DSLArgument(s, t)], [DSLArgument(a, t)]] + + if mapping_pair.head == :tuple end => begin + middle_index = findfirst(x -> typeof(x) == Expr && length(x.args) == 3, mapping_pair.args) # still isn't specific enough + if isnothing(middle_index) + error("Malformed line $mapping_pair, could not find center.") + end + @match mapping_pair.args[middle_index] begin + :($stail => $t <= $ahead) => begin + sdict = [[DSLArgument(ss, t) for ss in mapping_pair.args[1:middle_index-1]] ; DSLArgument(stail, t)] + adict = [DSLArgument(ahead, t) ; [DSLArgument(as, t) for as in mapping_pair.args[middle_index+1:end]]] + return [sdict, adict] + end + _ => "Unknown format found for match; middle three values formatted incorrectly." + end + end + _ => error("Unknown line format found in stratification notation.") + end +end + + + +function interpret_stratification_generalized_notation(mapping_pair::Expr)::Vector{Vector{DSLArgument}} + # asserts are covered before this function is called. + other = mapping_pair.args[2].args # needs to be a vector of tuples of symbols + type = mapping_pair.args[3] # needs to be a symbol + return [((typeof(tup) == Expr) && (tup.head == :tuple)) ? [DSLArgument(sym, type) for sym ∈ tup.args] : [DSLArgument(tup, type)] for tup ∈ other] +end + + + + + +""" +Gets mapping information from each line and updates dictionaries. If a symbol already has a mapping and another is found, keep the first, or throw an error if strict_matches = true. +""" +function read_stratification_line_and_update_dictionaries!(line::Expr, other_names::Vector{Dict{Symbol, Int}}, type_names::Dict{Symbol, Int}, other_mappings::Vector{Dict{Int, Int}} ; use_standard_stratification_syntax = true, strict_matches = false, use_flags = true) + if use_standard_stratification_syntax + interpret_stratification_notation_function = interpret_stratification_standard_notation + else + interpret_stratification_notation_function = interpret_stratification_generalized_notation + + # need to do this here, since we know the number of other_mappings at this point, but not in the interpret_stratification_notation + @assert length(line.args) == 3 + @assert typeof(line.args[3]) == Symbol + @assert line.args[1] == :(=>) + @assert length(line.args[2].args) == length(other_names) + @assert all(tup -> typeof(tup) == Symbol || tup.args[1] == :~ || tup.head == :tuple, line.args[2].args) # every element of the vector is an expression of tuple. + @assert all(tup -> typeof(tup) == Symbol || tup.args[1] == :~ || all(sym -> typeof(sym) <: Union{Symbol, Expr}, tup.args), line.args[2].args) # ensure all arguments in the tuples are expressions or symbols. + # In the future, if we have additional flags, may need to check for them as well. + # These asserts are a bit sloppy + end + + current_symbol_dict::Vector{Vector{DSLArgument}} = interpret_stratification_notation_function(line) + + current_mapping_dict::Vector{Dict{Int, Int}} = ((x, y) -> substitute_symbols(x,type_names, y; use_flags=use_flags)).(other_names, current_symbol_dict) + + ((cumulative_dict, new_dict) -> mergewith!((cv, nv) -> cv, cumulative_dict, new_dict)).(other_mappings, current_mapping_dict) + +end + +""" +Print all symbols such that the corresponding int is 0, representing an unmapped object. +""" +function print_unmapped(SFNames, name="STOCKFLOW") + for (indices, names) ∈ zip(SFNames.get_mapped_vectors, SFNames.get_names) + for (i, val) ∈ enumerate(indices) + if val == 0 + println("UNMAPPED IN $(name):") + println(names[i]) + end + end + end +end + +""" +Iterates over each line in a stratification syntax block and updates the appropriate dictionaries. +""" +function iterate_over_stratification_lines!(block, other_names::Vector{SFNames}, type_names::SFNames; use_standard_stratification_syntax=true, strict_matches=false, use_flags=true) + + current_phase = (_, _) -> () + for statement in block.args + @match statement begin + QuoteNode(:stocks) => begin + current_phase = s -> read_stratification_line_and_update_dictionaries!(s, (getfield.(other_names, :s))::Vector{Dict{Symbol, Int}}, type_names.s, (getfield.(other_names, :ms))::Vector{Dict{Int, Int}}; use_standard_stratification_syntax=use_standard_stratification_syntax, strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:sums) => begin + current_phase = sv -> read_stratification_line_and_update_dictionaries!(sv, (getfield.(other_names, :sv))::Vector{Dict{Symbol, Int}}, type_names.sv, (getfield.(other_names, :msv))::Vector{Dict{Int, Int}}; use_standard_stratification_syntax=use_standard_stratification_syntax, strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:dynamic_variables) => begin + current_phase = v -> read_stratification_line_and_update_dictionaries!(v, (getfield.(other_names, :v))::Vector{Dict{Symbol, Int}}, type_names.v, (getfield.(other_names, :mv))::Vector{Dict{Int, Int}}; use_standard_stratification_syntax=use_standard_stratification_syntax, strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:flows) => begin + current_phase = f -> read_stratification_line_and_update_dictionaries!(f, (getfield.(other_names, :f))::Vector{Dict{Symbol, Int}}, type_names.f, (getfield.(other_names, :mf))::Vector{Dict{Int, Int}}; use_standard_stratification_syntax=use_standard_stratification_syntax, strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(:parameters) => begin + current_phase = p -> read_stratification_line_and_update_dictionaries!(p, (getfield.(other_names, :p))::Vector{Dict{Symbol, Int}}, type_names.p, (getfield.(other_names, :mp))::Vector{Dict{Int, Int}}; use_standard_stratification_syntax=use_standard_stratification_syntax, strict_matches=strict_matches, use_flags=use_flags) + end + QuoteNode(kw) => + error("Unknown block type for stratify syntax: " * String(kw)) + _ => current_phase(statement) + end + end +end + +""" +Apply default mappings, infer mapping if there's only a single option, and convert from Dict{Int, Int} to Vector{Int} +""" +function complete_mappings!(sfm::SFNames, sftype::SFNames; strict_mappings = false) + # get the default value, if it has been assigned. Use 0 if it hasn't. + all_index_mappings = get_mappings(sfm) + + default_index_stock = get(all_index_mappings[1], -1, 0) + default_index_sum = get(all_index_mappings[2], -1, 0) + default_index_dyvar = get(all_index_mappings[3], -1, 0) + default_index_flow = get(all_index_mappings[4], -1, 0) + default_index_param = get(all_index_mappings[5], -1, 0) + + + + # STEP 3 + if !strict_mappings + one_type_stock = length(sftype.snames) == 1 ? 1 : 0 # if there is only one stock, it needs to have index 1 + one_type_flow = length(sftype.fnames) == 1 ? 1 : 0 + one_type_dyvar = length(sftype.vnames) == 1 ? 1 : 0 + one_type_param = length(sftype.pnames) == 1 ? 1 : 0 + one_type_sum = length(sftype.svnames) == 1 ? 1 : 0 + else + one_type_stock = one_type_flow = one_type_dyvar = one_type_param = one_type_sum = 0 + end + + # Convert back to vectors. If you find a zero, check if there's a default and use that. If there isn't a default, check if there's only one option and use that. + # Otherwise, there's an unassigned value which can't be inferred. + # Taking max because it's less verbose than ternary and accomplishes the same thing: + # - If both default_index and one_type are mapped, they must be mapped to the same thing, because one_type being mapped implies there's only one option. + # - If only one_type is mapped, then it will be positive, and default_infex will be 0 + # - If only default_index is mapped, it will be positive and one_type will be 0 + + stock_mappings::Vector{Int} = [get(all_index_mappings[1], i, max(default_index_stock, one_type_stock)) for i ∈ eachindex(sfm.snames)] + sum_mappings::Vector{Int} = [get(all_index_mappings[2], i, max(default_index_sum, one_type_sum)) for i ∈ eachindex(sfm.svnames)] + dyvar_mappings::Vector{Int} = [get(all_index_mappings[3], i, max(default_index_dyvar, one_type_dyvar)) for i ∈ eachindex(sfm.vnames)] + flow_mappings::Vector{Int} = [get(all_index_mappings[4], i, max(default_index_flow, one_type_flow)) for i ∈ eachindex(sfm.fnames)] + param_mappings::Vector{Int} = [get(all_index_mappings[5], i, max(default_index_param, one_type_param)) for i ∈ eachindex(sfm.pnames)] + + append!(sfm.mvs, stock_mappings) + append!(sfm.mvsv, sum_mappings) + append!(sfm.mvv, dyvar_mappings) + append!(sfm.mvf, flow_mappings) + append!(sfm.mvp, param_mappings) + +end + + +""" + sfstratify(strata, type, aggregate, block ; kwargs) + + 1. Grab all names from strata, type and aggregate, and create dictionaries which map them to their indices + 2. Iterate over each line in the block + 2a. Split each line into a dictionary which maps all strata to that type and all aggregate to that type + 2b. Convert from two Symbol => Symbol dictionaries to two Int => Int dictionaries, using the dictionaries from step 1 + 2bα. If applicable, for symbols with ~ as a prefix, find all symbols with matching substrings in the symbol dictionaries, and map all those + 2c. Accumulate respective dictionaries (optionally, only allow first match vs throw an error (strict_matches = false vs true)) + 3. Create an array of 0s for stocks, flows, parameters, dyvars and sums for strata and aggregate. Insert into arrays all values from the two Int => Int dictionaries + 3a. If strict_mappings = false, if there only exists one option in type to map to, and it hasn't been explicitly specified, add it. If strict_mappings = true and it hasn't been specified, throw an error. + 4. Do a once-over of arrays and ensure there aren't any zeroes (unmapped values) remaining (helps with debugging when you screw up stratifying) + 5. Deal with attributes (create a copy of type sf with attributes mapped to nothing) + 6. Infer LS, LSV, etc. + 7. Construct strata -> type and aggregate -> type ACSetTransformations + 8. Return pullback (with flattened attributes) +""" +function sfstratify(others::Vector{K}, type::K, block::Expr ; use_standard_stratification_syntax = true, strict_mappings = false, strict_matches = false, temp_strat_default = :_, use_temp_strat_default = true, use_flags = true, return_homs = false) where {K <: AbstractStockAndFlowStructureF} + + Base.remove_linenums!(block) + + # STEP 1 + + + other_names::Vector{SFNames} = [SFNames(sf) for sf ∈ others] + type_names::SFNames = SFNames(type) # has some unnecessary fields. + + if use_temp_strat_default + # Applies function to every element in vector. + @assert all((sfn -> no_temp_strat_default_in_names(sfn, temp_strat_default)).(other_names)) && no_temp_strat_default_in_names(type_names, temp_strat_default) "A stockflow contains $(temp_strat_default) ! Please change temp_strat_default to a different symbol or rename offending object." + (sfn -> add_temp_strat_default!(sfn, temp_strat_default)).(other_names) + end + + + # STEP 2 + iterate_over_stratification_lines!(block, other_names, type_names ; use_standard_stratification_syntax=use_standard_stratification_syntax, strict_matches=strict_matches, use_flags=use_flags) + + + (sfn -> complete_mappings!(sfn, type_names ; strict_mappings=strict_mappings)).(other_names) + + # STEP 4 + + # This bit makes debugging when making a stratification easier. Tells you exactly which ones you forgot to map. + + #unmapped: + if !(all(is_all_mapped.(other_names))) + for i ∈ eachindex(other_names) + print_unmapped(other_names[i], "STOCKFLOW $i") + end + error("There is an unmapped value!") + end + + + # STEP 5 + # NothingFunction(x...) = nothing; + no_attribute_type = map(type, Name=NothingFunction, Op=NothingFunction, Position=NothingFunction) + + # STEP 6/7 + # This is where we pull out the magic to infer links. + # + # A <- C -> B + # || || + # v v + # A'<- C'-> B' + # + # implies + # + # A <- C -> B + # || || || + # v v v + # A'<- C'-> B' + # + + generate_all_mappings_function = m -> Dict(infer_links(m.sf, type, get_mappings_infer_links_format(m))..., get_mappings_infer_links_format(m)..., :Op => NothingFunction, :Position => NothingFunction, :Name => NothingFunction) + all_mappings = generate_all_mappings_function.(other_names) + + all_transformations = [ACSetTransformation(sfn.sf, no_attribute_type ; mappings...) for (sfn, mappings) ∈ zip(other_names, all_mappings)] + + # STEP 8 + + pullback_model = pullback(all_transformations) |> apex |> rebuildStratifiedModelByFlattenSymbols; + + if return_homs + return pullback_model, all_transformations + else + return pullback_model + end + +end + +""" + stratify(strata, type, aggregate, block) +Take three stockflows and a block describing where the first and third map on to the second, and get a new stratified stockflow. +Left side are strata objects, middle are type, right are aggregate. Each strata and aggregate object is associated with one type object. +The resultant stockflow contains objects which are the product of strata and aggregate objects which map to the same type object. +Use _ to match all objects in that category, ~ as a prefix to match all objects with the following string as a substring. Objects always go with their first match. +If the type model has a single object in a category, the mapping to it is automatically assumed. In the below example, we wouldn't need to specify :stocks or :sums. + +```julia + +@stratify WeightModel l_type ageWeightModel begin + :stocks + _ => pop <= _ + + :flows + ~Death => f_death <= ~Death + ~id => f_aging <= ~aging + ~Becoming => f_fstOrder <= ~id + _ => f_birth <= f_NB + + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + ~Death => v_death <= ~Death + ~id => v_aging <= v_agingCA, v_agingAS + v_BecomingOverWeight, v_BecomingObese => v_fstOrder <= v_idC, v_idA, v_idS + + :parameters + μ => μ <= μ + δw, δo => δ <= δC, δA, δS + rw, ro => rFstOrder <= r + rage => rage <= rageCA, rageAS + + :sums + N => N <= N + +end +``` +""" +macro stratify(strata, type, aggregate, block) + escaped_block = Expr(:quote, block) + quote + sfstratify([$(esc(strata)), $(esc(aggregate))], $(esc(type)), $(esc(escaped_block))) + end +end + + +""" +Alternate syntax for stratification, allows for an arbitrary number of stockflows in a pullback. +Second last argument must be the type stockflow, last must be the block describing how the stratificaition is done. All arguments before that must be stockflows. + +```julia + +@n_stratify WeightModel ageWeightModel l_type begin + :stocks + [_, _] => pop + + :flows + [~Death, ~Death] => f_death + [~id, ~aging] => f_aging + [~Becoming, ~id] => f_fstOrder + [_, f_NB] => f_birth + + + :dynamic_variables + [v_NewBorn, v_NB] => v_birth + [~Death, ~Death] => v_death + [~id, (v_agingCA, v_agingAS)] => v_aging + [(v_BecomingOverWeight, v_BecomingObese), (v_idC, v_idA, v_idS)] => v_fstOrder + + :parameters + [μ, μ] => μ + [(δw, δo), (δC, δA, δS)] => δ + [(rw, ro), r] => rFstOrder + [rage, (rageCA, rageAS)] => rage + + :sums + [N,N] => N +end + +``` + +""" +macro n_stratify(args...) + if length(args) < 3 + return :(MethodError("Too few arguments provided! Please provide some number of stockflows, then the type stock flow, then a quote block.")) + else + escaped_block = Expr(:quote, args[end]) + other_sfs = esc.(args[1:end-2]) + type = (esc(args[end-1])) + quote + sfstratify([$(other_sfs...)], $type, $escaped_block ; use_standard_stratification_syntax = false) + end + end +end + + +end \ No newline at end of file diff --git a/test/Syntax.jl b/test/Syntax.jl index a46733e7..d0fb2d5b 100755 --- a/test/Syntax.jl +++ b/test/Syntax.jl @@ -2,7 +2,11 @@ using Base: is_unary_and_binary_operator using Test using StockFlow using StockFlow.Syntax -using StockFlow.Syntax: is_binop_or_unary, sum_variables, infix_expression_to_binops, fnone_value_or_vector, extract_function_name_and_args_expr, is_recursive_dyvar, create_foot +using StockFlow.Syntax: is_binop_or_unary, sum_variables, infix_expression_to_binops, fnone_value_or_vector, extract_function_name_and_args_expr, is_recursive_dyvar, create_foot, apply_flags, substitute_symbols + +@testset "Stratification DSL" begin + include("syntax/Stratification.jl") +end @testset "Composition DSL" begin include("syntax/Composition.jl") @@ -347,3 +351,211 @@ end @test_throws Exception @eval @feet begin A => B; =>(D,E,F) end @test_throws Exception @eval @feet begin A => B; 1 => 2; end end + +########################### + +@testset "infer_links works as expected" begin + # No prior mappings means no inferred mappings + @test (infer_links(StockAndFlowF(), StockAndFlowF(), Dict{Symbol, Vector{Int64}}(:S => [], :F => [], :SV => [], :P => [], :V => [])) + == Dict(:LS => [], :LSV => [], :LV => [], :I => [], :O => [], :LPV => [], :LVV => [])) + + # S: 1 -> 1 and SV: 1 -> 1 implies LS: 1 -> 1 + @test (infer_links( + (@stock_and_flow begin; :stocks; A; :sums; NA = [A]; end), + (@stock_and_flow begin; :stocks; B; :sums; NB = [B]; end), + Dict{Symbol, Vector{Int64}}(:S => [1], :F => [], :SV => [1], :P => [], :V => [])) + == Dict(:LS => [1], :LSV => [], :LV => [], :I => [], :O => [], :LPV => [], :LVV => [])) + + # annoying exanmple, required me to add code to disambiguate using position + # that is, vA = A + A, vA -> vB, A -> implies that the As in the vA definition map to the Bs in the vB definition + # But both As link to the same stock and dynamic variable so just looking at those isn't enough to figure out what it maps to. + # There will exist cases where it's impossible to tell - eg, when there exist multiple duplicate links, and some positions don't match up. + + # It does not currently look at the operator. You could therefore map vA = A + A -> vB = B * B + # I can see this being useful, actually, specifically when mapping between + and -, * and /, etc. Probably logs and powers too. + # Just need to be aware that it won't say it's invalid. + @test (infer_links( + (@stock_and_flow begin; :stocks; A; :dynamic_variables; vA = A + A; end), + (@stock_and_flow begin; :stocks; B; :dynamic_variables; vB = B + B; end), + Dict{Symbol, Vector{Int64}}(:S => [1], :F => [], :SV => [], :P => [], :V => [1])) + == Dict(:LS => [], :LSV => [], :LV => [2,2], :I => [], :O => [], :LPV => [], :LVV => [])) # If duplicate values, always map to end. + + @test (infer_links( + (@stock_and_flow begin; :stocks; A; :parameters; pA; :dynamic_variables; vA = A + pA; end), + (@stock_and_flow begin; :stocks; B; :parameters; pB; :dynamic_variables; vB = pB + B; end), + Dict{Symbol, Vector{Int64}}(:S => [1], :F => [], :SV => [], :P => [1], :V => [1])) + == Dict(:LS => [], :LSV => [], :LV => [1], :I => [], :O => [], :LPV => [1], :LVV => [])) + + @test (infer_links( + (@stock_and_flow begin + :stocks + S + I + R + + :parameters + p_inf + p_rec + + + :flows + S => f_StoI(p_inf * S) => I + I => f_ItoR(I * p_rec) => R + + :sums + N = [S,I,R] + NI = [I] + NS = [S,I,R] + end), + (@stock_and_flow begin + :stocks + pop + + :parameters + p_generic + + + :flows + pop => f_generic(p_generic * pop) => pop + + :sums + N = [pop] + NI = [pop] + NS = [pop] + end), + + Dict{Symbol, Vector{Int64}}(:S => [1,1,1], :F => [1,1], :SV => [1,2,3], :P => [1,1], :V => [1,1])) + == Dict(:LS => [1,3,1,2,3,1,3], :LSV => [], :LV => [1,1], :I => [1,1], :O => [1,1], :LPV => [1,1], :LVV => [])) + + +end + + +@testset "Applying flags can correctly find substring matches" begin + @test apply_flags(:f_, Set([:~]), Vector{Symbol}()) == [] + @test apply_flags(:f_, Set([:~]), [:f_death, :f_birth]) == [:f_death, :f_birth] + @test apply_flags(:NOMATCH, Set([:~]), [:f_death, :f_birth]) == [] + @test apply_flags(:f_birth, Set([:~]), [:f_death, :f_birth]) == [:f_birth] + @test apply_flags(:f_birth, Set{Symbol}(), [:f_death, :f_birth]) == [:f_birth] + + # Note, apply_flags is specifically meant to work on vectors without duplicates; the vector which is input are the keys of a dictionary. + # Regardless, the following will hold: + @test apply_flags(:f_birth, Set{Symbol}(), [:f_death, :f_birth, :f_birth, :f_birth]) == [:f_birth] + @test apply_flags(:f_birth, Set{Symbol}([:~]), [:f_death, :f_birth, :f_birth, :f_birth]) == [:f_birth, :f_birth, :f_birth] +end + + +@testset "substitute_symbols will correctly associate values of the two provided dictionaries based on user defined mappings" begin + # substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + + + # Note, these dictionaries represent a vector where all the entries are unique, and the values are the original indices. + # So, both keys and values should be unique. + # For stratification, first dictionary is strata or aggregate, second is type, and the vector of DSLArgument are the user-defined maps. + # For homomorphism, first argument is src, second is dest, vector are user-defined maps. + @test substitute_symbols(Dict{Symbol, Int}(), Dict{Symbol, Int}(), Vector{DSLArgument}()) == Dict{Int, Int}() + @test substitute_symbols(Dict{Symbol, Int}(), Dict(:B => 2), Vector{DSLArgument}()) == Dict{Int, Int}() + + @test substitute_symbols(Dict(:A => 1), Dict(:B => 1), [DSLArgument(:A, :B, Set{Symbol}())]) == Dict(1 => 1) + @test substitute_symbols(Dict(:A1 => 1, :A2 => 2), Dict(:B => 1), [DSLArgument(:A1, :B, Set{Symbol}()), DSLArgument(:A2, :B, Set{Symbol}())]) == Dict(1 => 1, 2 => 1) + @test substitute_symbols(Dict(:A1 => 1), Dict(:B1 => 1, :B2 => 2), [DSLArgument(:A1, :B2, Set{Symbol}())]) == Dict(1 => 2) + + + @test substitute_symbols(Dict(:A1 => 1, :A2 => 2), Dict(:B1 => 1, :B2 => 2), [DSLArgument(:A, :B2, Set{Symbol}([:~]))]) == Dict(1 => 2, 2 => 2) + + # 1:100 + # 1:50 + # All multiples x of 14 below 100 go to x % 10 + 1 + @test (substitute_symbols(Dict(Symbol(i) => i for i ∈ 1:100), Dict(Symbol(-i) => i for i ∈ 1:50), [DSLArgument(Symbol(i), Symbol(-((i%10) + 1)), Set{Symbol}()) for i ∈ 1:100 if i % 14 == 0]) + == Dict(14 => 5, 28 => 9, 42 => 3, 56 => 7, 70 => 1, 84 => 5, 98 => 9)) + + # Captures everything with a 7 as a digit + @test (substitute_symbols(Dict(Symbol(i) => i for i ∈ 1:100), Dict(Symbol(-i) => i for i ∈ 1:50), [DSLArgument(Symbol(7), Symbol(-1), Set{Symbol}([:~]))]) + == Dict(7 => 1, 17 => 1, 27 => 1, 37 => 1, 47 => 1, 57 => 1, 67 => 1, 70 => 1, 71 => 1, 72 => 1, 73 => 1, 74 => 1, 75 => 1, 76 => 1, 77 => 1, 78 => 1, 79 => 1, 87 => 1, 97 => 1)) + + @test substitute_symbols(Dict(Symbol("~") => 1), Dict(:R => 1), [DSLArgument(Symbol("~"), :R, Set([:~]))], ; use_flags = false) == Dict(1 => 1) # Note, the Set([:~]) is ignored because use_flags is false + +end + + +@testset "non-natural transformations fail infer_links" begin + + # Map both dynamic variables to the same + # Obviously, this will fail, as the new dynamic variable needs a LVV and one LV, but instead has two LV + @test_throws KeyError (infer_links( + (@stock_and_flow begin + :stocks + A + + :dynamic_variables + v1 = A + A + v2 = v1 + A + end), + (@stock_and_flow begin + :stocks + A + + :dynamic_variables + v1 = A + A + end), + Dict{Symbol, Vector{Int64}}(:S => [1], :V => [1,1]))) + + + + # Mapping it all to I + + # This one fails when trying to figure out the inflow. Stock maps to 2, and flow maps to 2, + # But inflows on the target have (1,2) and (2,3) + + # This also wouldn't work if we tried mapping flow to 1 instead. Outflows expect 1,1 or 2,2, + # so it fails on (2,1). + @test_throws KeyError (infer_links( + (@stock_and_flow begin + :stocks + pop + + :parameters + p_generic + + + :flows + pop => f_generic(p_generic * pop) => pop + + :sums + N = [pop] + NI = [pop] + NS = [pop] + end), + (@stock_and_flow begin + :stocks + S + I + R + + :parameters + p_inf + p_rec + + + :flows + S => f_StoI(p_inf * S) => I + I => f_ItoR(I * p_rec) => R + + :sums + N = [S,I,R] + NI = [I] + NS = [S,I,R] + end), + Dict{Symbol, Vector{Int64}}(:S => [2], :F => [2], :SV => [1,2,3], :P => [2], :V => [2]))) + +end + + +@testset "Applying flags throws on invalid inputs" begin + @test_throws ErrorException apply_flags(:f_, Set([:+]), [:f_death, :f_birth]) # fails because :+ is not a defined operation + @test_throws ErrorException apply_flags(:f_birth, Set([:~, :+]), [:f_death, :f_birth]) # also fails for same reason + + @test_throws AssertionError apply_flags(:NOMATCH, Set{Symbol}(), Vector{Symbol}()) # fails because it's not looking for substrings, and :NOMATCH isn't in the list of options. + @test_throws AssertionError apply_flags(:NOMATCH, Set{Symbol}(), [:nomatch]) # same reason + +end \ No newline at end of file diff --git a/test/syntax/Stratification.jl b/test/syntax/Stratification.jl new file mode 100755 index 00000000..bfdea7d2 --- /dev/null +++ b/test/syntax/Stratification.jl @@ -0,0 +1,579 @@ +using StockFlow.Syntax.Stratification + +using StockFlow.Syntax.Stratification: interpret_stratification_standard_notation +using StockFlow.Syntax: NothingFunction, DSLArgument, unwrap_expression, substitute_symbols + +using Catlab.WiringDiagrams +using Catlab.ACSets +using Catlab.CategoricalAlgebra + + + + +@testset "Pullback computed in standard way is equal to DSL pullbacks" begin + + + l_type = @stock_and_flow begin + :stocks + pop + + :parameters + μ + δ + rFstOrder + rage + + :dynamic_variables + v_aging = pop * rage + v_fstOrder = pop * rFstOrder + v_birth = N * μ + v_death = pop * δ + + :flows + pop => f_aging(v_aging) => pop + pop => f_fstOrder(v_fstOrder) => pop + CLOUD => f_birth(v_birth) => pop + pop => f_death(v_death) => CLOUD + + :sums + N = [pop] + + end; + l_type_noatts = map(l_type, Name=NothingFunction, Op=NothingFunction, Position=NothingFunction); + + + WeightModel = @stock_and_flow begin + :stocks + NormalWeight + OverWeight + Obese + + :parameters + μ + δw + rw + ro + δo + rage + + :dynamic_variables + v_NewBorn = N * μ + v_DeathNormalWeight = NormalWeight * δw + v_BecomingOverWeight = NormalWeight * rw + v_DeathOverWeight = OverWeight * δw + v_BecomingObese = OverWeight * ro + v_DeathObese = Obese * δo + v_idNW = NormalWeight * rage + v_idOW = OverWeight * rage + v_idOb = Obese * rage + + :flows + CLOUD => f_NewBorn(v_NewBorn) => NormalWeight + NormalWeight => f_DeathNormalWeight(v_DeathNormalWeight) => CLOUD + NormalWeight => f_BecomingOverWeight(v_BecomingOverWeight) => OverWeight + OverWeight => f_DeathOverWeight(v_DeathOverWeight) => CLOUD + + OverWeight => f_BecomingObese(v_BecomingObese) => Obese + Obese => f_DeathObese(v_DeathObese) => CLOUD + NormalWeight => f_idNW(v_idNW) => NormalWeight + OverWeight => f_idOW(v_idOW) => OverWeight + Obese => f_idOb(v_idOb) => Obese + + :sums + N = [NormalWeight, OverWeight, Obese] + + end; + + + ageWeightModel = @stock_and_flow begin + :stocks + Child + Adult + Senior + + :parameters + μ + δC + δA + δS + rageCA + rageAS + r + + :dynamic_variables + v_NB = N * μ + v_DeathC = Child * δC + v_idC = Child * r + v_agingCA = Child * rageCA + v_DeathA = Adult * δA + v_idA = Adult * r + v_agingAS = Adult * rageAS + v_DeathS = Senior * δS + v_idS = Senior * r + + :flows + CLOUD => f_NB(v_NB) => Child + Child => f_idC(v_idC) => Child + Child => f_DeathC(v_DeathC) => CLOUD + Child => f_agingCA(v_agingCA) => Adult + Adult => f_idA(v_idA) => Adult + Adult => f_DeathA(v_DeathA) => CLOUD + Adult => f_agingAS(v_agingAS) => Senior + Senior => f_idS(v_idS) => Senior + Senior => f_DeathS(v_DeathS) => CLOUD + + :sums + N = [Child, Adult, Senior] + + end; + + begin + s, = parts(l_type, :S) + N, = parts(l_type, :SV) + lsn, = parts(l_type, :LS) + f_aging, f_fstorder, f_birth, f_death = parts(l_type, :F) + i_aging, i_fstorder, i_birth = parts(l_type, :I) + o_aging, o_fstorder, o_death = parts(l_type, :O) + v_aging, v_fstorder, v_birth, v_death = parts(l_type, :V) + lv_aging1, lv_fstorder1, lv_death1 = parts(l_type, :LV) + lsv_birth1, = parts(l_type, :LSV) + p_μ, p_δ, p_rfstOrder, p_rage = parts(l_type, :P) + lpv_aging2, lpv_fstorder2, lpv_birth2, lpv_death2 = parts(l_type, :LPV) + end; + + typed_WeightModel=ACSetTransformation(WeightModel, l_type_noatts, + S = [s,s,s], + SV = [N], + LS = [lsn,lsn,lsn], + F = [f_birth, f_death, f_fstorder, f_death, f_fstorder, f_death, f_aging, f_aging, f_aging], + I = [i_birth, i_aging, i_fstorder, i_aging, i_fstorder, i_aging], + O = [o_death, o_fstorder, o_aging, o_death, o_fstorder, o_aging, o_death, o_aging], + V = [v_birth, v_death, v_fstorder, v_death, v_fstorder, v_death, v_aging, v_aging, v_aging], + LV = [lv_death1, lv_fstorder1, lv_death1, lv_fstorder1, lv_death1, lv_aging1, lv_aging1, lv_aging1], + LSV = [lsv_birth1], + P = [p_μ, p_δ, p_rfstOrder, p_rfstOrder, p_δ, p_rage], + LPV = [lpv_birth2, lpv_death2, lpv_fstorder2, lpv_death2, lpv_fstorder2, lpv_death2, lpv_aging2, lpv_aging2, lpv_aging2], + Name=NothingFunction, Op=NothingFunction, Position=NothingFunction + ); + @assert is_natural(typed_WeightModel); + + + + typed_ageWeightModel=ACSetTransformation(ageWeightModel, l_type_noatts, + S = [s,s,s], + SV = [N], + LS = [lsn,lsn,lsn], + F = [f_birth, f_fstorder, f_death, f_aging, f_fstorder, f_death, f_aging, f_fstorder, f_death], + I = [i_birth, i_fstorder, i_aging, i_fstorder, i_aging, i_fstorder], + O = [o_fstorder, o_death, o_aging, o_fstorder, o_death, o_aging, o_fstorder, o_death], + V = [v_birth, v_death, v_fstorder, v_aging, v_death, v_fstorder, v_aging, v_death, v_fstorder], + LV = [lv_death1, lv_fstorder1, lv_aging1, lv_death1, lv_fstorder1, lv_aging1, lv_death1, lv_fstorder1], + LSV = [lsv_birth1], + P = [p_μ, p_δ, p_δ, p_δ, p_rage, p_rage, p_rfstOrder], + LPV = [lpv_birth2, lpv_death2, lpv_fstorder2, lpv_aging2, lpv_death2, lpv_fstorder2, lpv_aging2, lpv_death2, lpv_fstorder2], + Name =NothingFunction, Op=NothingFunction, Position=NothingFunction + ); + @assert is_natural(typed_ageWeightModel); + + aged_weight = pullback(typed_WeightModel, typed_ageWeightModel) |> apex |> rebuildStratifiedModelByFlattenSymbols; + + # ######################################### + + age_weight_2 = @stratify WeightModel l_type ageWeightModel begin + :stocks + NormalWeight, OverWeight, Obese => pop <= Child, Adult, Senior + + :flows + f_NewBorn => f_birth <= f_NB + f_DeathNormalWeight, f_DeathOverWeight, f_DeathObese => f_death <= f_DeathC, f_DeathA, f_DeathS + f_idNW, f_idOW, f_idOb => f_aging <= f_agingCA, f_agingAS + f_BecomingOverWeight, f_BecomingObese => f_fstOrder <= f_idC, f_idA, f_idS + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + v_DeathNormalWeight, v_DeathOverWeight, v_DeathObese => v_death <= v_DeathC, v_DeathA, v_DeathS + v_idNW, v_idOW, v_idOb => v_aging <= v_agingCA, v_agingAS + v_BecomingOverWeight, v_BecomingObese => v_fstOrder <= v_idC, v_idA, v_idS + + :parameters + μ => μ <= μ + δw, δo => δ <= δC, δA, δS + rw, ro => rFstOrder <= r + rage => rage <= rageCA, rageAS + + :sums + N => N <= N + + end + ######################################### + + age_weight_3 = @stratify WeightModel l_type ageWeightModel begin + + :flows + f_NewBorn => f_birth <= f_NB + ~Death => f_death <= ~Death + ~id => f_aging <= ~aging + ~Becoming => f_fstOrder <= ~id + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + ~Death => v_death <= ~Death + ~id => v_aging <= ~aging + ~Becoming => v_fstOrder <= ~id + + :parameters + μ => μ <= μ + ~δ => δ <= ~δ + rage => rage <= rageCA, rageAS + _ => rFstOrder <= _ + + end + + age_weight_4 = @stratify WeightModel l_type ageWeightModel begin + + :flows + ~NO_MATCHES => f_birth <= ~NO_MATCHES + f_NewBorn => f_birth <= f_NB + ~Death => f_death <= ~Death + ~id => f_aging <= ~aging + ~Becoming => f_fstOrder <= ~id + ~Becoming => f_aging <= ~id # Everything already matched; ignored + _ => f_aging <= _ # also ignored + + :dynamic_variables + v_NewBorn => v_birth <= v_NB + ~Death => v_death <= ~Death + ~id => v_aging <= ~aging + _ => v_fstOrder <= _ + + :parameters + μ => μ <= μ + ~δ => δ <= ~δ + rage => rage <= rageCA, rageAS + _ => rFstOrder <= _ + + end + + age_weight_5 = @n_stratify WeightModel ageWeightModel l_type begin + :stocks + [_, _] => pop + + :flows + [~Death, ~Death] => f_death + [~id, ~aging] => f_aging + [~Becoming, ~id] => f_fstOrder + [_, f_NB] => f_birth + + + :dynamic_variables + [v_NewBorn, v_NB] => v_birth + [~Death, ~Death] => v_death + [~id, (v_agingCA, v_agingAS)] => v_aging + [(v_BecomingOverWeight, v_BecomingObese), (v_idC, v_idA, v_idS)] => v_fstOrder + + :parameters + [μ, μ] => μ + [(δw, δo), (δC, δA, δS)] => δ + [(rw, ro), r] => rFstOrder + [rage, (rageCA, rageAS)] => rage + + :sums + [N,N] => N + end + + age_weight_6 = @n_stratify WeightModel ageWeightModel l_type begin + + :flows + [~Death, ~Death] => f_death + [~id, ~aging] => f_aging + [~Becoming, ~id] => f_fstOrder + [_, f_NB] => f_birth + + + :dynamic_variables + [v_NewBorn, v_NB] => v_birth + [~Death, ~Death] => v_death + [~id, (v_agingCA, v_agingAS)] => v_aging + [(v_BecomingOverWeight, v_BecomingObese), (v_idC, v_idA, v_idS)] => v_fstOrder + + :parameters + [μ, μ] => μ + [(δw, δo), (δC, δA, δS)] => δ + [(rw, ro), r] => rFstOrder + [rage, (rageCA, rageAS)] => rage + + end + + + + @test aged_weight == age_weight_2 + @test aged_weight == age_weight_3 + @test aged_weight == age_weight_4 + @test aged_weight == age_weight_5 + @test aged_weight == age_weight_6 + +end + +@testset "Ensuring interpret_stratification_standard_notation correctly reads lines" begin # This should be all valid cases. There's always going to be at least one value on both sides. + # Note the orders. The lists produced go left to right. A1, A2 => B <= C1, C2 results in [A1 => B, A2 => B], [C1 => B. C2 => B] + + + @test interpret_stratification_standard_notation(:(A => B <= C)) == [[DSLArgument(:A, :B, Set{Symbol}())], [DSLArgument(:C, :B, Set{Symbol}())]] + + @test interpret_stratification_standard_notation(:(A1, A2 => B <= C)) == [ + [DSLArgument(:A1, :B, Set{Symbol}()), DSLArgument(:A2, :B, Set{Symbol}())], + [DSLArgument(:C, :B, Set{Symbol}())] + ] + @test interpret_stratification_standard_notation(:(A => B <= C1, C2)) == [ + [DSLArgument(:A, :B, Set{Symbol}())], + [DSLArgument(:C1, :B, Set{Symbol}()), DSLArgument(:C2, :B, Set{Symbol}())], + ] + @test interpret_stratification_standard_notation(:(_ => B <= _)) == [ + [DSLArgument(:_, :B, Set{Symbol}())], + [DSLArgument(:_, :B, Set{Symbol}())], + ] + @test interpret_stratification_standard_notation(:(~A => B <= ~C)) == [ + [DSLArgument(:A, :B, Set{Symbol}([:~]))], + [DSLArgument(:C, :B, Set{Symbol}([:~]))], + ] + @test interpret_stratification_standard_notation(:(~A1, A2 => B <= ~C)) == [ + [DSLArgument(:A1, :B, Set{Symbol}([:~])), DSLArgument(:A2, :B, Set{Symbol}())], + [DSLArgument(:C, :B, Set{Symbol}([:~]))], + ] + + @test interpret_stratification_standard_notation(:(~_ => B <= ~_, C)) == [ # Weird case. Matches everything with _ as a substring. + [DSLArgument(:_, :B, Set{Symbol}([:~]))], + [DSLArgument(:_, :B, Set{Symbol}([:~])), DSLArgument(:C, :B, Set{Symbol}())] + ] + +end + + +@testset "Unwrapping expressions works correctly" begin + @test unwrap_expression(:S) == (:S, Set{Symbol}()) + @test unwrap_expression(:(~S)) == (:S, Set{Symbol}([:~])) + @test unwrap_expression(:(~_)) == (:_, Set{Symbol}([:~])) +end + + +# function substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int} + +@testset "Testing substituting symbols" begin # underscore matching occurs at the very end, after this step. + s1 = Dict(:A => 1) + t1 = Dict(:B => 2) + m1₁ = [DSLArgument(:A, :B, Set{Symbol}())] + m1₂ = [DSLArgument(:A, :B, Set{Symbol}([:~]))] + + @test substitute_symbols(s1, t1, m1₁) == Dict(1 => 2) # A=>B -> 1=>2 + @test substitute_symbols(s1, t1, m1₂) == Dict(1 => 2) # A=>B -> 1=>2 + @test substitute_symbols(s1, t1, m1₂, use_flags=false) == Dict(1 => 2) # A=>B -> 1=>2 + + s2 = Dict(:A1 => 10, :A2 => 20, :A3 => 30) # Unfortunately, cannot do substring matches starting with numbers, since it would require a symbol starting with a number. Might need to add something for this... + t2 = Dict(:B1 => 1, :B2 => 2) + m2₁ = [DSLArgument(:A, :B1, Set{Symbol}([:~]))] + + @test substitute_symbols(s2, t2, m2₁) == Dict(10 => 1, 20 => 1, 30 => 1) #~A=>B -> 10=>1, 20=>1, 30=>1 + + s3 = Dict{Symbol, Int}() + t3 = Dict{Symbol, Int}() + m3 = Vector{DSLArgument}() + + @test substitute_symbols(s3, t3, m3) == Dict() + @test substitute_symbols(s3, t3, m3, use_flags=false) == Dict() + + s4 = Dict(:A1 => 1, :A2 => 2, :AB3 => 3, :AB4 => 4, :A5 => 5) + t4 = Dict(:B1 => 1, :B2 => 2, :B3 => 3) + m4 = [DSLArgument(:A1, :B1, Set{Symbol}()), DSLArgument(:B, :B2, Set{Symbol}([:~])), DSLArgument(:A, :B3, Set{Symbol}([:~]))] + + # always goes with first match. A1 is taken, B matches AB3 and AB4, then A matches A2 and A5 + @test substitute_symbols(s4, t4, m4) == Dict(1 => 1, 3 => 2, 4 => 2, 2 => 3, 5 => 3) +end + + +@testset "nondefault flags work as expected" begin + A_ = (@stock_and_flow begin + :stocks + A + _ + end) + + X_ = (@stock_and_flow begin + :stocks + X + _ + end) + + B_ = (@stock_and_flow begin + :stocks + B + _ + end) + + strat_AXB = (quote # Note, we use a quote when calling the function, begin when calling the macro. + :stocks + _ => _ <= _ + A => X <= B + ~A => X <= ~B # everything is already assigned, so does nothing (or throws error if strict_matches is true) + end) + + + sfA = (@stock_and_flow begin; :stocks; A; end;) + + @test (sfstratify([A_, B_], X_, strat_AXB, use_temp_strat_default=false) + == (@stock_and_flow begin + :stocks + AB + __ + end)) + + # doesn't show up anywhere, so doesn't affect anything. Could also set it to something untypable in the DSL, like Symbol("") + @test (sfstratify([A_, B_], X_, strat_AXB, temp_strat_default=:ABABABABA) + == (@stock_and_flow begin + :stocks + AB + __ + end)) + + @test_throws AssertionError (sfstratify([A_, B_], X_, strat_AXB, strict_matches=true)) # A matches against A and ~A, which is disallowed with this flag. + + @test_throws ErrorException (sfstratify([sfA,sfA],sfA,(quote end) ; strict_mappings=true)) # strict_mappings=false wouldn't throw an error, and would infer strata and aggregate need to map to the only stock. + + + nothing_sfA = map(sfA, Position=NothingFunction, Op=NothingFunction, Name=NothingFunction) + + @test (sfstratify([sfA,sfA],sfA,(quote end), return_homs=true) == ( + (@stock_and_flow begin + :stocks + AA + end), + [ACSetTransformation(sfA, nothing_sfA ; S=[1], F=Vector{Int}(),V =Vector{Int}(),SV=Vector{Int}(),P=Vector{Int}(),LS=Vector{Int}(),I=Vector{Int}(),O=Vector{Int}(),LV=Vector{Int}(),LSV=Vector{Int}(),LVV=Vector{Int}(),LPV=Vector{Int}(), Position=NothingFunction, Op=NothingFunction, Name=NothingFunction), # strata -> type + ACSetTransformation(sfA, nothing_sfA ; S=[1], F=Vector{Int}(),V =Vector{Int}(),SV=Vector{Int}(),P=Vector{Int}(),LS=Vector{Int}(),I=Vector{Int}(),O=Vector{Int}(),LV=Vector{Int}(),LSV=Vector{Int}(),LVV=Vector{Int}(),LPV=Vector{Int}(), Position=NothingFunction, Op=NothingFunction, Name=NothingFunction)] # aggregate -> type + )) # the empty lists are necessary for equality, but it'd still be an equivalent homomorphism if you didn't specify them. + +end + + + + +@testset "n_stratify works as expected" begin + + + l_type = @stock_and_flow begin + :stocks + pop + + :parameters + μ + δ + rFstOrder + rage + + :dynamic_variables + v_aging = pop * rage + v_fstOrder = pop * rFstOrder + v_birth = N * μ + v_death = pop * δ + + :flows + pop => f_aging(v_aging) => pop + pop => f_fstOrder(v_fstOrder) => pop + CLOUD => f_birth(v_birth) => pop + pop => f_death(v_death) => CLOUD + + :sums + N = [pop] + + end + + chain_ltype = @stock_and_flow begin + :stocks + poppoppop + + :parameters + μμμ + δδδ + rFstOrderrFstOrderrFstOrder + rageragerage + + :dynamic_variables + v_agingv_agingv_aging = poppoppop * rageragerage + v_fstOrderv_fstOrderv_fstOrder = poppoppop * rFstOrderrFstOrderrFstOrder + v_birthv_birthv_birth = NNN * μμμ + v_deathv_deathv_death = poppoppop * δδδ + + :flows + poppoppop => f_agingf_agingf_aging(v_agingv_agingv_aging) => poppoppop + poppoppop => f_fstOrderf_fstOrderf_fstOrder(v_fstOrderv_fstOrderv_fstOrder) => poppoppop + CLOUD => f_birthf_birthf_birth(v_birthv_birthv_birth) => poppoppop + poppoppop => f_deathf_deathf_death(v_deathv_deathv_death) => CLOUD + + :sums + NNN = [poppoppop] + end + + chain_ltype_nstratify = @n_stratify l_type l_type l_type l_type begin + + :stocks + [pop, ~pop, _] => pop + + :parameters + [μ, μ, μ] => μ + [δ, δ, δ] => δ + [rFstOrder, rFstOrder, rFstOrder] => rFstOrder + [rage, rage, rage] => rage + + :dynamic_variables + [v_aging, v_aging, v_aging] => v_aging + [v_fstOrder, v_fstOrder, v_fstOrder] => v_fstOrder + [v_birth, v_birth, v_birth] => v_birth + [v_death, v_death, v_death] => v_death + + :flows + [f_aging, f_aging, f_aging] => f_aging + [f_fstOrder, f_fstOrder, f_fstOrder] => f_fstOrder + [f_birth, f_birth, f_birth] => f_birth + [f_death, f_death, f_death] => f_death + + :sums + [N, N, N] => N + end + + + @test chain_ltype == chain_ltype_nstratify + + + ltype_nstratify = @n_stratify l_type l_type begin + + :stocks + [pop] => pop + + :parameters + [μ] => μ + [δ] => δ + [rFstOrder] => rFstOrder + [rage] => rage + + :dynamic_variables + [v_aging] => v_aging + [v_fstOrder] => v_fstOrder + [v_birth] => v_birth + [v_death] => v_death + + :flows + [f_aging] => f_aging + [f_fstOrder] => f_fstOrder + [f_birth] => f_birth + [f_death] => f_death + + :sums + [N] => N + end + + @test ltype_nstratify == l_type + + +end + + + +