Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Sep 21, 2023
1 parent fd37563 commit b1e3aa9
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 37 deletions.
2 changes: 2 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b"
ComradeAHMC = "a4336a5c-78bc-4363-8a90-ce3fa9d3abe4"
ComradeBase = "6d8c423b-a35f-4ef1-850c-862fe21f82c4"
Expand All @@ -30,6 +31,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Expand Down
41 changes: 10 additions & 31 deletions playground/enzyme_dft_vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,60 +38,39 @@ dvis = extract_table(obs, ComplexVisibilities())


function sky(θ, metadata)
(;c, grid, cache) = metadata
img = IntensityMap(reshape(c, size(grid)), grid)
c = θ
(;grid, cache) = metadata
img = IntensityMap(c, grid)
m = ContinuousImage(img, cache)
return m
end

function instrument(θ, metadata)
(; lgamp,) = θ
(; gcache,) = metadata
## Now form our instrument model
gvis = exp.(lgamp)
jgamp = jonesStokes(gvis, gcache)
return JonesModel(jgamp)
end

npix = 12
fovx = μas2rad(80.0)
fovy = μas2rad(80.0)

grid = imagepixels(fovx, fovy, npix, npix)
buffer = IntensityMap(zeros(npix, npix), grid)
cache = create_cache(DFTAlg(dvis), buffer, DeltaPulse())
cache = create_cache(NFFTAlg(dvis), buffer, DeltaPulse())


gcache = jonescache(dvis, ScanSeg())
gcachep = jonescache(dvis, ScanSeg(); autoref=SEFDReference((complex(1.0))))

using VLBIImagePriors
instrumentmetadata = (;gcache, gcachep)

using Distributions
using DistributionsAD
distamp = station_tuple(dvis, Normal(0.0, 0.1); LM = Normal(1.0))

distphase = station_tuple(dvis, DiagonalVonMises(0.0, inv^2)))


prior = NamedDist(
lgamp = CalPrior(distamp, gcache),
# gphase = CalPrior(distphase, gcachep),
)
prior = ImageUniform(npix, npix)

skymetadata = (;c=rand(npix, npix), grid, cache)
instrumentmetadata = (;gcache, gcachep)
lklhd = RadioLikelihood(sky, instrument, dvis; skymeta=skymetadata, instrumentmeta=instrumentmetadata)
# instrumentmetadata = (;gcache, gcachep)
lklhd = RadioLikelihood(sky, dvis; skymeta=skymetadata)
post = Posterior(lklhd, prior)

tpost = asflat(post)
ndim = dimension(tpost)

using Enzyme
Enzyme.API.printall!(false)
Enzyme.API.runtimeActivity!(true)
# Enzyme.API.printall!(false)
x0 = randn(rng, ndim)
dx0 = zero(x0)
lt=logdensityof(tpost)
= logdensityof(tpost, x0)
autodiff(Reverse, logdensityof, Duplicated(tpost, deepcopy(tpost)), Duplicated(x0, dx0))
@time autodiff(Reverse, logdensityof, Const(tpost), Duplicated(x0, fill!(dx0, 0.0)))
6 changes: 3 additions & 3 deletions playground/enzyme_geom_vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ prior = NamedDist(

# Putting it all together we form our likelihood and posterior objects for optimization and
# sampling.
lklhd = RadioLikelihood(sky, instrument, dvis; instrumentmeta=metadata)
lklhd = RadioLikelihood(sky, instrument, dvis; instrumentmeta = metadata)
post = Posterior(lklhd, prior)

# ## Reconstructing the Image and Instrument Effects
Expand All @@ -122,10 +122,10 @@ ndim = dimension(tpost)
# inference packages use this interface as well.
using Zygote
using Enzyme
Enzyme.API.runtimeActivity!(false)
Enzyme.API.runtimeActivity!(true)

x0 = randn(rng, ndim)
= logdensityof(tpost)
gz, = Zygote.gradient(ℓ, x0)
dx0 = zero(x0)
autodiff(Reverse, logdensityof, (Const(tpost)), Duplicated(x0, dx0))
autodiff(Reverse, logdensityof, Duplicated(tpost, deepcopy(tpost)), Duplicated(x0, dx0))
21 changes: 19 additions & 2 deletions src/calibration/caltable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,15 @@ end
function fill_gmat!(gmat, v::FixedSeg, lookup, i, allstations, alltimes, gains)
t = findfirst(t->(t==alltimes[i]), unique(alltimes))
c = lookup[allstations[i]]
gmat[t,c] = v.value
gmat[t,c] .= v.value
end

function stations(g::JonesCache)
s1 = g.schema.sites
if !(g.references isa AbstractVector{<:NoReference})
return sort(unique(vcat(s1, getproperty.(g.references, :site))))
end
return sort(unique(s1))
end


Expand All @@ -270,7 +278,8 @@ ct[1, :]
function caltable(g::JonesCache, gains::AbstractVector, f=identity)
@argcheck length(g.schema.times) == length(gains)

stations = sort(unique(g.schema.sites))
stations = Comrade.stations(g)
println(stations)
times = unique(g.schema.times)
gmat = Matrix{Union{eltype(gains), Missing}}(missing, length(times), length(stations))
gmat .= 0.0
Expand All @@ -284,6 +293,14 @@ function caltable(g::JonesCache, gains::AbstractVector, f=identity)
fill_gmat!(gmat, seg, lookup, i, allstations, alltimes, gains)
end

for i in eachindex(g.references)
s = g.references[i]
if !(s isa NoReference)
c = lookup[s.site]
gmat[i, c] = s.scheme.value
end
end

replace!(x->(x==0 ? missing : x), gmat)
gmat .= f.(gmat)
return CalTable(stations, lookup, times, gmat)
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/radiolikelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ phase(vis::AbstractArray{<:Complex}) = angle.(vis)



function DensityInterface.logdensityof(d::RadioLikelihood, θ::NamedTuple)
function DensityInterface.logdensityof(d::RadioLikelihood, θ)
ac = d.positions
m = vlbimodel(d, θ)
# Convert because of conventions
Expand Down

0 comments on commit b1e3aa9

Please sign in to comment.