Skip to content

Commit

Permalink
Merge branch 'nstratify_good' into foot_error_messages
Browse files Browse the repository at this point in the history
  • Loading branch information
neonWhiteout committed Sep 27, 2023
2 parents 0a29a83 + 16cde76 commit 6719a9a
Show file tree
Hide file tree
Showing 6 changed files with 1,462 additions and 4 deletions.
18 changes: 17 additions & 1 deletion src/StockFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ funcDynam, flowVariableIndex, funcFlow, funcFlows, funcSV, funcSVs, TransitionMa
vectorfield, funcFlowsRaw, funcFlowRaw, inflowsAll, outflowsAll,instock,outstock, stockssv, stocksv, svsv, svsstock,
vsstock, vssv, svsstockAllF, vsstockAllF, vssvAllF, StockAndFlowUntyped, StockAndFlowFUntyped, StockAndFlowStructureUntyped, StockAndFlowStructureFUntyped, StockAndFlowUntyped0, Open, snames, fnames, svnames, vnames,
object_shift_right, foot, leg, lsnames, OpenStockAndFlow, OpenStockAndFlowOb, fv, fvs, nlvv, nlpv, vtgt, vsrc, vpsrc, vptgt, pname, pnames, make_v_expr,
vop, lvvposition, lvtgtposition, lsvvposition, lpvvposition, recreate_stratified, set_snames!, set_fnames!, set_svnames!, set_vnames!, set_pnames!, set_sname!, set_fname!, set_svname!, set_vname!, set_pname!
vop, lvvposition, lvtgtposition, lsvvposition, lpvvposition, recreate_stratified, set_snames!, set_fnames!, set_svnames!, set_vnames!, set_pnames!, set_sname!, set_fname!, set_svname!, set_vname!, set_pname!,
get_lss, get_lssv, get_lsvsv, get_lsvv, get_lvs, get_lvv, get_is, get_ifn, get_os, get_ofn, get_lpvp, get_lpvv, get_lvsrc, get_lvtgt, get_links


using Catlab
using Catlab.CategoricalAlgebra
Expand Down Expand Up @@ -268,6 +270,20 @@ nlpv(p::AbstractStockAndFlowStructureF) = nparts(p,:LPV) #links from dynamic var
np(p::AbstractStockAndFlowStructureF) = nparts(p,:P) #parameters


get_lss(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lss].m))
get_lssv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lssv].m))
get_lsvsv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lsvsv].m))
get_lsvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lsvv].m))
get_lvs(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvs].m))
get_lvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvv].m))
get_is(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:is].m))
get_ifn(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:ifn].m))
get_os(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:os].m))
get_ofn(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:ofn].m))
get_lpvp(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lpvp].m))
get_lpvv(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lpvv].m))
get_lvsrc(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvsrc].m))
get_lvtgt(sf::AbstractStockAndFlowF) = collect(values(sf.subparts[:lvtgt].m))


#EXAMPLE:
Expand Down
212 changes: 211 additions & 1 deletion src/Syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ end
```
"""
module Syntax
export @stock_and_flow, @foot, @feet
export @stock_and_flow, @foot, @feet, infer_links

using ..StockFlow
using MLStyle

import Base: ==, Iterators.flatmap

"""
stock_and_flow(block :: Expr)
Expand Down Expand Up @@ -1049,7 +1051,215 @@ function match_foot_format(footblock::Expr)
end
end

#############################################

"""
infer_particular_link!(sfsrc, sftgt, f1, f2, map1, map2, destination_vector, posf=nothing)
infer_particular_link!(sfsrc, sftgt, get_lvs, get_lvv, stockmaps, dyvarmaps, lvmaps, get_lvvposition) # LV
If we're mapping the same value to multiple positions, it doesn't matter which one goes where.
We have a few options, on how we want to distribute mappings. Way it's done here, always goes to the last position.
"""
function infer_particular_link!(sfsrc, sftgt, f1, f2, map1, map2, destination_vector)

hom1′_mappings = f1(sftgt)
hom2′_mappings = f2(sftgt)

tgt::Dict{Tuple{Int, Int}, Int} = Dict((hom1′, hom2′) => i for (i, (hom1′, hom2′)) in enumerate(zip(hom1′_mappings, hom2′_mappings))) # ISSUE: If there are two matches, second one overwrites the first.
# SOLUTION: Who cares. Just map to the last.
for (i, (hom1, hom2)) in enumerate(zip(f1(sfsrc), f2(sfsrc)))
mapped_index1 = map1[hom1]
mapped_index2 = map2[hom2]

linkmap = tgt[(mapped_index1, mapped_index2)]



destination_vector[i] = linkmap # updated
end
return destination_vector

end


"""
infer_links(sfsrc :: StockAndFlowF, sftgt :: StockAndFlowF, NecMaps :: Dict{Symbol, Vector{Int64}})
Infer LS, I, O, LV, LSV, LVV, LPV mappings for an ACSetTransformation.
Returns dictionary of Symbols to lists of indices, corresponding to an ACSetTransformation argument.
If there exist no such mappings (eg, no LVV), that pairing will not be included in the returned dictionary.
If A <- C -> B, and we have A -> A' and B -> B' and a unique C' such that A' <- C' -> B', we can assume C -> C'.
:S => [2,4,1,3], :F => [1,2,4,3], ...
NecMaps must contain keys S, F, SV, P, V, each pointing to a (possibly empty) array of indices
"""
function infer_links(sfsrc :: StockAndFlowF, sftgt :: StockAndFlowF, NecMaps :: Dict{Symbol, Vector{Int64}})


stockmaps = NecMaps[:S]
flowmaps = NecMaps[:F]
summaps = NecMaps[:SV]
parammaps = NecMaps[:P]
dyvarmaps = NecMaps[:V]

lsmaps = zeros(Int, nls(sfsrc))
imaps = zeros(Int, ni(sfsrc))
omaps = zeros(Int, no(sfsrc))
lvmaps = zeros(Int, nlv(sfsrc))
lsvmaps = zeros(Int, nlsv(sfsrc))
lvvmaps = zeros(Int, nlvv(sfsrc))
lpvmaps = zeros(Int, nlpv(sfsrc))
# After the following calls, there should be no zeroes.


infer_particular_link!(sfsrc, sftgt, get_lss, get_lssv, stockmaps, summaps, lsmaps) # LS
infer_particular_link!(sfsrc, sftgt, get_ifn, get_is, flowmaps, stockmaps, imaps) # I
infer_particular_link!(sfsrc, sftgt, get_ofn, get_os, flowmaps, stockmaps, omaps) # O
infer_particular_link!(sfsrc, sftgt, get_lvs, get_lvv, stockmaps, dyvarmaps, lvmaps) # LV
infer_particular_link!(sfsrc, sftgt, get_lsvsv, get_lsvv, summaps, dyvarmaps, lsvmaps) # LSV
infer_particular_link!(sfsrc, sftgt, get_lvsrc, get_lvtgt, dyvarmaps, dyvarmaps, lvvmaps) # LVV
infer_particular_link!(sfsrc, sftgt, get_lpvp, get_lpvv, parammaps, dyvarmaps, lpvmaps) # LPV

return Dict(:LS => lsmaps, :LSV => lsvmaps, :LV => lvmaps, :I => imaps, :O => omaps, :LPV => lpvmaps, :LVV => lvvmaps)

include("syntax/Composition.jl")

end


struct DSLArgument
key::Symbol
value::Symbol
flags::Set{Symbol} # At present, the only flag that exists is ~
DSLArgument(kv::Pair{Union{Expr, Symbol}, Symbol}) = begin # this constructor seemed to fail... need to figure out why. Maybe it can't call other constructors.
key, flags = unwrap_expression(first(kv))
new(key, second(kv), flags)
end
DSLArgument(k::Union{Expr, Symbol}, v::Symbol) = begin
key, flags = unwrap_expression(k)
new(key, v, flags)
end
DSLArgument(k::Symbol, v::Symbol, f::Set{Symbol}) = new(k, v, f)
end

==(a::DSLArgument, b::DSLArgument) = a.key == b.key && a.value == b.value && a.flags == b.flags


function unwrap_expression(x::Union{Symbol, Expr}, flags::Set{Symbol}=Set{Symbol}())::Tuple{Symbol, Set{Symbol}} # No mutable default arguments.
if typeof(x) == Symbol
return (x, flags)
else
return unwrap_expression(x.args[2], push!(flags, x.args[1]))
end
end


"""
S₁ => I₁
S₂ => I₂
S₁ => S₂
I₁ => I₂
Determine what index an element e maps to based upon what f we have in the mapping such that e -> f
"""
function connect_by_value(; src::Dict{T,U}, mapping::Dict{T,T}, tgt::Dict{T,U})::Dict{U, U} where {T, U}
@assert allunique(values(src))

@assert all(x -> x keys(mapping), keys(src))
@assert all(x -> x keys(tgt), values(mapping))

return Dict(src[key] => tgt[value] for (key, value) in mapping)

end


"""
Filter a vector for all elements with substr as a substring.
"""
function substring_matches(v::Vector, substr::String)::Vector
return filter(x -> occursin(substr, string(x)), v)
end


"""
Takes a symbol 'key', applys flags, finds matches in s, and returns a vector of matching keys.
Currently, there are two options: no flags, in which case [key] is returned, or ~ is the only flag, in which case all Symbols with matching substrings are returned.
"""
function apply_flags(key::Symbol, flags::Set{Symbol}, s::Vector{Symbol})::Vector{Symbol} # Could make this a generator?

if isempty(flags)
@assert (key s) "$s does not contain key $key ! Did you forget to prefix ~?"
return [key] # potentially inefficient
elseif :~ flags

matches = collect(substring_matches(s, string(key)))

new_flags = copy(flags) # copy isn't necessary, probably
pop!(new_flags, :~)

return collect(flatmap(x -> apply_flags(x, new_flags, s), matches)) # this is just in case we add additional flags. As is, the recursion is unnecessary.
else
error("Unknown flag found! $(flags)")
end
end

"""
substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int}
Convert Dict(SymA => IntA), Dict(SymB => IntB), Dict(SymA => SymB) into Dict{IntA => IntB}
Using original sf defintions, and the user defined mappings, transform user defined symbol mappings to index mappings.
"""
function substitute_symbols(s::Dict{Symbol, Int}, t::Dict{Symbol, Int}, m::Vector{DSLArgument} ; use_flags::Bool=true)::Dict{Int, Int}
if !use_flags
mapping = Dict(arg.key => arg.value for arg in m)
return connect_by_value(src=s, mapping=mapping, tgt=t)
else
master_dict::Dict{Int, Int} = Dict()
for statement in m
key_matches = apply_flags(statement.key, statement.flags, collect(keys(s))) # Vector of Symbol
if isempty(key_matches)
println("WARNING! No matches on $(statement.key) with flags $(statement.flags)")
else
mergewith!((x...) -> first(x), master_dict, Dict(s[match] => t[statement.value] for match key_matches))
end
end
return master_dict
end
end


"""
Convert a vector of unique elements to a dictionary with each element pointing to their original index.
"""
function invert_vector(v::Vector{K})::Dict{K, Int} where {K} # Elements of v must be hashable
new_dict = Dict(val => i for (i, val) enumerate(v))
@assert length(new_dict) == length(v) "Nonunique key in vector v: $v"
return new_dict
end


"""
Takes any arguments and returns nothing.
Used so we can maintain equality when making ACSetTransformations.
"""
NothingFunction(x...)::Nothing = nothing;




include("syntax/Composition.jl")
include("syntax/Stratification.jl")

end





2 changes: 1 addition & 1 deletion src/SystemStructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ extracVAndAttrStructureAndFlatten(p::AbstractStockAndFlowF) = begin
if nvb(p)>0
for v in 1:nvb(p)
vn = flattenTupleNames(vname(p,v))
v_op = foldr(==,vop(p,v)) ? vop(p,v)[1] : error("operators $(vop(p,v)) in the stratified model's auxiliary variable: $(join(vname(p,v))) should be the same!")
v_op = allequal(vop(p,v)) ? vop(p,v)[1] : error("operators $(vop(p,v)) in the stratified model's auxiliary variable: $(join(vname(p,v))) should be the same!")
vnp = vn=>(args(p,v)=>v_op)
vs = vcat(vs,vnp)
end
Expand Down
Loading

0 comments on commit 6719a9a

Please sign in to comment.