Skip to content

Commit

Permalink
Ok Enzyme is now fast!
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Jul 26, 2024
1 parent 2f7fd64 commit fe98c02
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 41 deletions.
2 changes: 2 additions & 0 deletions examples/intermediate/PolarizedImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
12 changes: 8 additions & 4 deletions examples/intermediate/PolarizedImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ function sky(θ, metadata)
(;c, σ, p, p0, pσ, angparams) = θ
(;ftot, grid) = metadata
## Build the stokes I model
rast = ftot*to_simplex(CenteredLR(), σ.*c.params)
rast = to_simplex(CenteredLR(), σ.*c.params)
rast .= ftot.*rast
## The total polarization fraction is modeled in logit space so we transform it back
pim = logistic.(p0 .+.*p.params)
## Build our IntensityMap
Expand Down Expand Up @@ -232,11 +233,12 @@ skym = SkyModel(sky, skyprior, grid; metadata=skymeta)
# the gain matrix is a diagonal 2x2 matrix the function must return a 2-element tuple.
# The first element of the tuple is the gain for the first polarization feed (R) and the
# second is the gain for the second polarization feed (L).
G = JonesG() do x
function fgain(x)
gR = exp(x.lgR + 1im*x.gpR)
gL = gR*exp(x.lgrat + 1im*x.gprat)
return gR, gL
end
G = JonesG(fgain)
# Note that we are using the Julia `do` syntax here to define an anonymous function. This
# could've also been written as
# ```julia
Expand All @@ -251,12 +253,14 @@ end
# d2 1
# Therefore, there are 2 free parameters for the JonesD our parameterization function
# must return a 2-element tuple. For d-terms we will use a re-im parameterization.
D = JonesD() do x
function fdterms(x)
dR = complex(x.dRx, x.dRy)
dL = complex(x.dLx, x.dLy)
return dR, dL
end

D = JonesD(fdterms)

# Finally we define our response Jones matrix. This matrix is a basis transform matrix
# plus the feed rotation angle for each station. These are typically set by the telescope
# so there are no free parameters, so no parameterization is necessary.
Expand All @@ -273,7 +277,7 @@ J = JonesSandwich(splat(*), G, D, R)
intprior = (
lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))),
gpR = ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv^2))); refant=SEFDReference(0.0), phase=false),
lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=true),
lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false),
gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)); refant = SingleReference(:AA, 0.0)),
dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))),
dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))),
Expand Down
2 changes: 1 addition & 1 deletion src/Comrade.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import Distributions as Dists
using DocStringExtensions
using ChainRulesCore
using Enzyme
Enzyme.API.runtimeActivity!(true)
# Enzyme.API.runtimeActivity!(true)
using FillArrays: Fill
using ForwardDiff
using IntervalSets
Expand Down
16 changes: 9 additions & 7 deletions src/instrument/instrument_transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,19 @@ end
function _instrument_transform_with(flag::TV.LogJacFlag, m::MarkovInstrumentTransform, x, index)
(;inner_transform, site_map) = m
y, ℓ, index = TV.transform_with(flag, inner_transform, x, index)
site_sum!(y, site_map)
return y, ℓ, index
yout = site_sum(y, site_map)
return yout, ℓ, index
end

function site_sum!(y, site_map::SiteLookup)
map(site_map.lookup) do site
ys = @view y[site]
@inline function site_sum(y, site_map::SiteLookup)
yout = similar(y)
for site in site_map.lookup
ys = @inbounds @view y[site]
# y should never alias so we should be fine here.
cumsum!(ys, (ys))
youts = @inbounds @view yout[site]
cumsum!(youts, ys)
end
return nothing
return yout
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_instrument_transform_with), flag, m::MarkovInstrumentTransform, x, index)
Expand Down
10 changes: 5 additions & 5 deletions src/instrument/jonesmatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end
struct JonesG{F} <: AbstractJonesMatrix
param_map::F
end
construct_jones(::JonesG, x::NTuple{2, T}, index, site) where {T} = Diagonal(SVector{2, T}(x))
construct_jones(::JonesG, x::NTuple{2, T}, index, site) where {T} = SMatrix{2, 2, T, 4}(x[1], zero(T), zero(T), x[2])


"""
Expand Down Expand Up @@ -87,7 +87,7 @@ end
struct JonesD{F} <: AbstractJonesMatrix
param_map::F
end
construct_jones(::JonesD, x::NTuple{2, T}, index, site) where {T} = SMatrix{2, 2, T, 4}(1, x[2], x[1], 1)
Base.@propagate_inbounds construct_jones(::JonesD, x::NTuple{2, T}, index, site) where {T} = SMatrix{2, 2, T, 4}(1, x[2], x[1], 1)


"""
Expand All @@ -112,7 +112,7 @@ end
struct GenericJones{F} <: AbstractJonesMatrix
param_map::F
end
construct_jones(::GenericJones, x::NTuple{4, T}, index, site) where {T} = SMatrix{2, 2, T, 4}(x[1], x[2], x[3], x[4])
Base.@propagate_inbounds construct_jones(::GenericJones, x::NTuple{4, T}, index, site) where {T} = SMatrix{2, 2, T, 4}(x[1], x[2], x[3], x[4])

"""
JonesF(;add_fr=true)
Expand All @@ -128,7 +128,7 @@ struct JonesF{M} <: AbstractJonesMatrix
matrices::M
end
JonesF() = JonesF(nothing)
construct_jones(J::JonesF, x, index, ::Val{M}) where {M} = J.matrices[index][M]
Base.@propagate_inbounds construct_jones(J::JonesF, x, index, ::Val{M}) where {M} = J.matrices[index][M]
param_map(::JonesF, x) = x
function preallocate_jones(::JonesF, array::AbstractArrayConfiguration, ref)
field_rotations = build_feedrotation(array)
Expand All @@ -149,7 +149,7 @@ Base.@kwdef struct JonesR{M} <: AbstractJonesMatrix
matrices::M = nothing
add_fr::Bool = true
end
construct_jones(J::JonesR, x, index, ::Val{M}) where {M} = J.matrices[M][index]
Base.@propagate_inbounds construct_jones(J::JonesR, x, index, ::Val{M}) where {M} = @inbounds J.matrices[M][index]
param_map(::JonesR, x) = x

function preallocate_jones(J::JonesR, array::AbstractArrayConfiguration, ref)
Expand Down
51 changes: 35 additions & 16 deletions src/instrument/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ end
# Site lookup is const so we add a method so we can signal
# to Enzyme that it is not differentiable.
sitelookup(x::ObservedInstrumentModel) = x.bsitelookup
instrument(x::ObservedInstrumentModel) = x.instrument
refbasis(x::ObservedInstrumentModel) = x.refbasis
Enzyme.EnzymeRules.inactive(::typeof(sitelookup), args...) = nothing
Enzyme.EnzymeRules.inactive(::typeof(instrument), args...) = nothing
Enzyme.EnzymeRules.inactive(::typeof(refbasis), args...) = nothing

function Base.show(io::IO, mime::MIME"text/plain", m::ObservedInstrumentModel)
printstyled(io, "ObservedInstrumentModel"; bold=true, color=:light_cyan)
Expand Down Expand Up @@ -212,31 +216,46 @@ intout(vis::AbstractArray{<:StokesParams{T}}) where {T<:Complex} = similar(vis,
intout(vis::AbstractArray{T}) where {T<:Complex} = similar(vis, T)
intout(vis::AbstractArray{<:CoherencyMatrix{A,B,T}}) where {A,B,T<:Complex} = similar(vis, SMatrix{2,2, T, 4})

intout(vis::StructArray{<:StokesParams{T}}) where {T<:Complex} = StructArray{SMatrix{2,2, T, 4}}((vis.I, vis.Q, vis.U, vis.V))

function apply_instrument(vis, J::ObservedInstrumentModel, x)
vout = vis#intout(vis)
_apply_instrument!(baseimage(vout), baseimage(vis), J, x.instrument)
@inline function apply_instrument(vis, J::ObservedInstrumentModel, x)
# vout = intout(parent(vis))
vis .= apply_jones.(vis, eachindex(vis), Ref(J), Ref(x.instrument))
vout = intout(parent(vis))
return vout
end

function apply_instrument(vis, J::ObservedInstrumentModel{<:Union{JonesR, JonesF}}, x)
vout = vis#intout(vis)
# function apply_instrument(vis, J::ObservedInstrumentModel, x)
# xint = x.instrument
# vout = map(Array(vis), eachindex(vis)) do v, i
# return apply_jones(v, i, J, xint)
# end
# # vout = apply_jones.(vis, eachindex(vis), Ref(J), Ref(x.instrument))
# return UnstructuredMap(StructArray(vout), axisdims(vis))
# end


@inline function apply_instrument(vis, J::ObservedInstrumentModel{<:Union{JonesR, JonesF}}, x)
vout = intout(parent(vis))
_apply_instrument!(baseimage(vout), baseimage(vis), J, (;))
return vout
return UnstructuredMap(vout, axisdims(vis))
end

Enzyme.EnzymeRules.inactive(::typeof(Base.Ref), ::ObservedInstrumentModel) = nothing

function _apply_instrument!(vout, vis, J::ObservedInstrumentModel, xint)
# @inbounds for i in eachindex(vout, vis)
# vout[i] = apply_jones(vis[i], i, J, xint)
# end
vout .= apply_jones.(vis, eachindex(vis), Ref(J), Ref(xint))
return nothing
end
# @inline function _apply_instrument!(vout, vis, J::ObservedInstrumentModel, xint)
# # @inbounds for i in eachindex(vout, vis)
# # v = apply_jones(vis[i], i, J, xint)
# # vout[i] = v
# # end
# vout .= apply_jones.(vis, eachindex(vis), Ref(J), Ref(xint))
# return nothing
# end

@inline get_indices(bsitemaps, index, ::Val{1}) = map(x->getindex(x.indices_1, index), bsitemaps)
@inline get_indices(bsitemaps, index, ::Val{2}) = map(x->getindex(x.indices_2, index), bsitemaps)
@inline get_params(x::NamedTuple{N}, indices::NamedTuple{N}) where {N} = NamedTuple{N}(map((xx, ii)->getindex(xx, ii), x, indices))
@inline get_params(x::NamedTuple{N}, indices::NamedTuple{N}) where {N} = NamedTuple{N}(map(getindex, values(x), values(indices)))
# @inline get_params(x::NamedTuple{N}, indices::NamedTuple{N}) where {N} = NamedTuple{N}(ntuple(i->getindex(x[i], indices[i]), Val(length(N))))

# We need this because Enzyme seems to crash when generating code for this
# TODO try to find MWE and post to Enzyme.jl
Expand All @@ -245,14 +264,14 @@ Enzyme.EnzymeRules.inactive(::typeof(get_indices), args...) = nothing
@inline function build_jones(index::Int, J::ObservedInstrumentModel, x, ::Val{N}) where N
indices = get_indices(sitelookup(J), index, Val(N))
params = get_params(x, indices)
return jonesmatrix(J.instrument, params, index, Val(N))
return jonesmatrix(instrument(J), params, index, Val(N))
end


@inline function apply_jones(v, index::Int, J::ObservedInstrumentModel, x)
j1 = build_jones(index, J, x, Val(1))
j2 = build_jones(index, J, x, Val(2))
vout = _apply_jones(v, j1, j2, J.refbasis)
vout = _apply_jones(v, j1, j2, refbasis(J))
return vout
end

Expand Down
8 changes: 4 additions & 4 deletions src/instrument/site_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ EnzymeRules.inactive(::(typeof(Base.size)), ::SiteArray) = nothing
Base.parent(a::SiteArray) = getfield(a, :data)
Base.size(a::SiteArray) = size(parent(a))
Base.IndexStyle(::Type{<:SiteArray{T, N, A}}) where {T, N, A} = Base.IndexStyle(A)
Base.@propagate_inbounds Base.getindex(a::SiteArray, i::Integer) = getindex(parent(a), i)
Base.@propagate_inbounds Base.getindex(a::SiteArray, I::Vararg{Int, N}) where {N} = getindex(parent(a), I...)
Base.setindex!(m::SiteArray, v, i::Int) = setindex!(parent(m), v, i)
Base.setindex!(m::SiteArray, v, i::Vararg{Int, N}) where {N} = setindex!(parent(m), v, i...)
Base.@propagate_inbounds Base.getindex(a::SiteArray{T}, i::Integer) where {T} = @inbounds(getindex(parent(a), i))::T
Base.@propagate_inbounds Base.getindex(a::SiteArray, I::Vararg{Integer, N}) where {N} = getindex(parent(a), I...)
Base.setindex!(m::SiteArray, v, i::Integer) = setindex!(parent(m), v, i)
Base.setindex!(m::SiteArray, v, i::Vararg{Integer, N}) where {N} = setindex!(parent(m), v, i...)
Base.@propagate_inbounds function Base.getindex(m::SiteArray, I...)
return SiteArray(getindex(parent(m), I...), getindex(m.times, I...), getindex(m.frequencies, I...), getindex(m.sites, I...))
end
Expand Down
7 changes: 4 additions & 3 deletions src/posterior/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Returns the instrument model of the posterior `d`.
"""
instrumentmodel(d::AbstractVLBIPosterior) = getfield(d, :instrumentmodel)
HypercubeTransform.dimension(d::AbstractVLBIPosterior) = length(d.prior)
Enzyme.EnzymeRules.inactive(::typeof(instrumentmodel), args...) = nothing

@noinline logprior_ref(d, x) = logprior(d, x[])

Expand Down Expand Up @@ -114,7 +115,7 @@ Computes the forward model visibilities of the posterior `d` with parameters `θ
Note these are the complex visiblities or the coherency matrices, not the actual
data products observed.
"""
function forward_model(d::AbstractVLBIPosterior, θ)
@inline function forward_model(d::AbstractVLBIPosterior, θ)
vis = idealvisibilities(skymodel(d), θ)
return apply_instrument(vis, instrumentmodel(d), θ)
end
Expand All @@ -124,7 +125,7 @@ end
Computes the log-likelihood of the posterior `d` with parameters `θ`.
"""
function loglikelihood(d::AbstractVLBIPosterior, θ)
@inline function loglikelihood(d::AbstractVLBIPosterior, θ)
vis = forward_model(d, θ)
# Convert because of conventions
return logdensityofvis(d.lklhds, vis)
Expand Down Expand Up @@ -152,7 +153,7 @@ end



function logdensityofvis(lklhds, vis::AbstractArray)
@inline function logdensityofvis(lklhds, vis::AbstractArray)
fl = Base.Fix2(logdensityof, vis)
ls = map(fl, lklhds)
return sum(ls)
Expand Down
2 changes: 1 addition & 1 deletion src/posterior/likelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ struct ConditionedLikelihood{F, O}
kernel::F
obs::O
end
DensityInterface.logdensityof(d::ConditionedLikelihood, μ) = logdensityof(d.kernel(μ), d.obs)
@inline DensityInterface.logdensityof(d::ConditionedLikelihood, μ) = logdensityof(@inline(d.kernel)), d.obs)


"""
Expand Down

0 comments on commit fe98c02

Please sign in to comment.