Skip to content

Commit

Permalink
Add network calibration functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Jul 18, 2024
1 parent 99890ed commit 1cf8303
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 9 deletions.
1 change: 1 addition & 0 deletions examples/intermediate/StokesIImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b"
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand Down
35 changes: 35 additions & 0 deletions playground/network_calibration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using Comrade
using Enzyme
using Optimization
using OptimizationOptimisers
using AdvancedHMC
using Distributions, DistributionsAD
using CairoMakie
using Plots
using Pyehtim

function network_calibration(obs::EHTObservationTable{<:Comrade.EHTVisibilityAmplitudeDatum},
zbl_flux::Real,
netcal_bl::NTuple{2, Symbol}...;
gamp_σ = 0.3)

obsnc = Comrade.prepare_netcal_data(obs, netcal_bl...)
skym = Comrade.NetworkCalSkyModel(zbl_flux, netcal_bl)

netcal_prior = (
AA = IIDSitePrior(IntegSeg(), Normal(0.0, 0.1)),
AX = IIDSitePrior(IntegSeg(), Normal(0.0, gamp_σ)),
SW = IIDSitePrior(IntegSeg(), Normal(0.0, gamp_σ)),
MM = IIDSitePrior(IntegSeg(), Normal(0.0, gamp_σ)),
)
intprior = (
lg = ArrayPrior(IIDSitePrior(IntegSeg(), Normal(0.0, 0.001));
netcal_prior...),
)

J = SingleStokesGain(x->@inline exp(x.lg))
intm = InstrumentModel(J, intprior)
# return obsnc
post = VLBIPosterior(skym, intm, obsnc)
return post, obsnc
end
1 change: 1 addition & 0 deletions src/Comrade.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ include("visualizations/visualizations.jl")
include("dirty_image.jl")
include("mrf_image.jl")
include("rules.jl")
include("network_cal.jl")



Expand Down
10 changes: 5 additions & 5 deletions src/instrument/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ intout(vis::AbstractArray{<:CoherencyMatrix{A,B,T}}) where {A,B,T<:Complex} = si

function apply_instrument(vis, J::ObservedInstrumentModel, x)
vout = intout(vis)
_apply_instrument!(vout, vis, J, x.instrument)
_apply_instrument!(parent(vout), parent(vis), J, x.instrument)
return vout
end

Expand All @@ -227,10 +227,10 @@ end


function _apply_instrument!(vout, vis, J::ObservedInstrumentModel, xint)
for i in eachindex(vout, vis)
vout[i] = apply_jones(vis[i], i, J, xint)
end
# vout .= apply_jones.(vis, eachindex(vis), Ref(J), Ref(x))
# for i in eachindex(vout, vis)
# vout[i] = apply_jones(vis[i], i, J, xint)
# end
vout .= apply_jones.(vis, eachindex(vis), Ref(J), Ref(xint))
return nothing
end

Expand Down
71 changes: 71 additions & 0 deletions src/network_cal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
NetworkCalibrationSkyModel(zbl_flux, netcal_bl)
Constructs a SkyModel that represents what is done for network calibration.
Network calibration requires a special model since there is no actual image used. Instead
we assume that the sky has some total flux given by `zbl_flux` and the rest of the
amplitudes are the actual model parameters.
!!! note
By default we will assume that the amplitudes are a flat
prior from [0, zbl_flux] to be maximally permissive.
!!! note
We need a special skymodel for network calibration since the model is not an image but
rather we directly fit the visibility amplitudes for non-intrasite baselines.
# Arguments
- `zbl_flux` : The apriori measured total flux of the object.
- `netcal_bl` : The baselines that are considered to be co-located for network calibration.
"""
Base.@kwdef struct NetworkCalSkyModel{Z<:Real, B} <: AbstractSkyModel
zbl_flux::Z
netcal_bl::B
end

# From LogExpFunctions.jl
@inline _logistic_bounds(::Float16) = (Float16(-16.64), Float16(7.625))
@inline _logistic_bounds(::Float32) = (-103.27893f0, 16.635532f0)
@inline _logistic_bounds(::Float64) = (-744.4400719213812, 36.7368005696771)

@inline function elogistic(x::Union{Float16, Float32, Float64})
e = @inline exp(x)
lower, upper = _logistic_bounds(x)
return x < lower ? zero(x) : x > upper ? one(x) : e / (one(x) + e)
end

function set_array(m::NetworkCalSkyModel, array::AbstractArrayConfiguration)
dtbl = datatable(array)
sites = dtbl.sites

netcalset = m.netcal_bl
intrainds = findall(x->Set(x)Set.(netcalset), sites)
fixvals = fill(0.0, length(intrainds))
ampinds = setdiff(eachindex(sites), intrainds)
dists = Distributions.MvNormal(Diagonal(fill(1.78^2, length(ampinds))))

d = PartiallyConditionedDist(dists, ampinds, intrainds, fixvals)
skypr = d
f = let zblflux=m.zbl_flux, intrainds=intrainds
x->(y = 2 .*zblflux.*elogistic.(x); y[intrainds] .= zblflux; y)
end
g = imagepixels(μas2rad(100.0), μas2rad(100.0), 256, 256)
return ObservedSkyModel(m, FourierDualDomain(g, array, NFFTAlg()), f), skypr
end

function idealvisibilities(m::ObservedSkyModel{<:NetworkCalSkyModel}, x)
return m.metadata(x.sky)
end

function skymodel(m::ObservedSkyModel{<:NetworkCalSkyModel}, x)
return m.metadata(x)
end

function prepare_netcal_data(obs::EHTObservationTable{<:EHTVisibilityAmplitudeDatum}, netcal_bl...)
S = Set(Iterators.flatten(netcal_bl))
array = arrayconfig(obs)
inds = findall(x->(x[1]S || x[2]S), datatable(array).sites)
# We find all baselines that are connected to our network calibration baselines
return obs[inds]
end
2 changes: 1 addition & 1 deletion src/observations/datums.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ build_datum(F::Type{<:AbstractVisibilityDatum}, m, e, b) = F(m, e, b)
abstract type AbstractSinglePolDatum{P,S} <: AbstractVisibilityDatum{S} end
abstract type ClosureProducts{P,T} <: AbstractSinglePolDatum{P,T} end

VLBISkyModels.polarization(p::AbstractSinglePolDatum{Pol}) where {Pol} = Pol
VLBISkyModels.polarization(::AbstractSinglePolDatum{Pol}) where {Pol} = Pol


abstract type AbstractBaselineDatum end
Expand Down
11 changes: 8 additions & 3 deletions src/posterior/vlbiposterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ post = VLBIPosterior(skym, intmodel, dlcamp, dcphase)
```
"""
function VLBIPosterior(
skymodel::SkyModel,
skymodel::AbstractSkyModel,
instrumentmodel::AbstractInstrumentModel,
dataproducts::EHTObservationTable...;
)
Expand All @@ -76,7 +76,7 @@ function VLBIPosterior(
typeof(sky), typeof(int)}(dataproducts, ls, total_prior, sky, int)
end

VLBIPosterior(skymodel::SkyModel, dataproducts::EHTObservationTable...) =
VLBIPosterior(skymodel::AbstractSkyModel, dataproducts::EHTObservationTable...) =
VLBIPosterior(skymodel, IdealInstrumentModel(), dataproducts...)

function combine_prior(skyprior, instrumentmodelprior)
Expand All @@ -93,9 +93,14 @@ end


function combine_prior(::Tuple{}, instrumentmodel)
return NamedDist((instrument=skymodel.instrument,))
return NamedDist((;instrument=instrumentmodel,))
end

function combine_prior(::NamedTuple{}, instrumentmodel)
return NamedDist((;instrument=instrumentmodel,))
end


function Base.show(io::IO, mime::MIME"text/plain", post::VLBIPosterior)
printstyled(io, "VLBIPosterior"; bold=true, color=:light_magenta)
println(io)
Expand Down

0 comments on commit 1cf8303

Please sign in to comment.