Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dagger Instance #147

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions experiments/DaggerInstanceGLA/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name = "DaggerInstanceGLA"
uuid = "eaa498f4-19f8-4d9b-a7d5-989fa76892d7"
authors = ["bosonbaas <[email protected]>"]
version = "0.1.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
78 changes: 78 additions & 0 deletions experiments/DaggerInstanceGLA/src/DaggerInstanceGLA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
module DaggerInstanceGLA

import Base: +
using AutoHashEquals

export DagDom, LinearMap, MatrixThunk, matrixToThunk,
dom, codom, adjoint, +, compose, id, oplus, mzero,
braid, mcopy, delete, plus, zero, scalar, antipode,
using Dagger
using LinearAlgebra
using Catlab.LinearAlgebra.GraphicalLinearAlgebra
using Catlab, Catlab.Doctrines
import ...Doctrines:
Ob, Hom, dom, codom, compose, ⋅, ∘, id, oplus, ⊕, mzero, braid,
dagger, dunit, dcounit, mcopy, Δ, delete, ◊, mmerge, ∇, create, □,
plus, zero, coplus, cozero, meet, top, join, bottom
using LinearMaps
import LinearMaps: adjoint
const LMs = LinearMaps


@auto_hash_equals struct DagDom
N::Int
end

# This structure was created to keep track of dom and codom information.
# This information can be updated efficiently, and keeping it here keeps
# LinearFunctions from having to think the thunk each time the dom or codom
# is queried

struct MatrixThunk
thunk::Thunk
dom::Int
codom::Int
end

matrixToThunk(A::LinearMap) = begin
epatters marked this conversation as resolved.
Show resolved Hide resolved
epatters marked this conversation as resolved.
Show resolved Hide resolved
delayed(x->x)(A)
end

@instance LinearFunctions(DagDom, MatrixThunk) begin

adjoint(f::MatrixThunk) = MatrixThunk(delayed(adjoint)(f.thunk), f.codom, f.dom)
+(f::MatrixThunk, g::MatrixThunk) = MatrixThunk(delayed(+)(f.thunk, g.thunk), f.dom, f.codom)

dom(f::MatrixThunk) = f.dom
codom(f::MatrixThunk) = f.codom

compose(f::MatrixThunk, g::MatrixThunk) = MatrixThunk(delayed(*)(g.thunk,f.thunk), g.dom, f.codom)
id(V::DagDom) = MatrixThunk(delayed(x->x)(LMs.UniformScalingMap(1, V.N)), V.N, V.N)
epatters marked this conversation as resolved.
Show resolved Hide resolved

oplus(V::DagDom, W::DagDom) = DagDom(V.N + W.N)
oplus(f::MatrixThunk, g::MatrixThunk) = MatrixThunk(delayed(LMs.BlockDiagonalMap)(f.thunk, g.thunk), f.dom+g.dom, f.codom+g.codom)
mzero(::Type{DagDom}) = DagDom(0)
braid(V::DagDom, W::DagDom) =
MatrixThunk(delayed(x->x)(LinearMap(braid_lm(V.N), braid_lm(W.N), W.N+V.N, V.N+W.N)),V.N+W.N, W.N+V.N)

mcopy(V::DagDom) = MatrixThunk(delayed(x->x)(LinearMap(mcopy_lm, plus_lm, 2*V.N, V.N)), V.N, 2*V.N)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be MatrixThunk(delayed(x->LinearMap(mcopy_lm, plus_lm, 2*V.N, V.N)*x), V.N, 2*V.N) etc.

Copy link
Member

@bosonbaas bosonbaas Mar 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see what you're getting at with that, I had a hard time with this one (as well as braid and delete). I think as-is will cause a typeError, since delayed(x->LinearMap(mcopy_lm, plus_lm, 2*V.N, V.N)*x) is not of type Thunk.

From the way it's used in the LinearMaps instance, mcopy returns a morphism (LinearMap object) which is then applied to other morphisms through composition. If we're keeping with that usage (applying mcopy by composition), I think we'll want this to return a MatrixThunk which has an identity Thunk for the LinearMap mcopy, which can then be used as an argument in a compose call.

This acts to copy the internal LinearMap through composition.

delete(V::DagDom) = MatrixThunk(delayed(x->x)(LinearMap(delete_lm, zero_lm(V.N), 0, V.N)), V.N, 0)
plus(V::DagDom) = MatrixThunk(delayed(x->x)(LinearMap(plus_lm, mcopy_lm, V.N, 2*V.N)), 2*V.N, V.N)
zero(V::DagDom) = MatrixThunk(delayed(x->x)(LinearMap(zero_lm(V.N), delete_lm, V.N, 0)), 0, V.N)

plus(f::MatrixThunk, g::MatrixThunk) = f+g
scalar(V::DagDom, c::Number) = MatrixThunk(delayed(x->x)(LMs.UniformScalingMap(c, V.N)), V.N, V.N)
antipode(V::DagDom) = scalar(V, -1)
end

braid_lm(n::Int) = x::AbstractVector -> vcat(x[n+1:end], x[1:n])
mcopy_lm(x::AbstractVector) = vcat(x, x)
delete_lm(x::AbstractVector) = eltype(x)[]
plus_lm(x::AbstractVector) = begin
n = length(x) ÷ 2
x[1:n] + x[n+1:end]
end
zero_lm(n::Int) = x::AbstractVector -> zeros(eltype(x), n)

end
23 changes: 23 additions & 0 deletions experiments/DaggerInstanceGLA/src/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Catlab.LinearAlgebra.GraphicalLinearAlgebra, DaggerInstanceGLA, Catlab, Catlab.Doctrines

V, W = Ob(FreeLinearFunctions, :V, :W)
f, g, h = Hom(:f, V, W), Hom(:g, V, W), Hom(:h, W, W)

fop = MatrixThunk(matrixToThunk(LinearMap([1 1 1 1; 2 2 2 3; 3 3 3 3.0])), 4,3)
gop = MatrixThunk(matrixToThunk(LinearMap([3 1 1 1; -1 -1 -1 -1; 0 2 0 0.0])), 4,3)
hop = MatrixThunk(matrixToThunk(LinearMap([1 4 -1; -1 1 0; 1 1 -1.0])), 3,3)

#fop = LinearMap([1 1 1 1; 2 2 2 3; 3 3 3 3.0])
#gop = LinearMap([3 1 1 1; -1 -1 -1 -1; 0 2 0 0.0])
#hop = LinearMap([1 4 -1; -1 1 0; 1 1 -1.0])


d = Dict(f=>fop, g=>gop, h => hop, V => DagDom(4), W=>DagDom(3))
F(ex) = functor((DagDom, MatrixThunk), ex, generators=d)
#d = Dict(f=>fop, g=>gop, h => hop, V => LinearMapDom(4), W=>LinearMapDom(3))
#F(ex) = functor((LinearMapDom, LinearMap), ex, generators=d)

M = plus(f,g)⋅h

print(Matrix(collect(F(M).thunk)))
epatters marked this conversation as resolved.
Show resolved Hide resolved
print('\n')