Skip to content

Commit

Permalink
Use fast sum for Enzyme reasons
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Sep 12, 2024
1 parent adbb75c commit ffa89f0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
19 changes: 19 additions & 0 deletions src/instrument/jonesmatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,22 @@ function preallocate_jones(J::JonesSandwich, array::AbstractArrayConfiguration,
m2 = map(x->preallocate_jones(x, array, refbasis), J.matrices)
return JonesSandwich(J.jones_map, m2)
end


function forward_jones(v::AbstractJonesMatrix, xs::NamedTuple{N}) where {N}
sm = broadest_sitemap(xs)
bl = map(x->(x,x), sm.sites)
bmaps = map(x->_construct_baselinemap(getproperty.(sm.times, :t0), sm.frequencies, bl, x).indices_1, xs)
vs = map(eachindex(sm.times)) do index
indices = map(x->getindex(x, index), bmaps)
params = NamedTuple{N}(map(getindex, values(xs), values(indices)))
return jonesmatrix(v, params, indices, index)
end
return SiteArray(StructArray(vs), sm)
end

function broadest_sitemap(xs::NamedTuple)
v = values(xs)
return SiteLookup(argmax(x->length(x.times), v))
end

25 changes: 17 additions & 8 deletions src/mrf_image.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,45 @@ Apply multiplicative fluctuations to an image `mimg` with fluctuations `δ`.
The function `f` is applied to the fluctuations and then the the transfored δ are multiplicatively applied
to the image.
"""
function apply_fluctuations(f, mimg::IntensityMap, δ::AbstractArray)
@inline function apply_fluctuations(f, mimg::IntensityMap, δ::AbstractArray)
return IntensityMap(_apply_fluctuations(f, baseimage(mimg), δ), axisdims(mimg))
end

function apply_fluctuations(f, m::AbstractModel, g::AbstractRectiGrid, δ::AbstractArray)
@inline function apply_fluctuations(f, m::AbstractModel, g::AbstractRectiGrid, δ::AbstractArray)
return apply_fluctuations(f, intensitymap(m, g), δ)
end

function apply_fluctuations(t::VLBIImagePriors.LogRatioTransform, m::AbstractModel, g::AbstractRectiGrid, δ::AbstractArray)
@inline function apply_fluctuations(t::VLBIImagePriors.LogRatioTransform, m::AbstractModel, g::AbstractRectiGrid, δ::AbstractArray)
mimg = baseimage(intensitymap(m, g))
return apply_fluctuations(t, IntensityMap(mimg./sum(mimg), g), δ)
end



function apply_fluctuations(mimg::IntensityMap, δ::AbstractArray)
@inline function apply_fluctuations(mimg::IntensityMap, δ::AbstractArray)
return apply_fluctuations(identity, mimg, δ)
end

function _apply_fluctuations(f, mimg::AbstractArray, δ::AbstractArray)
@inline function _apply_fluctuations(f, mimg::AbstractArray, δ::AbstractArray)
return mimg.*f.(δ)
end

_checknorm(m::AbstractArray) = isapprox(sum(m), 1, atol=1e-6)
@noinline _checknorm(m::AbstractArray) = isapprox(sum(m), 1, atol=1e-6)
Enzyme.EnzymeRules.inactive(::typeof(_checknorm), args...) = nothing

function _fastsum(x)
tot = zero(eltype(x))
@simd for i in eachindex(x)
tot += x[i]
end
return tot
end


function _apply_fluctuations(t::VLBIImagePriors.LogRatioTransform, mimg::AbstractArray, δ::AbstractArray)
@argcheck _checknorm(mimg) "Mean image must have unit flux when using log-ratio transformations in apply_fluctuations"
r = to_simplex(t, δ)
r = to_simplex(t, baseimage(δ))
r .= r.*baseimage(mimg)
r .= r./sum(r)
r .= r./_fastsum(r)
return r
end

0 comments on commit ffa89f0

Please sign in to comment.