Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dagger Instance #147

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@
*.jl.cov
*.jl.*.cov
*.jl.mem

# Experiments
experiments/*/Manifest.toml
12 changes: 12 additions & 0 deletions experiments/DaggerAMC/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name = "DaggerAMC"
uuid = "85a14c48-69db-4cd9-97af-587947037424"
authors = ["bosonbaas <[email protected]>"]
version = "0.1.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
153 changes: 153 additions & 0 deletions experiments/DaggerAMC/src/DaggerAMC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
module DaggerAMC

import Base: +
using AutoHashEquals

export DagDom, ThunkArr, copy,
dom, codom, compose, id, oplus, mzero,
zero, ⊕, braid
# To implement
# mcopy, delete, plus, zero,

using Dagger
using LinearAlgebra
#using Catlab.Doctrines.AdditiveMonoidal
using Catlab, Catlab.Doctrines, Catlab.Programs
import Catlab.Doctrines:
Ob, Hom, dom, codom, compose, ⋅, ∘, id, oplus, ⊕, mzero, braid,
dagger, dunit, dcounit, mcopy, Δ, delete, ◊, mmerge, ∇, create, □,
plus, zero, coplus, cozero, meet, top, join, bottom


@auto_hash_equals struct DagDom
N::Int
end

# This structure was created to keep track of dom and codom information.
# This information can be updated efficiently, and keeping it here keeps
# LinearFunctions from having to think the thunk each time the dom or codom
# is queried

struct ThunkArr
input::Array{Tuple{Int64,Int64},1}
output::Array{Int64,1}
thunks::Array{Thunk,1}
end

input_nodes(f::ThunkArr) = begin
input_n = Set{Int64}()
for port in f.input
push!(input_n,port[1])
end
return input_n
end

copy(A::ThunkArr) = begin
n_thunks = [Thunk(x.f, x.inputs...) for x in A.thunks]
n_input = A.input
n_output = A.output
ThunkArr(n_input, n_output, n_thunks)
end

ThunkArr(A::AbstractArray) = begin
id_thunks = [delayed(identity)(A[x]) for x in 1:length(A)[1]]
id_input = []
id_output = Array(1:length(A))
ThunkArr(id_input, id_output, id_thunks)
end

ThunkArr(A) = begin
ThunkArr([A])
end

@instance AdditiveSymmetricMonoidalCategory(DagDom, ThunkArr) begin
zero(V::DagDom) = DagDom(0)
mzero(::Type{DagDom}) = DagDom(0)
dom(f::ThunkArr) = size(f.input)[1]
codom(f::ThunkArr) = size(f.output)[1]

compose(f::ThunkArr,g::ThunkArr) = begin
add_ind = (x,n) -> x+n
cf = f
cg = g
n_output = add_ind.(cf.output,size(cg.thunks)[1])
n_input = cg.input
# f_inputs stores what thunks will be passed in from g
f_inputs = Dict(x => Array{Thunk}(undef, length(cf.thunks[x].inputs)) for x in input_nodes(cf))
# Fill out the values that will be passed in from g
for port_num in 1:length(cf.input)
port = cf.input[port_num]
g_node = cg.thunks[cg.output[port_num]]
f_inputs[port[1]][port[2]] = g_node
end
for (key,g_in) in f_inputs
cf.thunks[key].inputs = Tuple(g_in)
end

n_thunks = vcat(cg.thunks,cf.thunks)
ThunkArr(n_input, n_output, n_thunks)
end

oplus(V::DagDom, W::DagDom) = DagDom(V.N + W.N)

# Make copies of thunks to keep from variable interference
oplus(f::ThunkArr, g::ThunkArr) = begin
add_tup = (x,n) -> (x[1]+n,x[2])
add_ind = (x,n) -> x+n
cf = f
cg = g
n_thunks = vcat(cf.thunks, cg.thunks)
n_input = vcat(cf.input, add_tup.(cg.input,size(cf.thunks)[1]))
n_output = vcat(cf.output, add_ind.(cg.output,size(cf.thunks)[1]))
ThunkArr(n_input, n_output, n_thunks)
end

id(V::DagDom) = begin
add_port = x -> (x,1)
id_thunks = [delayed(identity)(1) for x in 1:V.N]
id_input = add_port.(Array(1:V.N))
id_output = Array(1:V.N)
ThunkArr(id_input, id_output, id_thunks)
end

braid(V::DagDom, W::DagDom) = begin
vw_id = id(V.N + W.N)
vw_id.output = vcat(vw_id.output[V.N+1:end],vw_id.output[1:V.N])
return vw_id
end

#adjoint(f::MatrixThunk) = MatrixThunk(delayed(adjoint)(f.thunk), f.codom, f.dom)
#+(f::MatrixThunk, g::MatrixThunk) = MatrixThunk(delayed(+)(f.thunk, g.thunk), f.dom, f.codom)

#compose(f::MatrixThunk, g::MatrixThunk) =
# MatrixThunk(delayed(*)(g.thunk,f.thunk), g.dom, f.codom)
#id(V::DagDom) = MatrixThunk(LMs.UniformScalingMap(1, V.N))

#oplus(V::DagDom, W::DagDom) = DagDom(V.N + W.N)
#oplus(f::MatrixThunk, g::MatrixThunk) =
# MatrixThunk(delayed((f,g)->LMs.BlockDiagonalMap(f,g))(f.thunk, g.thunk),
# f.dom+g.dom, f.codom+g.codom)
#
#mzero(::Type{DagDom}) = DagDom(0)
#braid(V::DagDom, W::DagDom) =
#MatrixThunk(LinearMap(braid_lm(V.N), braid_lm(W.N), W.N+V.N, V.N+W.N))

#mcopy(V::DagDom) = MatrixThunk(LinearMap(mcopy_lm, plus_lm, 2*V.N, V.N))
#delete(V::DagDom) = MatrixThunk(LinearMap(delete_lm, zero_lm(V.N), 0, V.N))
#plus(V::DagDom) = MatrixThunk(LinearMap(plus_lm, mcopy_lm, V.N, 2*V.N))

#plus(f::MatrixThunk, g::MatrixThunk) = f+g
#scalar(V::DagDom, c::Number) = MatrixThunk(LMs.UniformScalingMap(c, V.N))
#antipode(V::DagDom) = scalar(V, -1)
end

#braid_lm(n::Int) = x::AbstractVector -> vcat(x[n+1:end], x[1:n])
#mcopy_lm(x::AbstractVector) = vcat(x, x)
#delete_lm(x::AbstractVector) = eltype(x)[]
#plus_lm(x::AbstractVector) = begin
# n = length(x) ÷ 2
# x[1:n] + x[n+1:end]
#end
#zero_lm(n::Int) = x::AbstractVector -> zeros(eltype(x), n)

end
12 changes: 12 additions & 0 deletions experiments/DaggerInstanceGLA/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name = "DaggerInstanceGLA"
uuid = "eaa498f4-19f8-4d9b-a7d5-989fa76892d7"
authors = ["bosonbaas <[email protected]>"]
version = "0.1.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
82 changes: 82 additions & 0 deletions experiments/DaggerInstanceGLA/src/DaggerInstanceGLA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module DaggerInstanceGLA

import Base: +
using AutoHashEquals

export DagDom, LinearMap, MatrixThunk,
dom, codom, adjoint, +, compose, id, oplus, mzero,
braid, mcopy, delete, plus, zero, scalar, antipode,
using Dagger
using LinearAlgebra
using Catlab.LinearAlgebra.GraphicalLinearAlgebra
using Catlab, Catlab.Doctrines
import Catlab.Doctrines:
Ob, Hom, dom, codom, compose, ⋅, ∘, id, oplus, ⊕, mzero, braid,
dagger, dunit, dcounit, mcopy, Δ, delete, ◊, mmerge, ∇, create, □,
plus, zero, coplus, cozero, meet, top, join, bottom
using LinearMaps
import LinearMaps: adjoint
const LMs = LinearMaps


@auto_hash_equals struct DagDom
N::Int
end

# This structure was created to keep track of dom and codom information.
# This information can be updated efficiently, and keeping it here keeps
# LinearFunctions from having to think the thunk each time the dom or codom
# is queried

struct MatrixThunk
thunk::Thunk
dom::Int
codom::Int
end

MatrixThunk(A::LinearMap) = begin
MatrixThunk(delayed(identity)(A), size(A,2), size(A,1))
end

@instance LinearFunctions(DagDom, MatrixThunk) begin

adjoint(f::MatrixThunk) = MatrixThunk(delayed(adjoint)(f.thunk), f.codom, f.dom)
+(f::MatrixThunk, g::MatrixThunk) = MatrixThunk(delayed(+)(f.thunk, g.thunk), f.dom, f.codom)

dom(f::MatrixThunk) = f.dom
codom(f::MatrixThunk) = f.codom

compose(f::MatrixThunk, g::MatrixThunk) =
MatrixThunk(delayed(*)(g.thunk,f.thunk), g.dom, f.codom)
id(V::DagDom) = MatrixThunk(LMs.UniformScalingMap(1, V.N))

oplus(V::DagDom, W::DagDom) = DagDom(V.N + W.N)
oplus(f::MatrixThunk, g::MatrixThunk) =
MatrixThunk(delayed((f,g)->LMs.BlockDiagonalMap(f,g))(f.thunk, g.thunk),
f.dom+g.dom, f.codom+g.codom)

mzero(::Type{DagDom}) = DagDom(0)
braid(V::DagDom, W::DagDom) =
MatrixThunk(LinearMap(braid_lm(V.N), braid_lm(W.N), W.N+V.N, V.N+W.N))

mcopy(V::DagDom) = MatrixThunk(LinearMap(mcopy_lm, plus_lm, 2*V.N, V.N))
delete(V::DagDom) = MatrixThunk(LinearMap(delete_lm, zero_lm(V.N), 0, V.N))
plus(V::DagDom) = MatrixThunk(LinearMap(plus_lm, mcopy_lm, V.N, 2*V.N))
zero(V::DagDom) = MatrixThunk(LinearMap(zero_lm(V.N), delete_lm, V.N, 0))

plus(f::MatrixThunk, g::MatrixThunk) = f+g
scalar(V::DagDom, c::Number) = MatrixThunk(LMs.UniformScalingMap(c, V.N))
antipode(V::DagDom) = scalar(V, -1)
end

braid_lm(n::Int) = x::AbstractVector -> vcat(x[n+1:end], x[1:n])
mcopy_lm(x::AbstractVector) = vcat(x, x)
delete_lm(x::AbstractVector) = eltype(x)[]
plus_lm(x::AbstractVector) = begin
n = length(x) ÷ 2
x[1:n] + x[n+1:end]
end
zero_lm(n::Int) = x::AbstractVector -> zeros(eltype(x), n)

end
31 changes: 31 additions & 0 deletions experiments/DaggerInstanceGLA/test/DaggerInstanceGLA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Catlab.LinearAlgebra.GraphicalLinearAlgebra,
DaggerInstanceGLA, Catlab, Catlab.Doctrines
using Test

V, W = Ob(FreeLinearFunctions, :V, :W)
f, g, h = Hom(:f, V, W), Hom(:g, V, W), Hom(:h, W, W)

fmap = LinearMap([1 1 1 1; 2 2 2 3; 3 3 3 3.0])
gmap = LinearMap([3 1 1 1; -1 -1 -1 -1; 0 2 0 0.0])
hmap = LinearMap([1 4 -1; -1 1 0; 1 1 -1.0])

fop = MatrixThunk(fmap)
gop = MatrixThunk(gmap)
hop = MatrixThunk(hmap)

d = Dict(f=>fop, g=>gop, h => hop, V => DagDom(4), W=>DagDom(3))
F(ex) = functor((DagDom, MatrixThunk), ex, generators=d)

dmap = Dict(f=>fmap, g=>gmap, h => hmap, V => LinearMapDom(4), W=>LinearMapDom(3))
Fmap(ex) = functor((LinearMapDom, LinearMap), ex, generators=dmap)

M = plus(f,g)⋅h
N = adjoint(g)⋅f⋅mcopy(W)+h⋅mcopy(W)
O = (adjoint(g)⊕f)+(h⊕(g⋅adjoint(f)))

expressions = [M,N,O]

for expression in expressions
@test Matrix(collect(F(expression).thunk)) ==
Matrix(Fmap(expression))
end
5 changes: 5 additions & 0 deletions experiments/DaggerInstanceGLA/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using Test

@testset "DaggerInstanceGLA" begin
include("DaggerInstanceGLA.jl")
end