Skip to content

Commit

Permalink
feat: optimizers with vector functions
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcohen1 committed Jul 9, 2024
1 parent 8a408b8 commit baaadd0
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 368 deletions.
2 changes: 1 addition & 1 deletion src/FinSetAlgebras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ portmap(obj::Open{T}) where T = obj.m

# Helper function for when m is identity.
function Open{T}(o::T) where T
Open{T}(domain(o), o, id(domain(o)))
Open{T}(dom(o), o, id(dom(o)))
end

dom(obj::Open{T}) where T = dom(obj.m)
Expand Down
153 changes: 7 additions & 146 deletions src/Objectives.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Objectives

export PrimalObjective, MinObj, gradient_flow,
SaddleObjective, DualComp, primal_solution, dual_objective, primal_objective, pullback_function
SaddleObjective, DualComp, primal_solution, dual_objective, primal_objective

using ..FinSetAlgebras
import ..FinSetAlgebras: hom_map, laxator
Expand All @@ -11,6 +11,8 @@ import Catlab: oapply, dom
using ForwardDiff
using Optim



# Primal Minimization Problems and Gradient Descent
###################################################

Expand Down Expand Up @@ -38,15 +40,15 @@ struct MinObj <: FinSetAlgebra{PrimalObjective} end
The morphism map is defined by ϕ ↦ (f ↦ f∘ϕ^*).
"""
hom_map(::MinObj, ϕ::FinFunction, p::PrimalObjective) =
PrimalObjective(codom(ϕ), x->p(test_pullback_function(ϕ, x)))
PrimalObjective(codom(ϕ), x->p(pullback_function(ϕ, x)))

""" laxator(::MinObj, Xs::Vector{PrimalObjective})
Takes the "disjoint union" of a collection of primal objectives.
"""
function laxator(::MinObj, Xs::Vector{PrimalObjective})
c = coproduct([dom(X) for X in Xs])
subproblems = [x -> X(test_pullback_function(l, x)) for (X,l) in zip(Xs, legs(c))]
subproblems = [x -> X(pullback_function(l, x)) for (X,l) in zip(Xs, legs(c))]
objective(x) = sum([sp(x) for sp in subproblems])
return PrimalObjective(apex(c), objective)
end
Expand Down Expand Up @@ -96,14 +98,14 @@ struct DualComp <: FinSetAlgebra{SaddleObjective} end
# Only "glue" along dual variables
hom_map(::DualComp, ϕ::FinFunction, p::SaddleObjective) =
SaddleObjective(p.primal_space, codom(ϕ),
(x,λ) -> p(x, test_pullback_function(ϕ, λ)))
(x,λ) -> p(x, pullback_function(ϕ, λ)))

# Laxate along both primal and dual variables
function laxator(::DualComp, Xs::Vector{SaddleObjective})
c1 = coproduct([X.primal_space for X in Xs])
c2 = coproduct([X.dual_space for X in Xs])
subproblems = [(x,λ) ->
X(test_pullback_function(l1, x), test_pullback_function(l2, λ)) for (X,l1,l2) in zip(Xs, legs(c1), legs(c2))]
X(pullback_function(l1, x), pullback_function(l2, λ)) for (X,l1,l2) in zip(Xs, legs(c1), legs(c2))]
objective(x,λ) = sum([sp(x,λ) for sp in subproblems])
return SaddleObjective(apex(c1), apex(c2), objective)
end
Expand All @@ -123,145 +125,4 @@ function gradient_flow(of::Open{SaddleObjective})
λ -> ForwardDiff.gradient(dual_objective(f, x(λ)), λ), of.m)
end






# New stuff

struct UnivTypedPrimalObjective
decision_space::FinSet
objective::Function # R^ds -> R NOTE: should be autodifferentiable
type::FinDomFunction # R^ds -> Z+, for our use case. One function to rule them all--this will stay the same across our finsets. FinDomFunction
end

(p::UnivTypedPrimalObjective)(x::Vector) = p.objective(x)
dom(p::UnivTypedPrimalObjective) = p.decision_space

""" MinObj
# Finset-algebra implementing composition of minimization problems by variable sharing.
# """
struct MinObj <: FinSetAlgebra{PrimalObjective} end

""" hom_map(::MinObj, ϕ::FinFunction, p::PrimalObjective)
The morphism map is defined by ϕ ↦ (f ↦ f∘ϕ^*).
"""
hom_map(::MinObj, ϕ::FinFunction, p::UnivTypedPrimalObjective) = # Another version, which wouldn't require a universal type function, would have you pass in a custom type function for your set M. This would require more work in the laxator to take the disjoint union of type functions.
all(p.type(x) == p.type(ϕ(x)) for x in dom(p)) ?
UnivTypedPrimalObjective(codom(ϕ), x -> p(test_pullback_function(ϕ, x)), p.type) :
error("The ϕ provided is not type-preserving.") # throw an error



""" laxator(::MinObj, Xs::Vector{PrimalObjective})
Takes the "disjoint union" of a collection of primal objectives.
"""
function laxator(::MinObj, Xs::Vector{UnivTypedPrimalObjective})
c = coproduct([dom(X) for X in Xs])
subproblems = [x -> X(test_pullback_function(l, x)) for (X, l) in zip(Xs, legs(c))]
objective(x) = sum([sp(x) for sp in subproblems])
return UnivTypedPrimalObjective(apex(c), objective, Xs[1].type) # Assuming all have the same type function
end






struct TypedPrimalObjective # Should we put restrictions on the constructor, i.e. check that dom(objective) = dom(type) = decision_space?
decision_space::FinSet
objective::Function # R^ds -> R NOTE: should be autodifferentiable. Make it a FinDomFunction?
type::FinDomFunction
end

(p::TypedPrimalObjective)(x::Vector) = p.objective(x)
dom(p::TypedPrimalObjective) = p.decision_space


hom_map(::MinObj, ϕ::FinFunction, σ::FinDomFunction, τ::FinDomFunction, p::TypedPrimalObjective) = # τ seems completely unnecessary to me
all(p.type(x) == σ(ϕ(x)) && p.type(x) == τ(x) for x in dom(p)) ?
UnivTypedPrimalObjective(codom(ϕ), x -> p(test_pullback_function(ϕ, x)), σ) : # Note: we didn't check but σ must be applicable across all of codom(ϕ)
nothing # throw an error



function laxator(::MinObj, Xs::Vector{TypedPrimalObjective})
combinedType = copair([X.type for X in Xs])
c = dom(combinedType) # c is the coproduct of the decision spaces
objective(x) = sum(X(test_pullback_function(l, x)) for (X, l) in zip(Xs, legs(c))) # Calculated the same as before (but simplified onto one line)
return UnivTypedPrimalObjective(apex(c), objective, combinedType)
end




# Pullback function for a given ϕ and vector v
# function pullback(ϕ::FinFunction, v::Vector)
# output = Vector{eltype(v)}(undef, length(dom(ϕ)))
# for i in 1:length(dom(ϕ))
# output[i] = v[ϕ(i)]
# end
# return output
# end

# Optional easier version: return Vector{eltype(v)}(v[ϕ(i)] for i in 1:length(dom(ϕ)))



# Curried version of the pullback function
function curried_pullback::FinFunction)
return function (v::Vector)
output = Vector{eltype(v)}(undef, length(dom(ϕ)))
for i in 1:length(dom(ϕ))
output[i] = v[ϕ(i)]
end
return output
end
end



function pullback_function::FinFunction, v::Vector)
return v[ϕ.(1:length(dom(ϕ)))] # Broadcasting with vector of indices
end




# Not the active one
function typed_pullback_matrix(f::FinFunction) # Modify code

# Track codomain indices
prefixes = Dict()
lastPrefix = 0

for v in codom(f)
prefixes[v] = lastPrefix
lastPrefix += length(v)
end
# lastPrefix now holds the sum of the sizes across all the output vectors

domLength = 0
result = []
for v in dom(f)
domLength += length(v)
for i in 1:length(v)
push!(result, prefixes[f(v)] + i)
end
end

sparse(1:domLength, result, ones(Int, domLength), domLength, lastPrefix)
end








end # module
Loading

0 comments on commit baaadd0

Please sign in to comment.