From b78cc0f7ebe114f8166b8b5d0723b65da545add5 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 15 Dec 2024 20:07:03 -0500 Subject: [PATCH 01/34] Fix up site array --- Project.toml | 2 +- src/instrument/site_array.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 97870ef8..18bbefc1 100644 --- a/Project.toml +++ b/Project.toml @@ -74,7 +74,7 @@ DimensionalData = "0.27, 0.28, 0.29" Distributions = "0.25" DocStringExtensions = "0.8, 0.9" Dynesty = "0.4" -Enzyme = "0.13.0 - 0.13.11" +Enzyme = "0.13" EnzymeCore = "0.8" FillArrays = "1" ForwardDiff = "0.9, 0.10" diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 0e53ef3e..92ec521a 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -136,17 +136,17 @@ end function Base.getindex(arr::SiteArray; F=Base.Colon(), S=Base.Colon(), T=Base.Colon()) T2 = _maybe_all(times(arr), T) F2 = _maybe_all(frequencies(arr), F) - S2 = S isa Base.Colon ? unique(S) : S + S2 = S isa Base.Colon ? unique(sites(arr)) : S return select_region(arr, S2, T2, F2) end -function select_region(arr::SiteArray, S::Symbol, T::Union{IntegrationTime, AbstractInterval}, F::Union{FrequencyChannel, AbstractInterval}) +function select_region(arr::SiteArray, S::Symbol, T::Union{Real, AbstractInterval}, F::Union{Real, AbstractInterval}) select_region(arr, (S,), T, F) end -function select_region(arr::SiteArray, site, T::Union{IntegrationTime, AbstractInterval}, F::Union{FrequencyChannel, AbstractInterval}) - inds = findall(i->((Comrade.sites(arr)[i] ∈ site)&&(Comrade.times(arr)[i] ∈ T)), eachindex(arr)) +function select_region(arr::SiteArray, site, Ti::Union{Real, AbstractInterval}, Fr::Union{Real, AbstractInterval}) + inds = findall(i->((Comrade.sites(arr)[i] ∈ site)&&(Ti ∈ Comrade.times(arr)[i])&&(Fr ∈ Comrade.frequencies(arr)[i])), eachindex(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end From 116d83d90701ec75ab80e1a234e6bc31914b038f Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 15 Dec 2024 20:08:35 -0500 Subject: [PATCH 02/34] Small adjustment to gradient --- ext/ComradeEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ComradeEnzymeExt.jl b/ext/ComradeEnzymeExt.jl index 0eca7e49..9ae85532 100644 --- a/ext/ComradeEnzymeExt.jl +++ b/ext/ComradeEnzymeExt.jl @@ -12,7 +12,7 @@ function LogDensityProblems.logdensity_and_gradient(d::Comrade.TransformedVLBIPo mode = Enzyme.EnzymeCore.WithPrimal(Comrade.admode(d)) dx = zero(x) y = fetch(schedule( - Task(32*1024*2014) do + Task(32*1024*1024) do (_, y) = autodiff(mode, Comrade.logdensityof, Active, Const(d), Duplicated(x, dx)) return y end From 8baa5c091c7e64a4d07d41e93ccc5d8284977323 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 15 Dec 2024 20:07:03 -0500 Subject: [PATCH 03/34] Fix up site array --- src/instrument/site_array.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 0e53ef3e..92ec521a 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -136,17 +136,17 @@ end function Base.getindex(arr::SiteArray; F=Base.Colon(), S=Base.Colon(), T=Base.Colon()) T2 = _maybe_all(times(arr), T) F2 = _maybe_all(frequencies(arr), F) - S2 = S isa Base.Colon ? unique(S) : S + S2 = S isa Base.Colon ? unique(sites(arr)) : S return select_region(arr, S2, T2, F2) end -function select_region(arr::SiteArray, S::Symbol, T::Union{IntegrationTime, AbstractInterval}, F::Union{FrequencyChannel, AbstractInterval}) +function select_region(arr::SiteArray, S::Symbol, T::Union{Real, AbstractInterval}, F::Union{Real, AbstractInterval}) select_region(arr, (S,), T, F) end -function select_region(arr::SiteArray, site, T::Union{IntegrationTime, AbstractInterval}, F::Union{FrequencyChannel, AbstractInterval}) - inds = findall(i->((Comrade.sites(arr)[i] ∈ site)&&(Comrade.times(arr)[i] ∈ T)), eachindex(arr)) +function select_region(arr::SiteArray, site, Ti::Union{Real, AbstractInterval}, Fr::Union{Real, AbstractInterval}) + inds = findall(i->((Comrade.sites(arr)[i] ∈ site)&&(Ti ∈ Comrade.times(arr)[i])&&(Fr ∈ Comrade.frequencies(arr)[i])), eachindex(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end From d71f3c28f9d492dfbcab5e690369ecbc3b2eea9a Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 15 Dec 2024 20:08:35 -0500 Subject: [PATCH 04/34] Small adjustment to gradient --- ext/ComradeEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ComradeEnzymeExt.jl b/ext/ComradeEnzymeExt.jl index 0eca7e49..9ae85532 100644 --- a/ext/ComradeEnzymeExt.jl +++ b/ext/ComradeEnzymeExt.jl @@ -12,7 +12,7 @@ function LogDensityProblems.logdensity_and_gradient(d::Comrade.TransformedVLBIPo mode = Enzyme.EnzymeCore.WithPrimal(Comrade.admode(d)) dx = zero(x) y = fetch(schedule( - Task(32*1024*2014) do + Task(32*1024*1024) do (_, y) = autodiff(mode, Comrade.logdensityof, Active, Const(d), Duplicated(x, dx)) return y end From 5d4d2b06a8f17839597b71377afd4ba1b2cc9a1f Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 15 Dec 2024 23:13:39 -0500 Subject: [PATCH 05/34] Add frequency segmentation and make SiteArray indexing work --- src/instrument/instrument.jl | 16 ++++++++- src/instrument/model.jl | 4 +-- src/instrument/priors/array_priors.jl | 21 ++++++++---- src/instrument/priors/segmentation.jl | 30 +++++++++++++---- src/instrument/site_array.jl | 48 +++++++++++---------------- 5 files changed, 75 insertions(+), 44 deletions(-) diff --git a/src/instrument/instrument.jl b/src/instrument/instrument.jl index ff111549..cec0b463 100644 --- a/src/instrument/instrument.jl +++ b/src/instrument/instrument.jl @@ -10,12 +10,15 @@ struct IntegrationTime{T} dt::T end -Base.in(t::Number, ts::IntegrationTime) = (ts.t0 - ts.dt/2) ≤ t < (ts.t0 + ts.dt) +Base.in(t::Number, ts::IntegrationTime) = (ts.t0 - ts.dt/2) ≤ t < (ts.t0 + ts.dt/2) Base.isless(t::IntegrationTime, ts::IntegrationTime) = t.t0 < ts.t0 Base.isless(s::Number, t::IntegrationTime) = s < (t.t0 - t.dt/2) Base.isless(t::IntegrationTime, s::Number) = (t.t0 + t.dt/2) < s mjd(ts::IntegrationTime) = ts.mjd +Base.in(t::IntegrationTime, ::Base.Colon) = true +_center(ts::IntegrationTime) = ts.t0 +_region(ts::IntegrationTime) = ts.dt struct FrequencyChannel{T, I<:Integer} central::T @@ -24,6 +27,17 @@ struct FrequencyChannel{T, I<:Integer} end Base.in(f::Number, fs::FrequencyChannel) = (fs.central-fs.bandwidth/2) ≤ f < (fs.central+fs.bandwidth/2) channel(fs::FrequencyChannel) = fs.channel +Base.isless(t::FrequencyChannel, ts::FrequencyChannel) = _center(t) < _center(ts) +Base.isless(s::Number, t::FrequencyChannel) = s < (_center(t) - _region(t)/2) +Base.isless(t::FrequencyChannel, s::Number) = (_center(t) + _region(t)/2) < s +Base.in(x, t::FrequencyChannel) = in(t, x) +Base.in(t::FrequencyChannel, ::Base.Colon) = true + + + +_center(fs::FrequencyChannel) = fs.central +_region(fs::FrequencyChannel) = fs.bandwidth + include("site_array.jl") diff --git a/src/instrument/model.jl b/src/instrument/model.jl index 423492e8..abf331f1 100644 --- a/src/instrument/model.jl +++ b/src/instrument/model.jl @@ -195,8 +195,8 @@ function _construct_baselinemap(T, F, bl, x::SiteArray) t = T[i] f = F[i] s1, s2 = bl[i] - i1 = findall(x->(t∈x[1])&&(x[2]==s1), tsf) - i2 = findall(x->(t∈x[1])&&(x[2]==s2), tsf) + i1 = findall(x->(t∈x[1])&&(x[2]==s1)&&(f∈x[3]), tsf) + i2 = findall(x->(t∈x[1])&&(x[2]==s2)&&(f∈x[3]), tsf) length(i1) > 1 && throw(AssertionError("Multiple indices found for $t, $((s1)) in SiteArray")) length(i2) > 1 && throw(AssertionError("Multiple indices found for $t, $((s2)) in SiteArray")) (isnothing(i1) | isempty(i1)) && throw(AssertionError("$t, $f, $((s1)) not found in SiteArray")) diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index 910592ca..a7220e9c 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -71,29 +71,38 @@ function build_sitemap(d::ArrayPrior, array) # Ok to so this we are going to construct the schema first over sites. # At the end we may re-order depending on the schema ordering we want # to use. - tlists = map(keys(sites_prior)) do s + lists = map(keys(sites_prior)) do s seg = segmentation(sites_prior[s]) # get all the indices where this site is present inds_s = findall(x->((x[1]==s)||x[2]==s), array[:sites]) # Get all the unique times ts = unique(T[inds_s]) + fs = unique(F[inds_s]) # Now makes the acceptable time stamps given the segmentation tstamp = timestamps(seg, array) + fchan = freqchannels(SpectralWindow(), array) # Now we find commonalities times = eltype(tstamp)[] - for t in tstamp - if any(x->x∈t, ts) && (!(t.t0 ∈ times)) + freqs = eltype(fchan)[] + for t in tstamp, f in fchan + if any(x->x∈t, ts) && any(x->x∈f, fs) && ((!(t.t0 ∈ times)) || (!(f.central ∈ freqs))) push!(times, t) + push!(freqs, f) end end - return times + return times, freqs end + tlists = first.(lists) + flists = last.(lists) # construct the schema slist = mapreduce((t,s)->fill(s, length(t)), vcat, tlists, keys(sites_prior)) tlist = reduce(vcat, tlists) + flist = reduce(vcat, flists) + tlistre = similar(tlist) slistre = similar(slist) + flistre = similar(flist) # Now rearrange so we have time site ordering (sites are the fastest changing) tuni = sort(unique(getproperty.(tlist, :t0))) ind0 = 1 @@ -101,10 +110,10 @@ function build_sitemap(d::ArrayPrior, array) ind = findall(x->x.t0==t, tlist) tlistre[ind0:ind0+length(ind)-1] .= tlist[ind] slistre[ind0:ind0+length(ind)-1] .= slist[ind] + flistre[ind0:ind0+length(ind)-1] .= flist[ind] ind0 += length(ind) end - freqs = Fill(F[1], length(tlistre)) - return SiteLookup(tlistre, freqs, slistre) + return SiteLookup(tlistre, flistre, slistre) end function ObservedArrayPrior(d::ArrayPrior, array::EHTArrayConfiguration) diff --git a/src/instrument/priors/segmentation.jl b/src/instrument/priors/segmentation.jl index 3b6eb7d8..16ff3aeb 100644 --- a/src/instrument/priors/segmentation.jl +++ b/src/instrument/priors/segmentation.jl @@ -11,13 +11,17 @@ the following functions: """ abstract type Segmentation end +abstract type TimeSegmentation <: Segmentation end + +abstract type FrequencySegmentation <: Segmentation end + # Track is for quantities that remain static across an entire observation """ $(TYPEDEF) Data segmentation such that the quantity is constant over a `track`, i.e., the observation "night". """ -struct TrackSeg <: Segmentation end +struct TrackSeg <: TimeSegmentation end # Scan is for quantities that are constant across a scan """ @@ -25,7 +29,7 @@ struct TrackSeg <: Segmentation end Data segmentation such that the quantity is constant over a `scan`. """ -struct ScanSeg <: Segmentation end +struct ScanSeg <: TimeSegmentation end # Integration is for quantities that change every integration time @@ -35,10 +39,10 @@ struct ScanSeg <: Segmentation end Data segmentation such that the quantity is constant over the time stamps in the data. If the data is scan-averaged before then `IntegSeg` will be identical to `ScanSeg`. """ -struct IntegSeg <: Segmentation end +struct IntegSeg <: TimeSegmentation end """ - timestamps(seg::Segmentation, array::AbstractArrayConfiguration) + timestamps(seg::TimeSegmentation, array::AbstractArrayConfiguration) Return the time stamps or really a vector of integration time regions for a given segmentation scheme `seg` and array configuration `array`. @@ -48,10 +52,11 @@ function timestamps end function timestamps(::ScanSeg, array) st = array.scans mjd = array.mjd - # Shift the central time to the middle of the scan dt = (st.stop .- st.start) + dt[end] = dt[end]+0.5 t0 = st.start .+ dt./2 + return IntegrationTime.(mjd, t0, dt) end @@ -60,7 +65,6 @@ end function timestamps(::IntegSeg, array) ts = unique(array[:Ti]) - st = array.scans mjd = array.mjd # TODO build in the dt into the data format @@ -74,10 +78,10 @@ function timestamps(::IntegSeg, array) end function timestamps(::TrackSeg, array) - st = array.scans mjd = array.mjd tstart, tend = extrema(array[:Ti]) + tend = tend + 1 dt = tend - tstart if iszero(dt) dt = 1/3600 @@ -85,3 +89,15 @@ function timestamps(::TrackSeg, array) # TODO build in the dt into the data format return (IntegrationTime(mjd, (tend-tstart)/2 + tstart, dt),) end + + +struct SpectralWindow <: FrequencySegmentation end + + +function freqchannels(::SpectralWindow, array) + Fr = unique(array[:Fr]) + if length(Fr) > 1 + Fr[end] = nextfloat(Fr[end]) + end + return FrequencyChannel.(Fr, array.bandwidth, 1:length(Fr)) +end diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 92ec521a..3a49a758 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -16,7 +16,7 @@ which will grab the first 10 time and frequency points for the ALMA site. Otherwise indexing into the array will return an element whose time, frequency, and site are the element of the `times`, `frequencies`, and `sites` arrays respectively. """ -struct SiteArray{T, N, A<:AbstractArray{T,N}, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:Number, N}, Sy<:AbstractArray{<:Any, N}} <: AbstractArray{T, N} +struct SiteArray{T, N, A<:AbstractArray{T,N}, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:FrequencyChannel, N}, Sy<:AbstractArray{<:Any, N}} <: AbstractArray{T, N} data::A times::Ti frequencies::Fr @@ -111,47 +111,36 @@ function site(arr::SiteArray, p) end -function time(arr::SiteArray, a::Union{AbstractInterval, IntegrationTime}) +function time(arr::SiteArray, a::Union{AbstractInterval, Real}) inds = findall(in(a), times(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end -function frequency(arr::SiteArray, a::Union{AbstractInterval, FrequencyChannel}) +function frequency(arr::SiteArray, a::Union{AbstractInterval, Real}) inds = findall(in(a), times(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end +_equalorin(x::T, y::T) where {T} = x == y +_equalorin(x::Real, y) = x ∈ y +_equalorin(x, y::Real) = y ∈ x +_equalorin(x, y) = y ∈ x +_equalorin(x, ::typeof(Base.Colon())) = true +const Indexable = Union{Integer, AbstractArray{<:Integer}, BitArray} -@inline function _maybe_all(arr, X) - if X isa Base.Colon - ext = extrema(arr) - return ClosedInterval(ext[1], ext[2]) - else - return X - end -end - -function Base.getindex(arr::SiteArray; F=Base.Colon(), S=Base.Colon(), T=Base.Colon()) - T2 = _maybe_all(times(arr), T) - F2 = _maybe_all(frequencies(arr), F) - S2 = S isa Base.Colon ? unique(sites(arr)) : S - return select_region(arr, S2, T2, F2) -end - -function select_region(arr::SiteArray, S::Symbol, T::Union{Real, AbstractInterval}, F::Union{Real, AbstractInterval}) - select_region(arr, (S,), T, F) -end - - -function select_region(arr::SiteArray, site, Ti::Union{Real, AbstractInterval}, Fr::Union{Real, AbstractInterval}) - inds = findall(i->((Comrade.sites(arr)[i] ∈ site)&&(Ti ∈ Comrade.times(arr)[i])&&(Fr ∈ Comrade.frequencies(arr)[i])), eachindex(arr)) +function Base.getindex(arr::SiteArray; Fr=Base.Colon(), S=Base.Colon(), Ti=Base.Colon()) + Fr2 = isa(Fr, Indexable) ? unique(arr.frequencies)[Fr] : Fr + S2 = isa(S, Indexable) ? unique(arr.sites)[S] : S + Ti2 = isa(Ti, Indexable) ? unique(arr.times)[Ti] : Ti + inds = findall(i->(_equalorin(S2, Comrade.sites(arr)[i])&&_equalorin(Ti2, Comrade.times(arr)[i])&&_equalorin(Fr2, Comrade.frequencies(arr)[i])), eachindex(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end -struct SiteLookup{L<:NamedTuple, N, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:Number, N}, Sy<:AbstractArray{<:Any, N}} + +struct SiteLookup{L<:NamedTuple, N, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:FrequencyChannel, N}, Sy<:AbstractArray{<:Any, N}} lookup::L times::Ti frequencies::Fr @@ -204,7 +193,10 @@ function SiteArray(a::AbstractArray, map::SiteLookup) return SiteArray(a, map.times, map.frequencies, map.sites) end -function SiteArray(data::SiteArray{T, N}, times::AbstractArray{<:IntegrationTime, N}, frequencies::AbstractArray{<:Number, N}, sites::AbstractArray{<:Number, N}) where {T, N} +function SiteArray(data::SiteArray{T, N}, + times::AbstractArray{<:IntegrationTime, N}, + frequencies::AbstractArray{<:FrequencyChannel, N}, + sites::AbstractArray{<:Number, N}) where {T, N} return data end From 7d5946b76ef9ea938de3849240cafb32fdf75684 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 15 Dec 2024 23:38:36 -0500 Subject: [PATCH 06/34] Fix indexing again --- src/instrument/site_array.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 3a49a758..ce79d851 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -128,6 +128,7 @@ _equalorin(x::Real, y) = x ∈ y _equalorin(x, y::Real) = y ∈ x _equalorin(x, y) = y ∈ x _equalorin(x, ::typeof(Base.Colon())) = true +_equalorin(::typeof(Base.Colon()), x) = true const Indexable = Union{Integer, AbstractArray{<:Integer}, BitArray} function Base.getindex(arr::SiteArray; Fr=Base.Colon(), S=Base.Colon(), Ti=Base.Colon()) From c5b108543aa64dda51433520ddd85cd69d7bdef1 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 16 Dec 2024 10:20:11 -0500 Subject: [PATCH 07/34] Small bug fix to prevent SO --- src/instrument/instrument.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/instrument/instrument.jl b/src/instrument/instrument.jl index cec0b463..6ac95678 100644 --- a/src/instrument/instrument.jl +++ b/src/instrument/instrument.jl @@ -10,12 +10,11 @@ struct IntegrationTime{T} dt::T end +mjd(ts::IntegrationTime) = ts.mjd Base.in(t::Number, ts::IntegrationTime) = (ts.t0 - ts.dt/2) ≤ t < (ts.t0 + ts.dt/2) Base.isless(t::IntegrationTime, ts::IntegrationTime) = t.t0 < ts.t0 Base.isless(s::Number, t::IntegrationTime) = s < (t.t0 - t.dt/2) Base.isless(t::IntegrationTime, s::Number) = (t.t0 + t.dt/2) < s -mjd(ts::IntegrationTime) = ts.mjd -Base.in(t::IntegrationTime, ::Base.Colon) = true _center(ts::IntegrationTime) = ts.t0 _region(ts::IntegrationTime) = ts.dt @@ -25,13 +24,11 @@ struct FrequencyChannel{T, I<:Integer} bandwidth::T channel::I end -Base.in(f::Number, fs::FrequencyChannel) = (fs.central-fs.bandwidth/2) ≤ f < (fs.central+fs.bandwidth/2) channel(fs::FrequencyChannel) = fs.channel +Base.in(f::Number, fs::FrequencyChannel) = (fs.central-fs.bandwidth/2) ≤ f < (fs.central+fs.bandwidth/2) Base.isless(t::FrequencyChannel, ts::FrequencyChannel) = _center(t) < _center(ts) Base.isless(s::Number, t::FrequencyChannel) = s < (_center(t) - _region(t)/2) Base.isless(t::FrequencyChannel, s::Number) = (_center(t) + _region(t)/2) < s -Base.in(x, t::FrequencyChannel) = in(t, x) -Base.in(t::FrequencyChannel, ::Base.Colon) = true From eab1d251c5c7ea9c9fe81eb8a6d75abbc589cabc Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 17 Dec 2024 10:56:06 -0500 Subject: [PATCH 08/34] Fix forward_jones to work with frequency dependent instrument terms --- src/instrument/instrument.jl | 2 ++ src/instrument/jonesmatrices.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/instrument/instrument.jl b/src/instrument/instrument.jl index 6ac95678..ebddc7d9 100644 --- a/src/instrument/instrument.jl +++ b/src/instrument/instrument.jl @@ -11,6 +11,7 @@ struct IntegrationTime{T} end mjd(ts::IntegrationTime) = ts.mjd +interval(ts::IntegrationTime) = (ts.t0 - ts.dt/2)..(ts.t0 + ts.dt/2) Base.in(t::Number, ts::IntegrationTime) = (ts.t0 - ts.dt/2) ≤ t < (ts.t0 + ts.dt/2) Base.isless(t::IntegrationTime, ts::IntegrationTime) = t.t0 < ts.t0 Base.isless(s::Number, t::IntegrationTime) = s < (t.t0 - t.dt/2) @@ -25,6 +26,7 @@ struct FrequencyChannel{T, I<:Integer} channel::I end channel(fs::FrequencyChannel) = fs.channel +interval(fs::FrequencyChannel) = (fs.central - fs.bandwidth/2)..(fs.central + fs.bandwidth/2) Base.in(f::Number, fs::FrequencyChannel) = (fs.central-fs.bandwidth/2) ≤ f < (fs.central+fs.bandwidth/2) Base.isless(t::FrequencyChannel, ts::FrequencyChannel) = _center(t) < _center(ts) Base.isless(s::Number, t::FrequencyChannel) = s < (_center(t) - _region(t)/2) diff --git a/src/instrument/jonesmatrices.jl b/src/instrument/jonesmatrices.jl index 2042b4c6..274e7a11 100644 --- a/src/instrument/jonesmatrices.jl +++ b/src/instrument/jonesmatrices.jl @@ -225,7 +225,7 @@ of `xs, and whose elements are the jones matrices for the specific parameters. function forward_jones(v::AbstractJonesMatrix, xs::NamedTuple{N}) where {N} sm = broadest_sitemap(xs) bl = map(x->(x,x), sm.sites) - bmaps = map(x->_construct_baselinemap(getproperty.(sm.times, :t0), sm.frequencies, bl, x).indices_1, xs) + bmaps = map(x->_construct_baselinemap(getproperty.(sm.times, :t0), getproperty.(sm.frequencies, :central), bl, x).indices_1, xs) vs = map(eachindex(sm.times)) do index indices = map(x->getindex(x, index), bmaps) params = NamedTuple{N}(map(getindex, values(xs), values(indices))) From 2d08fe23169bb1b3f2bb6ed11b97d2015c05b3f9 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 17 Dec 2024 11:13:55 -0500 Subject: [PATCH 09/34] Update to Pyehtim 0.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 501004be..09af5ad1 100644 --- a/Project.toml +++ b/Project.toml @@ -89,7 +89,7 @@ ParameterHandling = "0.4, 0.5" Pigeons = "0.3, 0.4" PolarizedTypes = "0.1" PrettyTables = "1, 2" -Pyehtim = "0.1" +Pyehtim = "0.2" RecipesBase = "1" Reexport = "1" SpecialFunctions = "0.10, 1, 2" From 8ecc3ad475430b2863c192ae057ed07a2ffe5590 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 17 Dec 2024 11:56:38 -0500 Subject: [PATCH 10/34] Fix Pyehtim extension --- ext/ComradePyehtimExt.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ext/ComradePyehtimExt.jl b/ext/ComradePyehtimExt.jl index e357ddf1..b3064968 100644 --- a/ext/ComradePyehtimExt.jl +++ b/ext/ComradePyehtimExt.jl @@ -14,8 +14,8 @@ function build_arrayconfig(obs) source = get_source(obsc) bw = get_bw(obsc) angles = get_fr_angles(obsc) - tarr = Pyehtim.get_arraytable(obsc) - scans = get_scantable(obsc) + tarr = StructArray(Pyehtim.get_arraytable(obsc)) + scans = StructArray(get_scantable(obsc)) bw = get_bw(obsc) elevation = StructArray(angles[1]) parallactic = StructArray(angles[2]) @@ -140,7 +140,6 @@ Returns an EHTObservationTable with visibility amplitude data """ function Comrade.extract_amp(obsc; pol=:I, debias=false, kwargs...) obs = obsc.copy() - obs.add_scans() obs.reorder_tarr_snr() obs.add_amp(;debias, kwargs...) config = build_arrayconfig(obs) @@ -163,7 +162,6 @@ Returns an EHTObservationTable with complex visibility data """ function Comrade.extract_vis(obsc; pol=:I, kwargs...) obs = obsc.copy() - obs.add_scans() obs.reorder_tarr_snr() config = build_arrayconfig(obs) vis, viserr = getvisfield(obs) @@ -181,7 +179,6 @@ Returns an EHTObservationTable with coherency matrices """ function Comrade.extract_coherency(obsc; kwargs...) obs = obsc.copy() - obs.add_scans() obs.reorder_tarr_snr() config = build_arrayconfig(obs) vis, viserr = getcoherency(obs) From 0e4800224474621add55fe93c36d9d5823ae0650 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 17 Dec 2024 23:02:03 -0500 Subject: [PATCH 11/34] Allow Comrade to work with new Pyehtim version --- ext/ComradePyehtimExt.jl | 6 ++++++ test/Project.toml | 1 - test/test_util.jl | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ext/ComradePyehtimExt.jl b/ext/ComradePyehtimExt.jl index b3064968..90063c9b 100644 --- a/ext/ComradePyehtimExt.jl +++ b/ext/ComradePyehtimExt.jl @@ -15,6 +15,12 @@ function build_arrayconfig(obs) bw = get_bw(obsc) angles = get_fr_angles(obsc) tarr = StructArray(Pyehtim.get_arraytable(obsc)) + + # This is because sometimes eht-imaging sets the scans to nothing + # and sometimes it fills it with junk + if length(obsc.scans) <= 1 + obsc.add_scans() + end scans = StructArray(get_scantable(obsc)) bw = get_bw(obsc) elevation = StructArray(angles[1]) diff --git a/test/Project.toml b/test/Project.toml index b4519898..4259d1c3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,4 +37,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029" VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/test_util.jl b/test/test_util.jl index cf60e055..ae916459 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -8,6 +8,7 @@ function load_data() joinpath(@__DIR__, "../examples/Data/array.txt"), polrep="circ" ) + obspol.add_scans() m = ehtim.model.Model() m = m.add_gauss(1.0, μas2rad(40.0), μas2rad(20.0), π/3, 0.0, 0.0) From bb877ac9bd1ff43b14c5282d1136cd6520625da7 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 18 Dec 2024 11:15:24 -0500 Subject: [PATCH 12/34] Make sitelookup aware of difference frequencies --- src/instrument/instrument_transforms.jl | 25 +++++++++--- src/instrument/priors/array_priors.jl | 15 ++++--- src/instrument/priors/refant.jl | 11 +++-- src/instrument/site_array.jl | 41 ++++++++++++++++--- test/Core/models.jl | 54 +++++++++++++++++++++++++ 5 files changed, 124 insertions(+), 22 deletions(-) diff --git a/src/instrument/instrument_transforms.jl b/src/instrument/instrument_transforms.jl index 1e02167a..580cccf4 100644 --- a/src/instrument/instrument_transforms.jl +++ b/src/instrument/instrument_transforms.jl @@ -5,9 +5,12 @@ inner_transform(t::AbstractInstrumentTransform) = t.inner_transform function TV.transform_with(flag::TV.LogJacFlag, m::AbstractInstrumentTransform, x, index) y, ℓ, index = _instrument_transform_with(flag, m, x, index) sm = m.site_map - return SiteArray(y, sm.times, sm.frequencies, sm.sites), ℓ, index + return SiteArray(y, sm), ℓ, index end +EnzymeRules.inactive_type(::Type{AbstractArray{<:IntegrationTime}}) = true +EnzymeRules.inactive_type(::Type{AbstractArray{<:FrequencyChannel}}) = true + function TV.inverse_at!(x::AbstractArray, index, t::AbstractInstrumentTransform, y::SiteArray) itrf = inner_transform(t) return TV.inverse_at!(x, index, itrf, parent(y)) @@ -50,15 +53,25 @@ end return yout, ℓ, index end +EnzymeRules.inactive_type(::Type{<:SiteLookup}) = true + @inline function site_sum(y, site_map::SiteLookup) - yout = similar(y) - @inbounds for site in site_map.lookup + # yout = similar(y) + vals = values(lookup(site_map)) + @inbounds for site in vals + # i0 = site[begin] + # yout[i0] = y[i0] + # acc = zero(eltype(y)) + # for idx in site + # acc += y[idx] + # yout[idx] = acc + # end ys = @view y[site] # y should never alias so we should be fine here. - youts = @view yout[site] - cumsum!(youts, ys) + # youts = @view yout[site] + cumsum!(ys, ys) end - return yout + return y end # function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_instrument_transform_with), flag, m::MarkovInstrumentTransform, x, index) diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index a7220e9c..f5b36aa5 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -76,21 +76,20 @@ function build_sitemap(d::ArrayPrior, array) # get all the indices where this site is present inds_s = findall(x->((x[1]==s)||x[2]==s), array[:sites]) # Get all the unique times - ts = unique(T[inds_s]) - fs = unique(F[inds_s]) + ts = T[inds_s] + fs = F[inds_s] + tfs = zip(ts, fs) # Now makes the acceptable time stamps given the segmentation tstamp = timestamps(seg, array) fchan = freqchannels(SpectralWindow(), array) # Now we find commonalities - times = eltype(tstamp)[] - freqs = eltype(fchan)[] + tf = Tuple{eltype(tstamp), eltype(fchan)}[] for t in tstamp, f in fchan - if any(x->x∈t, ts) && any(x->x∈f, fs) && ((!(t.t0 ∈ times)) || (!(f.central ∈ freqs))) - push!(times, t) - push!(freqs, f) + if any(x->(x[1]∈t && x[2]∈f), tfs) && ((!((t.t0, f.central) ∈ tf))) + push!(tf, (t, f)) end end - return times, freqs + return first.(tf), last.(tf) end tlists = first.(lists) flists = last.(lists) diff --git a/src/instrument/priors/refant.jl b/src/instrument/priors/refant.jl index 8585f579..235e6de2 100644 --- a/src/instrument/priors/refant.jl +++ b/src/instrument/priors/refant.jl @@ -43,14 +43,19 @@ end function reference_indices(array::AbstractArrayConfiguration, st::SiteLookup, r::SEFDReference) tarr = array.tarr t = unique(st.times) + f = unique(st.frequencies) sefd = NamedTuple{Tuple(tarr.sites)}(Tuple(tarr.SEFD1 .+ tarr.SEFD2)) - fixedinds = map(eachindex(t)) do i - inds = findall(==(t[i]), st.times) + fixedinds = Int[] + for i in eachindex(t), j in eachindex(f) + inds = findall(x->((st.times[x]==t[i])&&(st.frequencies[x]==f[j])), eachindex(st.times)) + if isempty(inds) + continue + end sites = Tuple(st.sites[inds]) @assert length(sites) <= length(sefd) "Error in reference site generation. Too many sites" sp = select(sefd, sites) _, ind = findmin(values(sp)) - return inds[ind] + push!(fixedinds, inds[ind]) end return fixedinds, Fill(r.value, length(fixedinds)) end diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index ce79d851..23634326 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -148,8 +148,18 @@ struct SiteLookup{L<:NamedTuple, N, Ti<:AbstractArray{<:IntegrationTime, N}, Fr< sites::Sy end +times(s::SiteLookup) = s.times +sites(s::SiteLookup) = s.sites +frequencies(s::SiteLookup) = s.frequencies +lookup(s::SiteLookup) = s.lookup + +EnzymeRules.inactive(::typeof(times), ::SiteLookup) = nothing +EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup) = nothing +EnzymeRules.inactive(::typeof(sites), ::SiteLookup) = nothing +EnzymeRules.inactive(::typeof(lookup), ::SiteLookup) = nothing + function sitemap!(f, out::AbstractArray, gains::AbstractArray, slook::SiteLookup) - map(slook.lookup) do site + map(lookup(slook)) do site ysite = @view gains[site] outsite = @view out[site] outsite .= f.(ysite) @@ -163,9 +173,10 @@ function sitemap(f, gains::AbstractArray{T}, slook::SiteLookup) where {T} end function sitemap!(::typeof(cumsum), out::AbstractArray, gains::AbstractArray, slook::SiteLookup) - map(slook.lookup) do site + map(lookup(slook)) do site ys = @view gains[site] cumsum!(ys, ys) + nothing end return out end @@ -180,8 +191,28 @@ function SiteLookup(s::SiteArray) end function SiteLookup(times::AbstractVector, frequencies::AbstractArray, sites::AbstractArray) - slist = Tuple(sort(unique(sites))) - return SiteLookup(NamedTuple{slist}(map(p->findall(==(p), sites), slist)), times, frequencies, sites) + slist = sort(unique(sites)) + flist = sort(unique(frequencies)) + if length(flist) == 1 + return SiteLookup(NamedTuple{Tuple(slist)}(map(p->findall(==(p), sites), slist)), times, frequencies, sites) + else + # Find sites first + sflist = Symbol[] + inds = Vector{Int}[] + for s in slist + sinds = findall(==(s), sites) + for (i,f) in enumerate(flist) + finds = findall(==(f), @view(frequencies[sinds])) + if !isempty(finds) + ss = Symbol(s, i) + push!(sflist, ss) + push!(inds, finds) + end + end + end + lookup = NamedTuple{Tuple(sflist)}(Tuple(inds)) + return SiteLookup(lookup, times, frequencies, sites) + end end """ @@ -191,7 +222,7 @@ Construct a site array with the entries `arr` and the site ordering implied by `sitelookup`. """ function SiteArray(a::AbstractArray, map::SiteLookup) - return SiteArray(a, map.times, map.frequencies, map.sites) + return SiteArray(a, times(map), frequencies(map), sites(map)) end function SiteArray(data::SiteArray{T, N}, diff --git a/test/Core/models.jl b/test/Core/models.jl index e7f7a80b..c4a8f3a8 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -408,6 +408,60 @@ end end + + @testset "Multifrequency" begin + dcoh2 = deepcopy(dcoh) + dcoh2.config[:Fr][200:end] .= 345e9 + vis = CoherencyMatrix.(Comrade.measurement(dcoh2), Ref(CirBasis())) + + G = JonesG() do x + gR = exp(x.lgR + 1im*x.gpR) + gL = gR*exp(x.lgrat + 1im*x.gprat) + return gR, gL + end + + D = JonesD() do x + dR = complex(x.dRx, x.dRy) + dL = complex(x.dLx, x.dLy) + return dR, dL + end + + R = JonesR(;add_fr=true) + J = JonesSandwich(*, G, D, R) + intprior = ( + lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), + lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), + gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + ) + intm = InstrumentModel(J, intprior) + ointm, printm = Comrade.set_array(intm, arrayconfig(dcoh)) + + @testset "ObservedArrayPrior" begin + @inferred logpdf(printm, rand(printm)) + x = rand(printm) + @test logpdf(printm, x) ≈ logpdf(printm2, x) + @test asflat(printm) isa TV.AbstractTransform + p = rand(printm) + t = asflat(printm) + pout = TV.transform(t, TV.inverse(t, p)) + dp = ntequal(p, pout) + @test dp.lgR + @test dp.lgrat + @test dp.gprat + @test dp.dRx + @test dp.dRy + @test dp.dLx + @test dp.dLy + end + + + end + @testset "Integration" begin _,dvis, amp, lcamp, cphase, dcoh = load_data() ts = Comrade.timestamps(ScanSeg(), arrayconfig(dvis)) From d6d2744dc5f85cb6e28995c52eb93405dca3a545 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 18 Dec 2024 17:10:07 -0500 Subject: [PATCH 13/34] Fix dependencies in examples --- examples/advanced/HybridImaging/Project.toml | 2 +- examples/beginner/GeometricModeling/Project.toml | 2 +- examples/intermediate/ClosureImaging/Project.toml | 2 +- examples/intermediate/StokesIImaging/Project.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/advanced/HybridImaging/Project.toml b/examples/advanced/HybridImaging/Project.toml index 11b0baee..3ea32949 100644 --- a/examples/advanced/HybridImaging/Project.toml +++ b/examples/advanced/HybridImaging/Project.toml @@ -18,7 +18,7 @@ CairoMakie = "0.12" Distributions = "0.25" Optimization = "4" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" StatsBase = "0.34" VLBIImagePriors = "0.9" diff --git a/examples/beginner/GeometricModeling/Project.toml b/examples/beginner/GeometricModeling/Project.toml index 7e7be586..8879f150 100644 --- a/examples/beginner/GeometricModeling/Project.toml +++ b/examples/beginner/GeometricModeling/Project.toml @@ -18,6 +18,6 @@ Distributions = "0.25" Optimization = "4" Pigeons = "0.4" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" VLBIImagePriors = "0.9" diff --git a/examples/intermediate/ClosureImaging/Project.toml b/examples/intermediate/ClosureImaging/Project.toml index 918f8c78..7d161925 100644 --- a/examples/intermediate/ClosureImaging/Project.toml +++ b/examples/intermediate/ClosureImaging/Project.toml @@ -24,6 +24,6 @@ Distributions = "0.25" Optimization = "4" Pkg = "1" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" VLBIImagePriors = "0.9" diff --git a/examples/intermediate/StokesIImaging/Project.toml b/examples/intermediate/StokesIImaging/Project.toml index a36a695a..548458a4 100644 --- a/examples/intermediate/StokesIImaging/Project.toml +++ b/examples/intermediate/StokesIImaging/Project.toml @@ -21,6 +21,6 @@ Distributions = "0.25" Optimization = "4" Pkg = "1" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" VLBIImagePriors = "0.9" From c2786b7e3c0ab3c5cc540f270827e1130e798f26 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 18 Dec 2024 17:52:14 -0500 Subject: [PATCH 14/34] Test multifrequency --- test/Core/models.jl | 206 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 192 insertions(+), 14 deletions(-) diff --git a/test/Core/models.jl b/test/Core/models.jl index c4a8f3a8..ce05571a 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -409,11 +409,9 @@ end end - @testset "Multifrequency" begin - dcoh2 = deepcopy(dcoh) - dcoh2.config[:Fr][200:end] .= 345e9 - vis = CoherencyMatrix.(Comrade.measurement(dcoh2), Ref(CirBasis())) - + @testset "Coherencies Multifrequency" begin + dcoh.config[:Fr][200:end] .= 345e9 + vis = CoherencyMatrix.(Comrade.measurement(dcoh), Ref(CirBasis())) G = JonesG() do x gR = exp(x.lgR + 1im*x.gpR) gL = gR*exp(x.lgrat + 1im*x.gprat) @@ -427,22 +425,54 @@ end end R = JonesR(;add_fr=true) + J = JonesSandwich(*, G, D, R) + J2 = JonesSandwich(G, D, R) do g, d, r + return g*d*r + end + + + F = JonesF() + + JG = GenericJones(x->(x.lg, x.lg, x.lg, x.lg)) + + intprior = ( - lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), - gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), - lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), - gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), - dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), - dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), - dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), - dLy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), - ) + lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), + lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), + gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + ) + + intm = InstrumentModel(J, intprior) + intm2 = InstrumentModel(J2, intprior) + intjg = InstrumentModel(JG, (;lg = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))))) + show(IOBuffer(), MIME"text/plain"(), intm) + + + ointm, printm = Comrade.set_array(intm, arrayconfig(dcoh)) + ointm2, printm2 = Comrade.set_array(intm2, arrayconfig(dcoh)) + ointjg, printjg = Comrade.set_array(intjg, arrayconfig(dcoh)) + + x = rand(printjg) + fj = forward_jones(JG, x) + @test fj[1][1] == x.lg[1] + + + Fpre = Comrade.preallocate_jones(F, arrayconfig(dcoh), CirBasis()) + Rpre = Comrade.preallocate_jones(JonesR(;add_fr=true), arrayconfig(dcoh), CirBasis()) + @test Fpre.matrices[1] ≈ Rpre.matrices[1] + @test Fpre.matrices[2] ≈ Rpre.matrices[2] @testset "ObservedArrayPrior" begin @inferred logpdf(printm, rand(printm)) + @inferred logpdf(printm2, rand(printm2)) x = rand(printm) @test logpdf(printm, x) ≈ logpdf(printm2, x) @test asflat(printm) isa TV.AbstractTransform @@ -459,6 +489,154 @@ end @test dp.dLy end + pintm, _ = Comrade.set_array(InstrumentModel(JonesR(;add_fr=true)), arrayconfig(dcoh)) + + + x = rand(printm) + x.lgR .= 0 + x.lgrat .= 0 + x.gpR .= 0 + x.gprat .= 0 + x.dRx .= 0 + x.dRy .= 0 + x.dLx .= 0 + x.dLy .= 0 + + vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) + vper = Comrade.apply_instrument(vis, pintm, (;instrument=NamedTuple())) + @test vout ≈ vper + + # test_rrule(Comrade.apply_instrument, vis, ointm⊢NoTangent(), (;instrument=x)) + + # # Now check that everything is being applied right + for s in sites(dcoh) + x.lgR .= 0 + x.lgrat .= 0 + x.gpR .= 0 + x.gprat .= 0 + x.dRx .= 0 + x.dRy .= 0 + x.dLx .= 0 + x.dLy .= 0 + + + inds1 = findall(x->(x[1]==s), dcoh[:baseline].sites) + inds2 = findall(x->(x[2]==s), dcoh[:baseline].sites) + ninds = findall(x->(x[1]!=s && x[2]!=s), dcoh[:baseline].sites) + + # Now amp-offsets + x.lgR .= 0 + x.lgrat .= 0 + x.gpR .= 0 + x.gprat .= 0 + x.dRx .= 0 + x.dRy .= 0 + x.dLx .= 0 + x.dLy .= 0 + + xlgRs = x.lgR[S=s] + xlgRs .= log(2) + xlgrat = x.lgrat[S=s] + xlgrat .= -log(2) + vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) + G = SMatrix{2,2}(2.0, 0.0, 0.0, 1.0) + @test vout[inds1] ≈ Ref(G) .*vper[inds1] + @test vout[inds2] ≈ vper[inds2] .* Ref(G) + @test vout[ninds] ≈ vper[ninds] + + # Now phases + x.lgR .= 0 + x.lgrat .= 0 + x.gpR .= 0 + x.gprat .= 0 + x.dRx .= 0 + x.dRy .= 0 + x.dLx .= 0 + x.dLy .= 0 + + xgpRs = x.gpR[S=s] + xgpRs .= π/3 + xgprat = x.gprat[S=s] + xgprat .= -π/3 + vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) + G = SMatrix{2,2}(exp(1im*π/3), 0.0, 0.0, exp(1im*0.0)) + @test vout[inds1] ≈ Ref(G) .*vper[inds1] + @test vout[inds2] ≈ vper[inds2] .* Ref(adjoint(G)) + @test vout[ninds] ≈ vper[ninds] + + + # Now dterms + x.lgR .= 0 + x.lgrat .= 0 + x.gpR .= 0 + x.gprat .= 0 + x.dRx .= 0 + x.dRy .= 0 + x.dLx .= 0 + x.dLy .= 0 + + xdRxs = x.dRx[S=s] + xdRxs .= 0.1 + xdRys = x.dRy[S=s] + xdRys .= 0.2 + xdLxs = x.dLx[S=s] + xdLxs .= 0.3 + xdLys = x.dLy[S=s] + xdLys .= 0.4 + + vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) + D = SMatrix{2,2}(1.0, 0.3 + 0.4im, 0.1 + 0.2im, 1.0) + @test vout[inds1] ≈ Ref(D) .*vper[inds1] + @test vout[inds2] ≈ vper[inds2] .* Ref(adjoint(D)) + @test vout[ninds] ≈ vper[ninds] + end + + @testset "caltable test" begin + c1 = caltable(x.lgR) + @test Tables.istable(typeof(c1)) + @test Tables.rowaccess(typeof(c1)) + @test Tables.rows(c1) === c1 + @test Tables.columnaccess(c1) + clmns = Tables.columns(c1) + @test clmns[1] == Comrade.scantimes(c1) + @test Bool(prod(skipmissing(Tables.matrix(clmns)[:,begin+1:end]) .== skipmissing(Comrade.gmat(c1)))) + @test c1.time == Comrade.scantimes(c1) + @test c1.time == Tables.getcolumn(c1, 1) + @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 + @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 2))) ≈ 0 + @test Tables.columnnames(c1) == [:time, sort(sites(amp))...] + + c1row = Tables.getrow(c1, 30) + @test eltype(c1) == typeof(c1row) + @test c1row.time == c1.time[30] + @test c1row.AA == c1.AA[30] + @test Tables.getcolumn(c1row, :AA) == c1.AA[30] + @test Tables.getcolumn(c1row, :time) == c1.time[30] + @test Tables.getcolumn(c1row, 2) == c1.AA[30] + @test Tables.getcolumn(c1row, 1) == c1.time[30] + @test propertynames(c1) == propertynames(c1row) == [:time, sort(sites(amp))...] + + Tables.schema(c1) isa Tables.Schema + Tables.getcolumn(c1, Float64, 1, :test) + Tables.getcolumn(c1, Float64, 2, :test) + + c1[1, :AA] + c1[!, :AA] + c1[:, :AA] + @test length(c1) == length(c1.AA) + @test c1[1 ,:] isa Comrade.CalTableRow + @test length(Tables.getrow(c1, 1:5)) == 5 + + plot(c1) + plot(c1, datagains=true) + plot(c1, sites=(:AA,)) + + show(c1) + end + + end + + end From dedacd20381f36f19586fd9a01243b59d5453267 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 18 Dec 2024 17:52:27 -0500 Subject: [PATCH 15/34] Start making CalTable frequency friendly --- src/instrument/caltable.jl | 54 ++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/src/instrument/caltable.jl b/src/instrument/caltable.jl index c9810dbc..a957dc61 100644 --- a/src/instrument/caltable.jl +++ b/src/instrument/caltable.jl @@ -4,13 +4,14 @@ export caltable $(TYPEDEF) A Tabes of calibration quantities. The columns of the table are the telescope sites codes. -The rows are the calibration quantities at a specific time stamp. This user should not +The rows are the calibration quantities at a specific time and frequency. This user should not use this struct directly. Instead that should call [`caltable`](@ref). """ -struct CalTable{T,G<:AbstractVecOrMat} +struct CalTable{T,F,G<:AbstractVecOrMat} names::Vector{Symbol} lookup::Dict{Symbol,Int} times::T + freqs::F gmat::G end @@ -25,40 +26,45 @@ Tables.columnaccess(::Type{<:CalTable}) = true Return the sites in the calibration table """ sites(g::CalTable) = getfield(g, :names) -scantimes(g::CalTable) = getfield(g, :times) +times(g::CalTable) = getfield(g, :Tis) +frequencies(g::CalTable) = getfield(g, :freqs) lookup(g::CalTable) = getfield(g, :lookup) gmat(g::CalTable) = getfield(g, :gmat) -function Tables.schema(g::CalTable{T,G}) where {T,G} - nms = [:time] +function Tables.schema(g::CalTable{T,F,G}) where {T,F,G} + nms = [:Ti, :Fr] append!(nms, sites(g)) - types = Type[eltype(T)] + types = Type[eltype(T), eltype(F)] append!(types, fill(eltype(G), size(gmat(g),2))) return Tables.Schema(nms, types) end -Tables.columns(g::CalTable) = Tables.table([scantimes(g) gmat(g)]; header=Tables.columnnames(g)) +Tables.columns(g::CalTable) = Tables.table([times(g), frequencies(g), gmat(g)]; header=Tables.columnnames(g)) function Tables.getcolumn(g::CalTable, ::Type{T}, col::Int, nm::Symbol) where {T} - (col == 1 || nm == :time) && return scantimes(g) - gmat(g)[:, col-1] + (col == 1 || nm == :Ti) && return times(g) + (col == 2 || nm == :Fr) && return frequencies(g) + gmat(g)[:, col-2] end function Tables.getcolumn(g::CalTable, nm::Symbol) - nm == :time && return scantimes(g) + nm == :Ti && return times(g) + nm == :Fr && return frequencies(g) return gmat(g)[:, lookup(g)[nm]] end function Tables.getcolumn(g::CalTable, i::Int) - i==1 && return scantimes(g) - return gmat(g)[:, i-1] + i==1 && return times(g) + i==2 && return frequencies(g) + return gmat(g)[:, i-2] end function viewcolumn(gt::CalTable, nm::Symbol) - nm == :time && return scantimes(gt) + nm == :Ti && return times(gt) + nm == :Fr && return frequencies(gt) return @view gmat(gt)[:, lookup(gt)[nm]] end -Tables.columnnames(g::CalTable) = [:time, sites(g)...] +Tables.columnnames(g::CalTable) = [:Ti, :Fr, sites(g)...] Tables.rowaccess(::Type{<:CalTable}) = true Tables.rows(g::CalTable) = g @@ -77,7 +83,8 @@ function Base.getproperty(g::CalTable, nm::Symbol) end function Base.getindex(gt::CalTable, i::Int, nm::Symbol) - nm == :time && return gt.times[i] + nm == :Ti && return times(gt)[i] + nm == :Fr && return frequencies(gt)[i] return Tables.getcolumn(gt, nm)[i] end @@ -117,18 +124,21 @@ end function Tables.getcolumn(g::CalTableRow, ::Type, col::Int, nm::Symbol) - (col == 1 || nm == :time) && return scantimes(getfield(g, :source))[getfield(g, :row)] + (col == 1 || nm == :Ti) && return times(getfield(g, :source))[getfield(g, :row)] + (col == 2 || nm == :Fr) && return frequencies(getfield(g, :source))[getfield(g, :row)] gmat(getfield(g, :source))[getfield(g, :row), col-1] end function Tables.getcolumn(g::CalTableRow, i::Int) - (i==1) && return scantimes(getfield(g, :source))[getfield(g, :row)] + (i==1) && return times(getfield(g, :source))[getfield(g, :row)] + (i==2) && return frequencies(getfield(g, :source))[getfield(g, :row)] gmat(getfield(g, :source))[getfield(g, :row), i-1] end function Tables.getcolumn(g::CalTableRow, nm::Symbol) src = getfield(g, :source) - nm == :time && return scantimes(src)[getfield(g, :row)] + nm == :Ti && return times(src)[getfield(g, :row)] + nm == :Fr && return frequencies(src)[getfield(g, :row)] return gmat(src)[getfield(g, :row), lookup(src)[nm]] end @@ -152,9 +162,9 @@ end #else # ylims --> inv.(lims)[end:-1:begin] #end - t = getproperty.(gt[:time], :t0) + t = getproperty.(gt[:Ti], :t0) xlims --> (t[begin], t[end] + 0.01*abs(t[end])) - for (i,s) in enumerate(sites) + for (i,s) in enumerate(sites), f in unique(gt[:Fr]) @series begin seriestype := :scatter subplot := i @@ -166,7 +176,7 @@ end T = nonmissingtype(eltype(gt[s])) ind = Base.:!.(ismissing.(gt[s])) - #x := gt[:time][ind] + #x := gt[:Ti][ind] if !datagains yy = gt[s][ind] else @@ -179,7 +189,7 @@ end end end -Tables.columnnames(g::CalTableRow) = [:time, sites(getfield(g, :source))...] +Tables.columnnames(g::CalTableRow) = [:Ti, sites(getfield(g, :source))...] using PrettyTables From 3fc98d0021bc3667e24e8b815be4aff60f572719 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 18 Dec 2024 22:55:49 -0500 Subject: [PATCH 16/34] Caltable now plots gains with frequency labels --- src/instrument/caltable.jl | 71 +++++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/src/instrument/caltable.jl b/src/instrument/caltable.jl index a957dc61..4f008fd2 100644 --- a/src/instrument/caltable.jl +++ b/src/instrument/caltable.jl @@ -26,7 +26,7 @@ Tables.columnaccess(::Type{<:CalTable}) = true Return the sites in the calibration table """ sites(g::CalTable) = getfield(g, :names) -times(g::CalTable) = getfield(g, :Tis) +times(g::CalTable) = getfield(g, :times) frequencies(g::CalTable) = getfield(g, :freqs) lookup(g::CalTable) = getfield(g, :lookup) gmat(g::CalTable) = getfield(g, :gmat) @@ -39,7 +39,7 @@ function Tables.schema(g::CalTable{T,F,G}) where {T,F,G} end -Tables.columns(g::CalTable) = Tables.table([times(g), frequencies(g), gmat(g)]; header=Tables.columnnames(g)) +Tables.columns(g::CalTable) = Tables.table(hcat(times(g), frequencies(g), gmat(g)); header=Tables.columnnames(g)) function Tables.getcolumn(g::CalTable, ::Type{T}, col::Int, nm::Symbol) where {T} (col == 1 || nm == :Ti) && return times(g) (col == 2 || nm == :Fr) && return frequencies(g) @@ -95,11 +95,21 @@ end function Base.getindex(gt::CalTable, I::AbstractUnitRange, nm::Symbol) getproperty(gt, nm)[I] end +function Base.getindex(gt::CalTable, I::AbstractVector{Int}, nm::Symbol) + getproperty(gt, nm)[I] +end + function Base.view(gt::CalTable, I::AbstractUnitRange, nm::Symbol) @view getproperty(gt, nm)[I] end +function Base.view(gt::CalTable, I::AbstractVector{Int}, nm::Symbol) + @view getproperty(gt, nm)[I] +end + + + function Base.getindex(gt::CalTable, ::Colon, nm::Symbol) Tables.getcolumn(gt, nm) end @@ -164,27 +174,36 @@ end #end t = getproperty.(gt[:Ti], :t0) xlims --> (t[begin], t[end] + 0.01*abs(t[end])) - for (i,s) in enumerate(sites), f in unique(gt[:Fr]) + for (i,s) in enumerate(sites) @series begin + T = nonmissingtype(eltype(gt[s])) + @info T + tt = Vector{eltype(t)}[] + yy = Vector{T}[] seriestype := :scatter subplot := i - label --> :none - + title := String(s) if i == length(sites) xguide --> "Time (UTC)" end - - T = nonmissingtype(eltype(gt[s])) - ind = Base.:!.(ismissing.(gt[s])) - #x := gt[:Ti][ind] - if !datagains - yy = gt[s][ind] + labels = ["$(f.central/1e9) GHz" for f in unique(gt[:Fr])] + if i == length(sites) + label --> reshape(labels, 1, :) else - yy = inv.(gt[s])[ind] + label --> nothing end - - title --> string(s) - t[ind], T.(yy) + for f in unique(gt[:Fr]) + ind = Base.:!.(ismissing.(gt[s])) + find = findall(==(f), gt[:Fr][ind]) + #x := gt[:Ti][ind] + push!(tt, t[ind][find]) + if !datagains + push!(yy, T.(gt[s][ind][find])) + else + push!(yy, T.(inv.(gt[s][ind][find]))) + end + end + tt, yy end end end @@ -196,11 +215,17 @@ using PrettyTables function Base.show(io::IO, ct::CalTable, ) pretty_table(io, Tables.columns(ct); header=Tables.columnnames(ct), - vlines=[1], - # formatters = (v,i,j)->round(v, digits=3) + vlines=[1,2], + formatters = _ctab_formatter ) end +function _ctab_formatter(v, i, j) + j == 1 && return "$(round(v.t0, digits=2)) hr" + j == 2 && return "$(round(v.central, digits=2)/1e9) GHz" + return round(v, digits=3) +end + """ caltable(s::SiteArray) @@ -209,18 +234,22 @@ Creates a calibration table from a site array """ function caltable(sarr::SiteArray) sites = sort(unique(Comrade.sites(sarr))) - time = unique(times(sarr)) + tf = collect(Iterators.product(unique(times(sarr))|>sort, unique(frequencies(sarr))|>sort)) + time = vec(first.(tf)) + freq = vec(last.(tf)) gmat = Matrix{Union{eltype(sarr), Missing}}(missing, length(time), length(sites)) gmat .= missing lookup = Dict(sites[i]=>i for i in eachindex(sites)) for (j, s) in enumerate(sites) cterms = site(sarr, s) - for (i, t) in enumerate(time) - ind = findfirst(==(t), times(cterms)) + for (i, (t,f)) in enumerate(tf) + ti = times(cterms) + fi = frequencies(cterms) + ind = findfirst(i->((ti[i]==t)&&fi[i]==f), eachindex(ti, fi)) if !isnothing(ind) gmat[i, j] = cterms[ind] end end end - return CalTable(sites, lookup, time, gmat) + return CalTable(sites, lookup, time, freq, gmat) end From 4d52d6f5c2e9939a6000046910d7e150d4014c61 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 19 Dec 2024 08:02:43 -0500 Subject: [PATCH 17/34] small bug fix in caltable and tests --- src/instrument/caltable.jl | 3 +-- test/Core/models.jl | 4 ---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/instrument/caltable.jl b/src/instrument/caltable.jl index 4f008fd2..8810288c 100644 --- a/src/instrument/caltable.jl +++ b/src/instrument/caltable.jl @@ -177,7 +177,6 @@ end for (i,s) in enumerate(sites) @series begin T = nonmissingtype(eltype(gt[s])) - @info T tt = Vector{eltype(t)}[] yy = Vector{T}[] seriestype := :scatter @@ -190,7 +189,7 @@ end if i == length(sites) label --> reshape(labels, 1, :) else - label --> nothing + label := nothing end for f in unique(gt[:Fr]) ind = Base.:!.(ismissing.(gt[s])) diff --git a/test/Core/models.jl b/test/Core/models.jl index ce05571a..b7bab138 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -636,10 +636,6 @@ end end - - - end - @testset "Integration" begin _,dvis, amp, lcamp, cphase, dcoh = load_data() ts = Comrade.timestamps(ScanSeg(), arrayconfig(dvis)) From 2fa66a69831caf2e50ba3401c5f62e377abf4516 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 19 Dec 2024 08:14:53 -0500 Subject: [PATCH 18/34] Fix CalTable tests --- src/instrument/caltable.jl | 3 +- test/Core/models.jl | 129 +++++++++++++++---------------------- 2 files changed, 53 insertions(+), 79 deletions(-) diff --git a/src/instrument/caltable.jl b/src/instrument/caltable.jl index 8810288c..92aa4b8e 100644 --- a/src/instrument/caltable.jl +++ b/src/instrument/caltable.jl @@ -74,6 +74,8 @@ struct CalTableRow{T,G} <: Tables.AbstractRow source::CalTable{T,G} end +Tables.columnnames(g::CalTableRow) = [:Ti, :Fr, sites(getfield(g, :source))...] + function Base.propertynames(g::CalTable) return Tables.columnnames(g) end @@ -207,7 +209,6 @@ end end end -Tables.columnnames(g::CalTableRow) = [:Ti, sites(getfield(g, :source))...] using PrettyTables diff --git a/test/Core/models.jl b/test/Core/models.jl index b7bab138..533b424d 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -21,6 +21,55 @@ _ntequal(x::T, y::T) where {T<:Tuple} = map(_ntequal, x, y) _ntequal(x, y) = x ≈ y +function test_caltable(c1) + @test Tables.istable(typeof(c1)) + @test Tables.rowaccess(typeof(c1)) + @test Tables.rows(c1) === c1 + @test Tables.columnaccess(c1) + clmns = Tables.columns(c1) + @test clmns[1] == Comrade.times(c1) + @test clmns[2] == Comrade.frequencies(c1) + @test Bool(prod(skipmissing(Tables.matrix(clmns)[:,begin+2:end]) .== skipmissing(Comrade.gmat(c1)))) + @test c1.Ti == Comrade.times(c1) + @test c1.Ti == Tables.getcolumn(c1, 1) + @test c1.Fr == Comrade.frequencies(c1) + @test c1.Fr == Tables.getcolumn(c1, 2) + + @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 + @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 3))) ≈ 0 + @test Tables.columnnames(c1) == [:Ti, :Fr, sort(sites(amp))...] + + c1row = Tables.getrow(c1, 30) + @test eltype(c1) == typeof(c1row) + @test c1row.Ti == c1.Ti[30] + @test c1row.AA == c1.AA[30] + @test c1row.Fr == c1.Fr[30] + @test Tables.getcolumn(c1row, :AA) == c1.AA[30] + @test Tables.getcolumn(c1row, :Ti) == c1.Ti[30] + @test Tables.getcolumn(c1row, :Fr) == c1.Fr[30] + @test Tables.getcolumn(c1row, 3) == c1.AA[30] + @test Tables.getcolumn(c1row, 2) == c1.Fr[30] + @test Tables.getcolumn(c1row, 1) == c1.Ti[30] + @test propertynames(c1) == propertynames(c1row) == [:Ti, :Fr, sort(sites(amp))...] + + Tables.schema(c1) isa Tables.Schema + Tables.getcolumn(c1, Float64, 1, :test) + Tables.getcolumn(c1, Float64, 2, :test) + + c1[1, :AA] + c1[!, :AA] + c1[:, :AA] + @test length(c1) == length(c1.AA) + @test c1[1 ,:] isa Comrade.CalTableRow + @test length(Tables.getrow(c1, 1:5)) == 5 + + plot(c1) + plot(c1, datagains=true) + plot(c1, sites=(:AA,)) + + show(c1) +end + @testset "SkyModel" begin f = test_model @@ -365,45 +414,7 @@ end @testset "caltable test" begin c1 = caltable(x.lgR) - @test Tables.istable(typeof(c1)) - @test Tables.rowaccess(typeof(c1)) - @test Tables.rows(c1) === c1 - @test Tables.columnaccess(c1) - clmns = Tables.columns(c1) - @test clmns[1] == Comrade.scantimes(c1) - @test Bool(prod(skipmissing(Tables.matrix(clmns)[:,begin+1:end]) .== skipmissing(Comrade.gmat(c1)))) - @test c1.time == Comrade.scantimes(c1) - @test c1.time == Tables.getcolumn(c1, 1) - @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 - @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 2))) ≈ 0 - @test Tables.columnnames(c1) == [:time, sort(sites(amp))...] - - c1row = Tables.getrow(c1, 30) - @test eltype(c1) == typeof(c1row) - @test c1row.time == c1.time[30] - @test c1row.AA == c1.AA[30] - @test Tables.getcolumn(c1row, :AA) == c1.AA[30] - @test Tables.getcolumn(c1row, :time) == c1.time[30] - @test Tables.getcolumn(c1row, 2) == c1.AA[30] - @test Tables.getcolumn(c1row, 1) == c1.time[30] - @test propertynames(c1) == propertynames(c1row) == [:time, sort(sites(amp))...] - - Tables.schema(c1) isa Tables.Schema - Tables.getcolumn(c1, Float64, 1, :test) - Tables.getcolumn(c1, Float64, 2, :test) - - c1[1, :AA] - c1[!, :AA] - c1[:, :AA] - @test length(c1) == length(c1.AA) - @test c1[1 ,:] isa Comrade.CalTableRow - @test length(Tables.getrow(c1, 1:5)) == 5 - - plot(c1) - plot(c1, datagains=true) - plot(c1, sites=(:AA,)) - - show(c1) + test_caltable(c1) end end @@ -593,45 +604,7 @@ end @testset "caltable test" begin c1 = caltable(x.lgR) - @test Tables.istable(typeof(c1)) - @test Tables.rowaccess(typeof(c1)) - @test Tables.rows(c1) === c1 - @test Tables.columnaccess(c1) - clmns = Tables.columns(c1) - @test clmns[1] == Comrade.scantimes(c1) - @test Bool(prod(skipmissing(Tables.matrix(clmns)[:,begin+1:end]) .== skipmissing(Comrade.gmat(c1)))) - @test c1.time == Comrade.scantimes(c1) - @test c1.time == Tables.getcolumn(c1, 1) - @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 - @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 2))) ≈ 0 - @test Tables.columnnames(c1) == [:time, sort(sites(amp))...] - - c1row = Tables.getrow(c1, 30) - @test eltype(c1) == typeof(c1row) - @test c1row.time == c1.time[30] - @test c1row.AA == c1.AA[30] - @test Tables.getcolumn(c1row, :AA) == c1.AA[30] - @test Tables.getcolumn(c1row, :time) == c1.time[30] - @test Tables.getcolumn(c1row, 2) == c1.AA[30] - @test Tables.getcolumn(c1row, 1) == c1.time[30] - @test propertynames(c1) == propertynames(c1row) == [:time, sort(sites(amp))...] - - Tables.schema(c1) isa Tables.Schema - Tables.getcolumn(c1, Float64, 1, :test) - Tables.getcolumn(c1, Float64, 2, :test) - - c1[1, :AA] - c1[!, :AA] - c1[:, :AA] - @test length(c1) == length(c1.AA) - @test c1[1 ,:] isa Comrade.CalTableRow - @test length(Tables.getrow(c1, 1:5)) == 5 - - plot(c1) - plot(c1, datagains=true) - plot(c1, sites=(:AA,)) - - show(c1) + test_caltable(c1) end end From 295ba3259399d4537b63987f196688141ebeda2c Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 19 Dec 2024 08:31:54 -0500 Subject: [PATCH 19/34] Fix bug in caltable tests --- test/Core/models.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Core/models.jl b/test/Core/models.jl index 533b424d..424dbde0 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -21,7 +21,7 @@ _ntequal(x::T, y::T) where {T<:Tuple} = map(_ntequal, x, y) _ntequal(x, y) = x ≈ y -function test_caltable(c1) +function test_caltable(c1, sites) @test Tables.istable(typeof(c1)) @test Tables.rowaccess(typeof(c1)) @test Tables.rows(c1) === c1 @@ -37,7 +37,7 @@ function test_caltable(c1) @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 3))) ≈ 0 - @test Tables.columnnames(c1) == [:Ti, :Fr, sort(sites(amp))...] + @test Tables.columnnames(c1) == [:Ti, :Fr, sites...] c1row = Tables.getrow(c1, 30) @test eltype(c1) == typeof(c1row) @@ -50,7 +50,7 @@ function test_caltable(c1) @test Tables.getcolumn(c1row, 3) == c1.AA[30] @test Tables.getcolumn(c1row, 2) == c1.Fr[30] @test Tables.getcolumn(c1row, 1) == c1.Ti[30] - @test propertynames(c1) == propertynames(c1row) == [:Ti, :Fr, sort(sites(amp))...] + @test propertynames(c1) == propertynames(c1row) == [:Ti, :Fr, sites...] Tables.schema(c1) isa Tables.Schema Tables.getcolumn(c1, Float64, 1, :test) @@ -414,7 +414,7 @@ end @testset "caltable test" begin c1 = caltable(x.lgR) - test_caltable(c1) + test_caltable(c1, sort(sites(amp))) end end @@ -604,7 +604,7 @@ end @testset "caltable test" begin c1 = caltable(x.lgR) - test_caltable(c1) + test_caltable(c1, sort(sites(amp))) end end From 1443391a85f5e133b1e554e07d9a53598e39e51c Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 19 Dec 2024 09:30:00 -0500 Subject: [PATCH 20/34] Fix really dumb scan times issue --- src/instrument/instrument.jl | 2 ++ src/instrument/priors/array_priors.jl | 24 +++++++------ src/instrument/priors/segmentation.jl | 1 - test/Core/models.jl | 51 ++++++++++++++++++++------- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/src/instrument/instrument.jl b/src/instrument/instrument.jl index ebddc7d9..008e4f6f 100644 --- a/src/instrument/instrument.jl +++ b/src/instrument/instrument.jl @@ -16,6 +16,7 @@ Base.in(t::Number, ts::IntegrationTime) = (ts.t0 - ts.dt/2) ≤ t < (ts.t0 + ts. Base.isless(t::IntegrationTime, ts::IntegrationTime) = t.t0 < ts.t0 Base.isless(s::Number, t::IntegrationTime) = s < (t.t0 - t.dt/2) Base.isless(t::IntegrationTime, s::Number) = (t.t0 + t.dt/2) < s +Base.Broadcast.broadcastable(ts::IntegrationTime) = Ref(ts) _center(ts::IntegrationTime) = ts.t0 _region(ts::IntegrationTime) = ts.dt @@ -31,6 +32,7 @@ Base.in(f::Number, fs::FrequencyChannel) = (fs.central-fs.bandwidth/2) ≤ f < ( Base.isless(t::FrequencyChannel, ts::FrequencyChannel) = _center(t) < _center(ts) Base.isless(s::Number, t::FrequencyChannel) = s < (_center(t) - _region(t)/2) Base.isless(t::FrequencyChannel, s::Number) = (_center(t) + _region(t)/2) < s +Base.Broadcast.broadcastable(fs::FrequencyChannel) = Ref(fs) diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index f5b36aa5..fe8c7279 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -84,8 +84,8 @@ function build_sitemap(d::ArrayPrior, array) fchan = freqchannels(SpectralWindow(), array) # Now we find commonalities tf = Tuple{eltype(tstamp), eltype(fchan)}[] - for t in tstamp, f in fchan - if any(x->(x[1]∈t && x[2]∈f), tfs) && ((!((t.t0, f.central) ∈ tf))) + for f in fchan, t in tstamp + if any(x->(x[1]∈t && x[2]∈f), tfs) && ((!((t, f) ∈ tf))) push!(tf, (t, f)) end end @@ -102,15 +102,19 @@ function build_sitemap(d::ArrayPrior, array) tlistre = similar(tlist) slistre = similar(slist) flistre = similar(flist) - # Now rearrange so we have time site ordering (sites are the fastest changing) - tuni = sort(unique(getproperty.(tlist, :t0))) + # Now rearrange so we have frquency, time, site ordering (sites are the fastest changing) + tuni = sort(unique((tlist))) + funi = sort(unique((flist))) ind0 = 1 - for t in tuni - ind = findall(x->x.t0==t, tlist) - tlistre[ind0:ind0+length(ind)-1] .= tlist[ind] - slistre[ind0:ind0+length(ind)-1] .= slist[ind] - flistre[ind0:ind0+length(ind)-1] .= flist[ind] - ind0 += length(ind) + for f in funi, t in tuni + ind = (f .== flist) .& (t .== tlist) + vtlist = @view tlist[ind] + vslist = @view slist[ind] + vflist = @view flist[ind] + tlistre[ind0:ind0+length(vtlist)-1] .= vtlist + slistre[ind0:ind0+length(vtlist)-1] .= vslist + flistre[ind0:ind0+length(vtlist)-1] .= vflist + ind0 += length(vtlist) end return SiteLookup(tlistre, flistre, slistre) end diff --git a/src/instrument/priors/segmentation.jl b/src/instrument/priors/segmentation.jl index 16ff3aeb..da6af450 100644 --- a/src/instrument/priors/segmentation.jl +++ b/src/instrument/priors/segmentation.jl @@ -54,7 +54,6 @@ function timestamps(::ScanSeg, array) mjd = array.mjd # Shift the central time to the middle of the scan dt = (st.stop .- st.start) - dt[end] = dt[end]+0.5 t0 = st.start .+ dt./2 return IntegrationTime.(mjd, t0, dt) diff --git a/test/Core/models.jl b/test/Core/models.jl index 424dbde0..ff9fb684 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -20,6 +20,31 @@ _ntequal(x::T, y::T) where {T<:NamedTuple} = ntequal(values(x), values(y)) _ntequal(x::T, y::T) where {T<:Tuple} = map(_ntequal, x, y) _ntequal(x, y) = x ≈ y +function build_mfvis(vistuple...) + configs = arrayconfig.(vistuple) + vis = vistuple[1] + newdatatables = Comrade.StructArray(reduce(vcat, Comrade.datatable.(configs))) + newscans = reduce(vcat, getfield.(configs,:scans)) + newconfig = Comrade.EHTArrayConfiguration(vis.config.bandwidth, + vis.config.tarr, + newscans, + vis.config.mjd, + vis.config.ra, + vis.config.dec, + vis.config.source, + :UTC, + newdatatables) + newmeasurement = reduce(vcat, Comrade.measurement.(vistuple)) + newnoise = reduce(vcat, Comrade.noise.(vistuple)) + + return Comrade.EHTObservationTable{Comrade.datumtype(vis)}(newmeasurement,newnoise,newconfig) +end + +vistuple = (vis8,vis12) +mfvis = build_mfvis(vistuple) +νlist = [ν8, ν12] +mfgrid = mfimagepixels(fovx, fovy, npix, npix, νlist) + function test_caltable(c1, sites) @test Tables.istable(typeof(c1)) @@ -421,8 +446,10 @@ end @testset "Coherencies Multifrequency" begin - dcoh.config[:Fr][200:end] .= 345e9 - vis = CoherencyMatrix.(Comrade.measurement(dcoh), Ref(CirBasis())) + dcoh2 = deepcopy(dcoh) + dcoh2.config[:Fr] .= 345e9 + dcohmf = build_mfvis(dcoh, dcoh2) + vis = CoherencyMatrix.(Comrade.measurement(dcohmf), Ref(CirBasis())) G = JonesG() do x gR = exp(x.lgR + 1im*x.gpR) gL = gR*exp(x.lgrat + 1im*x.gprat) @@ -467,17 +494,17 @@ end - ointm, printm = Comrade.set_array(intm, arrayconfig(dcoh)) - ointm2, printm2 = Comrade.set_array(intm2, arrayconfig(dcoh)) - ointjg, printjg = Comrade.set_array(intjg, arrayconfig(dcoh)) + ointm, printm = Comrade.set_array(intm, arrayconfig(dcohmf)) + ointm2, printm2 = Comrade.set_array(intm2, arrayconfig(dcohmf)) + ointjg, printjg = Comrade.set_array(intjg, arrayconfig(dcohmf)) x = rand(printjg) fj = forward_jones(JG, x) @test fj[1][1] == x.lg[1] - Fpre = Comrade.preallocate_jones(F, arrayconfig(dcoh), CirBasis()) - Rpre = Comrade.preallocate_jones(JonesR(;add_fr=true), arrayconfig(dcoh), CirBasis()) + Fpre = Comrade.preallocate_jones(F, arrayconfig(dcohmf), CirBasis()) + Rpre = Comrade.preallocate_jones(JonesR(;add_fr=true), arrayconfig(dcohmf), CirBasis()) @test Fpre.matrices[1] ≈ Rpre.matrices[1] @test Fpre.matrices[2] ≈ Rpre.matrices[2] @@ -500,7 +527,7 @@ end @test dp.dLy end - pintm, _ = Comrade.set_array(InstrumentModel(JonesR(;add_fr=true)), arrayconfig(dcoh)) + pintm, _ = Comrade.set_array(InstrumentModel(JonesR(;add_fr=true)), arrayconfig(dcohmf)) x = rand(printm) @@ -520,7 +547,7 @@ end # test_rrule(Comrade.apply_instrument, vis, ointm⊢NoTangent(), (;instrument=x)) # # Now check that everything is being applied right - for s in sites(dcoh) + for s in sites(dcohmf) x.lgR .= 0 x.lgrat .= 0 x.gpR .= 0 @@ -531,9 +558,9 @@ end x.dLy .= 0 - inds1 = findall(x->(x[1]==s), dcoh[:baseline].sites) - inds2 = findall(x->(x[2]==s), dcoh[:baseline].sites) - ninds = findall(x->(x[1]!=s && x[2]!=s), dcoh[:baseline].sites) + inds1 = findall(x->(x[1]==s), dcohmf[:baseline].sites) + inds2 = findall(x->(x[2]==s), dcohmf[:baseline].sites) + ninds = findall(x->(x[1]!=s && x[2]!=s), dcohmf[:baseline].sites) # Now amp-offsets x.lgR .= 0 From eca1e0cece2930379fe9d4afd71ac819019f1e83 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 19 Dec 2024 10:40:04 -0500 Subject: [PATCH 21/34] Silly copy pasta mistake --- test/Core/models.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/Core/models.jl b/test/Core/models.jl index ff9fb684..d5f59141 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -40,11 +40,6 @@ function build_mfvis(vistuple...) return Comrade.EHTObservationTable{Comrade.datumtype(vis)}(newmeasurement,newnoise,newconfig) end -vistuple = (vis8,vis12) -mfvis = build_mfvis(vistuple) -νlist = [ν8, ν12] -mfgrid = mfimagepixels(fovx, fovy, npix, npix, νlist) - function test_caltable(c1, sites) @test Tables.istable(typeof(c1)) From b4688affe740ba6d24e2a85dc57c707db92b9ee9 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 19 Dec 2024 10:40:42 -0500 Subject: [PATCH 22/34] Fix this --- test/Core/models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Core/models.jl b/test/Core/models.jl index d5f59141..e6274be5 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -24,7 +24,7 @@ function build_mfvis(vistuple...) configs = arrayconfig.(vistuple) vis = vistuple[1] newdatatables = Comrade.StructArray(reduce(vcat, Comrade.datatable.(configs))) - newscans = reduce(vcat, getfield.(configs,:scans)) + newscans = reduce(vcat, configs.scans) newconfig = Comrade.EHTArrayConfiguration(vis.config.bandwidth, vis.config.tarr, newscans, From 39fa84f9e3e449cb8de3de8caf3b67681bb6bc60 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 21 Dec 2024 13:22:29 -0500 Subject: [PATCH 23/34] finish up tests --- test/Core/core.jl | 6 +- test/Core/models.jl | 204 +++++++++++++++----------------------------- test/runtests.jl | 10 +-- 3 files changed, 77 insertions(+), 143 deletions(-) diff --git a/test/Core/core.jl b/test/Core/core.jl index 7523c531..9054ebe6 100644 --- a/test/Core/core.jl +++ b/test/Core/core.jl @@ -9,8 +9,8 @@ using VLBIImagePriors include(joinpath(@__DIR__, "../test_util.jl")) -include(joinpath(@__DIR__, "observation.jl")) -include(joinpath(@__DIR__, "partially_fixed.jl")) +# include(joinpath(@__DIR__, "observation.jl")) +# include(joinpath(@__DIR__, "partially_fixed.jl")) include(joinpath(@__DIR__, "models.jl")) -include(joinpath(@__DIR__, "bayes.jl")) +# include(joinpath(@__DIR__, "bayes.jl")) # include(joinpath(@__DIR__, "rules.jl")) diff --git a/test/Core/models.jl b/test/Core/models.jl index e6274be5..e2f23d7e 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -24,7 +24,7 @@ function build_mfvis(vistuple...) configs = arrayconfig.(vistuple) vis = vistuple[1] newdatatables = Comrade.StructArray(reduce(vcat, Comrade.datatable.(configs))) - newscans = reduce(vcat, configs.scans) + newscans = reduce(vcat, getproperty.(configs, :scans)) newconfig = Comrade.EHTArrayConfiguration(vis.config.bandwidth, vis.config.tarr, newscans, @@ -174,8 +174,8 @@ end @test Comrade.SiteArray(x.lg, sl) == x.lg - Comrade.time(x.lg, 5.0..6.0) - Comrade.frequency(x.lg, 1.0..400.0) + @inferred Comrade.time(x.lg, 5.0..6.0) + @inferred Comrade.frequency(x.lg, 1.0..400.0) # ps = ProjectTo(x.lg) # @test ps(x.lg) == x.lg @@ -444,7 +444,8 @@ end dcoh2 = deepcopy(dcoh) dcoh2.config[:Fr] .= 345e9 dcohmf = build_mfvis(dcoh, dcoh2) - vis = CoherencyMatrix.(Comrade.measurement(dcohmf), Ref(CirBasis())) + vissi = CoherencyMatrix.(Comrade.measurement(dcoh), Ref(CirBasis())) + vismf = CoherencyMatrix.(Comrade.measurement(dcohmf), Ref(CirBasis())) G = JonesG() do x gR = exp(x.lgR + 1im*x.gpR) gL = gR*exp(x.lgrat + 1im*x.gprat) @@ -460,132 +461,79 @@ end R = JonesR(;add_fr=true) J = JonesSandwich(*, G, D, R) - J2 = JonesSandwich(G, D, R) do g, d, r - return g*d*r - end - - F = JonesF() - JG = GenericJones(x->(x.lg, x.lg, x.lg, x.lg)) - - intprior = ( - lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), - gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), - lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), - gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), - dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), - dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), - dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), - dLy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), + lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), + gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), ) intm = InstrumentModel(J, intprior) - intm2 = InstrumentModel(J2, intprior) - intjg = InstrumentModel(JG, (;lg = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))))) show(IOBuffer(), MIME"text/plain"(), intm) + ointsi, printsi = Comrade.set_array(intm, arrayconfig(dcoh)) + ointmf, printmf = Comrade.set_array(intm, arrayconfig(dcohmf)) - ointm, printm = Comrade.set_array(intm, arrayconfig(dcohmf)) - ointm2, printm2 = Comrade.set_array(intm2, arrayconfig(dcohmf)) - ointjg, printjg = Comrade.set_array(intjg, arrayconfig(dcohmf)) - - x = rand(printjg) - fj = forward_jones(JG, x) - @test fj[1][1] == x.lg[1] - - - Fpre = Comrade.preallocate_jones(F, arrayconfig(dcohmf), CirBasis()) - Rpre = Comrade.preallocate_jones(JonesR(;add_fr=true), arrayconfig(dcohmf), CirBasis()) - @test Fpre.matrices[1] ≈ Rpre.matrices[1] - @test Fpre.matrices[2] ≈ Rpre.matrices[2] + Rsi = Comrade.preallocate_jones(F, arrayconfig(dcoh), CirBasis()) + Rmf = Comrade.preallocate_jones(R, arrayconfig(dcohmf), CirBasis()) + # Check that the copied matrices are identical + @test Rsi.matrices[1] ≈ Rmf.matrices[1][1:length(Rsi.matrices[1])] + @test Rsi.matrices[1] ≈ Rmf.matrices[1][length(Rsi.matrices[1])+1:end] + @test Rsi.matrices[2] ≈ Rmf.matrices[2][1:length(Rsi.matrices[1])] + @test Rsi.matrices[2] ≈ Rmf.matrices[2][length(Rsi.matrices[1])+1:end] - @testset "ObservedArrayPrior" begin - @inferred logpdf(printm, rand(printm)) - @inferred logpdf(printm2, rand(printm2)) - x = rand(printm) - @test logpdf(printm, x) ≈ logpdf(printm2, x) - @test asflat(printm) isa TV.AbstractTransform - p = rand(printm) - t = asflat(printm) - pout = TV.transform(t, TV.inverse(t, p)) - dp = ntequal(p, pout) - @test dp.lgR - @test dp.lgrat - @test dp.gprat - @test dp.dRx - @test dp.dRy - @test dp.dLx - @test dp.dLy + for p in propertynames(ointsi.bsitelookup) + L = length(ointsi.bsitelookup[p].indices_1) + @test ointsi.bsitelookup[p].indices_1 == ointmf.bsitelookup[p].indices_1[1:L] + @test ointsi.bsitelookup[p].indices_2 == ointmf.bsitelookup[p].indices_2[1:L] + @test 2*L == length(ointmf.bsitelookup[p].indices_1) end - pintm, _ = Comrade.set_array(InstrumentModel(JonesR(;add_fr=true)), arrayconfig(dcohmf)) - - - x = rand(printm) - x.lgR .= 0 - x.lgrat .= 0 - x.gpR .= 0 - x.gprat .= 0 - x.dRx .= 0 - x.dRy .= 0 - x.dLx .= 0 - x.dLy .= 0 - - vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) - vper = Comrade.apply_instrument(vis, pintm, (;instrument=NamedTuple())) - @test vout ≈ vper - - # test_rrule(Comrade.apply_instrument, vis, ointm⊢NoTangent(), (;instrument=x)) - - # # Now check that everything is being applied right - for s in sites(dcohmf) - x.lgR .= 0 - x.lgrat .= 0 - x.gpR .= 0 - x.gprat .= 0 - x.dRx .= 0 - x.dRy .= 0 - x.dLx .= 0 - x.dLy .= 0 - + pintmf, _ = Comrade.set_array(InstrumentModel(R), arrayconfig(dcohmf)) - inds1 = findall(x->(x[1]==s), dcohmf[:baseline].sites) - inds2 = findall(x->(x[2]==s), dcohmf[:baseline].sites) - ninds = findall(x->(x[1]!=s && x[2]!=s), dcohmf[:baseline].sites) - - # Now amp-offsets - x.lgR .= 0 - x.lgrat .= 0 - x.gpR .= 0 - x.gprat .= 0 - x.dRx .= 0 - x.dRy .= 0 - x.dLx .= 0 - x.dLy .= 0 + xsi = rand(printsi) + xmf = rand(printmf) - xlgRs = x.lgR[S=s] - xlgRs .= log(2) - xlgrat = x.lgrat[S=s] - xlgrat .= -log(2) - vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) - G = SMatrix{2,2}(2.0, 0.0, 0.0, 1.0) - @test vout[inds1] ≈ Ref(G) .*vper[inds1] - @test vout[inds2] ≈ vper[inds2] .* Ref(G) - @test vout[ninds] ≈ vper[ninds] + for s in sites(dcoh) + map(x->fill!(x, 0.0), xsi) + map(x->fill!(x, 0.0), xmf) + + inds1si = findall(x->(x[1]==s), dcoh[:baseline].sites) + inds2si = findall(x->(x[2]==s), dcoh[:baseline].sites) + nindssi = findall(x->(x[1]!=s && x[2]!=s), dcoh[:baseline].sites) + + inds1mf = findall(x->(x[1]==s), dcohmf[:baseline].sites) + inds2mf = findall(x->(x[2]==s), dcohmf[:baseline].sites) + nindsmf = findall(x->(x[1]!=s && x[2]!=s), dcohmf[:baseline].sites) + + xsilgR = xsi.lgR[S=s] + xsilgR .= log(2) + xmflgR = xmf.lgR[S=s] + xmflgR[1:length(xsilgR)] .= xsilgR + xmflgR[length(xsilgR)+1:end] .= 2 .* xsilgR + + xsilgrat = xsi.lgrat[S=s] + xsilgrat .= -log(2) + xmflgrat = xmf.lgrat[S=s] + xmflgrat[1:length(xsilgrat)] .= xsilgrat + xmflgrat[length(xsilgrat)+1:end] .= 2 .* xsilgrat + vmf = Comrade.apply_instrument(vismf, ointmf, (;instrument=xmf)) + vsi = Comrade.apply_instrument(vissi, ointsi, (;instrument=xsi)) + Gmf = SMatrix{2,2}(2.0, 0.0, 0.0, 1.0) + @test vsi[inds1si] ≈ vmf[inds1si] + @test vsi[inds1si] ≈ vmf[inds1mf[length(inds1si)+1:end]] .* Ref(Gmf) # Now phases - x.lgR .= 0 - x.lgrat .= 0 - x.gpR .= 0 - x.gprat .= 0 - x.dRx .= 0 - x.dRy .= 0 - x.dLx .= 0 - x.dLy .= 0 + map(x->fill!(x, 0.0), xsi) + map(x->fill!(x, 0.0), xmf) xgpRs = x.gpR[S=s] xgpRs .= π/3 @@ -598,31 +546,17 @@ end @test vout[ninds] ≈ vper[ninds] - # Now dterms - x.lgR .= 0 - x.lgrat .= 0 - x.gpR .= 0 - x.gprat .= 0 - x.dRx .= 0 - x.dRy .= 0 - x.dLx .= 0 - x.dLy .= 0 + end - xdRxs = x.dRx[S=s] - xdRxs .= 0.1 - xdRys = x.dRy[S=s] - xdRys .= 0.2 - xdLxs = x.dLx[S=s] - xdLxs .= 0.3 - xdLys = x.dLy[S=s] - xdLys .= 0.4 - vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) - D = SMatrix{2,2}(1.0, 0.3 + 0.4im, 0.1 + 0.2im, 1.0) - @test vout[inds1] ≈ Ref(D) .*vper[inds1] - @test vout[inds2] ≈ vper[inds2] .* Ref(adjoint(D)) - @test vout[ninds] ≈ vper[ninds] - end + voutsi = Comrade.apply_instrument(vissi, ointsi, (;instrument=xsi)) + voutmf = Comrade.apply_instrument(vismf, ointmf, (;instrument=xmf)) + voutmf[1:length(voutsi)] ≈ voutsi + voutmf[length(voutsi)+1:end] ≈ voutsi.*exp(1 + 2*1im) + + # # Now check that everything is being applied right + + @testset "caltable test" begin c1 = caltable(x.lgR) diff --git a/test/runtests.jl b/test/runtests.jl index b7132c0d..a5adb06b 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,9 +15,9 @@ include(joinpath(@__DIR__, "test_util.jl")) Pkg.develop(PackageSpec(url="https://github.com/ptiede/ComradeBase.jl")) @testset "Comrade.jl" begin include(joinpath(@__DIR__, "Core/core.jl")) - include(joinpath(@__DIR__, "ext/comradeahmc.jl")) - include(joinpath(@__DIR__, "ext/comradeoptimization.jl")) - include(joinpath(@__DIR__, "ext/comradepigeons.jl")) - include(joinpath(@__DIR__, "ext/comradedynesty.jl")) - include(joinpath(@__DIR__, "ext/comradenested.jl")) + # include(joinpath(@__DIR__, "ext/comradeahmc.jl")) + # include(joinpath(@__DIR__, "ext/comradeoptimization.jl")) + # include(joinpath(@__DIR__, "ext/comradepigeons.jl")) + # include(joinpath(@__DIR__, "ext/comradedynesty.jl")) + # include(joinpath(@__DIR__, "ext/comradenested.jl")) end From c5f434e39d9625e69b93f7b1d23bd039dfce20af Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 21 Dec 2024 13:55:38 -0500 Subject: [PATCH 24/34] tests passing locally --- src/instrument/caltable.jl | 4 ++-- test/Core/core.jl | 8 ++++---- test/Core/models.jl | 42 +++++++++++++++++++------------------- test/runtests.jl | 10 ++++----- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/instrument/caltable.jl b/src/instrument/caltable.jl index 92aa4b8e..21f5b69c 100644 --- a/src/instrument/caltable.jl +++ b/src/instrument/caltable.jl @@ -138,13 +138,13 @@ end function Tables.getcolumn(g::CalTableRow, ::Type, col::Int, nm::Symbol) (col == 1 || nm == :Ti) && return times(getfield(g, :source))[getfield(g, :row)] (col == 2 || nm == :Fr) && return frequencies(getfield(g, :source))[getfield(g, :row)] - gmat(getfield(g, :source))[getfield(g, :row), col-1] + gmat(getfield(g, :source))[getfield(g, :row), col-2] end function Tables.getcolumn(g::CalTableRow, i::Int) (i==1) && return times(getfield(g, :source))[getfield(g, :row)] (i==2) && return frequencies(getfield(g, :source))[getfield(g, :row)] - gmat(getfield(g, :source))[getfield(g, :row), i-1] + gmat(getfield(g, :source))[getfield(g, :row), i-2] end function Tables.getcolumn(g::CalTableRow, nm::Symbol) diff --git a/test/Core/core.jl b/test/Core/core.jl index 9054ebe6..1ef7e53e 100644 --- a/test/Core/core.jl +++ b/test/Core/core.jl @@ -9,8 +9,8 @@ using VLBIImagePriors include(joinpath(@__DIR__, "../test_util.jl")) -# include(joinpath(@__DIR__, "observation.jl")) -# include(joinpath(@__DIR__, "partially_fixed.jl")) +include(joinpath(@__DIR__, "observation.jl")) +include(joinpath(@__DIR__, "partially_fixed.jl")) include(joinpath(@__DIR__, "models.jl")) -# include(joinpath(@__DIR__, "bayes.jl")) -# include(joinpath(@__DIR__, "rules.jl")) +include(joinpath(@__DIR__, "bayes.jl")) +include(joinpath(@__DIR__, "rules.jl")) diff --git a/test/Core/models.jl b/test/Core/models.jl index e2f23d7e..ca5c9f87 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -524,42 +524,42 @@ end xsilgrat .= -log(2) xmflgrat = xmf.lgrat[S=s] xmflgrat[1:length(xsilgrat)] .= xsilgrat - xmflgrat[length(xsilgrat)+1:end] .= 2 .* xsilgrat + xmflgrat[length(xsilgrat)+1:end] .= xsilgrat vmf = Comrade.apply_instrument(vismf, ointmf, (;instrument=xmf)) vsi = Comrade.apply_instrument(vissi, ointsi, (;instrument=xsi)) - Gmf = SMatrix{2,2}(2.0, 0.0, 0.0, 1.0) + Gmf = SMatrix{2,2}(2.0, 0.0, 0.0, 2.0) @test vsi[inds1si] ≈ vmf[inds1si] - @test vsi[inds1si] ≈ vmf[inds1mf[length(inds1si)+1:end]] .* Ref(Gmf) + @test vsi[inds1si] .* Ref(Gmf) ≈ vmf[inds1mf[length(inds1si)+1:end]] # Now phases map(x->fill!(x, 0.0), xsi) map(x->fill!(x, 0.0), xmf) - xgpRs = x.gpR[S=s] - xgpRs .= π/3 - xgprat = x.gprat[S=s] - xgprat .= -π/3 - vout = Comrade.apply_instrument(vis, ointm, (;instrument=x)) - G = SMatrix{2,2}(exp(1im*π/3), 0.0, 0.0, exp(1im*0.0)) - @test vout[inds1] ≈ Ref(G) .*vper[inds1] - @test vout[inds2] ≈ vper[inds2] .* Ref(adjoint(G)) - @test vout[ninds] ≈ vper[ninds] + xsigpR = xsi.gpR[S=s] + xsigpR .= π/3 + xmfgpR = xmf.gpR[S=s] + xmfgpR[1:length(xsigpR)] .= xsigpR + xmfgpR[length(xsilgR)+1:end] .= 2 .* xsigpR + xsigprat = xsi.gprat[S=s] + xsigprat .= -π/6 + xmfgprat = xmf.gprat[S=s] + xmfgprat[1:length(xsigprat)] .= xsigprat + xmfgprat[length(xsigprat)+1:end] .= xsigprat - end - - - voutsi = Comrade.apply_instrument(vissi, ointsi, (;instrument=xsi)) - voutmf = Comrade.apply_instrument(vismf, ointmf, (;instrument=xmf)) - voutmf[1:length(voutsi)] ≈ voutsi - voutmf[length(voutsi)+1:end] ≈ voutsi.*exp(1 + 2*1im) + vmf = Comrade.apply_instrument(vismf, ointmf, (;instrument=xmf)) + vsi = Comrade.apply_instrument(vissi, ointsi, (;instrument=xsi)) + Gmf = SMatrix{2,2}(exp(1im*π/3), 0.0, 0.0, exp(1im*π/3)) + @test vsi[inds1si] ≈ vmf[inds1si] + @test vsi[inds1si] .* Ref(Gmf) ≈ vmf[inds1mf[length(inds1si)+1:end]] - # # Now check that everything is being applied right + end @testset "caltable test" begin - c1 = caltable(x.lgR) + xmf = rand(printmf) + c1 = caltable(xmf.lgR) test_caltable(c1, sort(sites(amp))) end diff --git a/test/runtests.jl b/test/runtests.jl index a5adb06b..b7132c0d 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,9 +15,9 @@ include(joinpath(@__DIR__, "test_util.jl")) Pkg.develop(PackageSpec(url="https://github.com/ptiede/ComradeBase.jl")) @testset "Comrade.jl" begin include(joinpath(@__DIR__, "Core/core.jl")) - # include(joinpath(@__DIR__, "ext/comradeahmc.jl")) - # include(joinpath(@__DIR__, "ext/comradeoptimization.jl")) - # include(joinpath(@__DIR__, "ext/comradepigeons.jl")) - # include(joinpath(@__DIR__, "ext/comradedynesty.jl")) - # include(joinpath(@__DIR__, "ext/comradenested.jl")) + include(joinpath(@__DIR__, "ext/comradeahmc.jl")) + include(joinpath(@__DIR__, "ext/comradeoptimization.jl")) + include(joinpath(@__DIR__, "ext/comradepigeons.jl")) + include(joinpath(@__DIR__, "ext/comradedynesty.jl")) + include(joinpath(@__DIR__, "ext/comradenested.jl")) end From c35e4795c76849392f802579fd164683231681a4 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 21 Dec 2024 14:11:29 -0500 Subject: [PATCH 25/34] rule was deleted --- test/Core/core.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Core/core.jl b/test/Core/core.jl index 1ef7e53e..8b6bd6b0 100644 --- a/test/Core/core.jl +++ b/test/Core/core.jl @@ -13,4 +13,3 @@ include(joinpath(@__DIR__, "observation.jl")) include(joinpath(@__DIR__, "partially_fixed.jl")) include(joinpath(@__DIR__, "models.jl")) include(joinpath(@__DIR__, "bayes.jl")) -include(joinpath(@__DIR__, "rules.jl")) From 08b56d053342133e3d5ed8e3d3c5ac3daba2f997 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 21 Dec 2024 15:49:09 -0500 Subject: [PATCH 26/34] Update geometric model --- examples/beginner/GeometricModeling/main.jl | 8 ++++---- examples/beginner/LoadingData/Project.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/beginner/GeometricModeling/main.jl b/examples/beginner/GeometricModeling/main.jl index df880cf7..d3195cf4 100644 --- a/examples/beginner/GeometricModeling/main.jl +++ b/examples/beginner/GeometricModeling/main.jl @@ -81,7 +81,7 @@ prior = ( ξτ= Uniform(0.0, π), f = Uniform(0.0, 1.0), σG = Uniform(μas2rad(1.0), μas2rad(100.0)), - τG = Uniform(0.0, 1.0), + τG = Exponential(1.0), ξG = Uniform(0.0, 1π), xG = Uniform(-μas2rad(80.0), μas2rad(80.0)), yG = Uniform(-μas2rad(80.0), μas2rad(80.0)) @@ -146,8 +146,8 @@ fpost = asflat(post) p = prior_sample(rng, post) # and then transform it to transformed space using T -logdensityof(cpost, Comrade.TV.inverse(cpost, p)) -logdensityof(fpost, Comrade.TV.inverse(fpost, p)) +logdensityof(cpost, Comrade.inverse(cpost, p)) +logdensityof(fpost, Comrade.inverse(fpost, p)) # note that the log densit is not the same since the transformation has causes a jacobian to ensure volume is preserved. @@ -185,7 +185,7 @@ DisplayAs.Text(DisplayAs.PNG(fig)) # parallel tempering sampler that enables global exploration of the posterior. For smaller dimension # problems (< 100) we recommend using this sampler, especially if you have access to > 1 core. using Pigeons -pt = pigeons(target=cpost, explorer=SliceSampler(), record=[traces, round_trip, log_sum_ratio], n_chains=16, n_rounds=8) +pt = pigeons(target=cpost, explorer=SliceSampler(), record=[traces, round_trip, log_sum_ratio], n_chains=16, n_rounds=10) # That's it! To finish it up we can then plot some simple visual fit diagnostics. diff --git a/examples/beginner/LoadingData/Project.toml b/examples/beginner/LoadingData/Project.toml index 5910585b..8168f932 100644 --- a/examples/beginner/LoadingData/Project.toml +++ b/examples/beginner/LoadingData/Project.toml @@ -6,4 +6,4 @@ Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468" [compat] Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" From 27439fe118781789f0e6c96fe4002075583f7253 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 21 Dec 2024 17:24:51 -0500 Subject: [PATCH 27/34] Small updates to tther examples --- .../intermediate/PolarizedImaging/main.jl | 8 ++-- examples/intermediate/StokesIImaging/main.jl | 6 +-- ext/ComradeOptimizationExt.jl | 2 +- src/instrument/caltable.jl | 41 +++++++++---------- 4 files changed, 27 insertions(+), 30 deletions(-) diff --git a/examples/intermediate/PolarizedImaging/main.jl b/examples/intermediate/PolarizedImaging/main.jl index a633bb2f..da47f065 100644 --- a/examples/intermediate/PolarizedImaging/main.jl +++ b/examples/intermediate/PolarizedImaging/main.jl @@ -366,16 +366,16 @@ gamp_ratio = caltable(exp.(xopt.instrument.lgrat)) # expected since gain ratios are typically stable over the course of an observation and the constant # offset was removed in the EHT calibration process. gphaseR = caltable(xopt.instrument.gpR) -p = Plots.plot(gphaseR, layout=(3,3), size=(650,500)); -Plots.plot!(p, gphase_ratio, layout=(3,3), size=(650,500)); +p = Plots.plot(gphaseR, layout=(3,3), size=(650,500), label="R Gain Phase"); +Plots.plot!(p, gphase_ratio, layout=(3,3), size=(650,500), label="Gain Phase Ratio"); p |> DisplayAs.PNG |> DisplayAs.Text #- # Moving to the amplitudes we see largely stable gain amplitudes on the right circular polarization except for LMT which is # known and due to pointing issues during the 2017 observation. Again the gain ratios are stable and close to unity. Typically # we expect that apriori calibration should make the gain ratios close to unity. gampr = caltable(exp.(xopt.instrument.lgR)) -p = Plots.plot(gampr, layout=(3,3), size=(650,500)) -Plots.plot!(p, gamp_ratio, layout=(3,3), size=(650,500)) +p = Plots.plot(gampr, layout=(3,3), size=(650,500), label="R Gain Amp."); +Plots.plot!(p, gamp_ratio, layout=(3,3), size=(650,500), label="Gain Amp. Ratio") p |> DisplayAs.PNG |> DisplayAs.Text #- diff --git a/examples/intermediate/StokesIImaging/main.jl b/examples/intermediate/StokesIImaging/main.jl index 4933b901..48804e80 100644 --- a/examples/intermediate/StokesIImaging/main.jl +++ b/examples/intermediate/StokesIImaging/main.jl @@ -131,15 +131,13 @@ skym = SkyModel(sky, prior, grid; metadata=skymeta) # - Gain phases which are more difficult to constrain and can shift rapidly. G = SingleStokesGain() do x - lg = x.lgμ + x.lgσ*x.lgz + lg = x.lg gp = x.gp return exp(lg + 1im*gp) end intpr = ( - lgμ = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2)); LM = IIDSitePrior(TrackSeg(), Normal(0.0, 1.0))), - lgσ = ArrayPrior(IIDSitePrior(TrackSeg(), Exponential(0.1))), - lgz = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 1.0))), + lg = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.2)); LM = IIDSitePrior(ScanSeg(), Normal(0.0, 1.0))), gp= ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(π^2))); refant=SEFDReference(0.0), phase=true) ) intmodel = InstrumentModel(G, intpr) diff --git a/ext/ComradeOptimizationExt.jl b/ext/ComradeOptimizationExt.jl index aed367c1..c0f4091f 100644 --- a/ext/ComradeOptimizationExt.jl +++ b/ext/ComradeOptimizationExt.jl @@ -14,7 +14,7 @@ function Optimization.OptimizationFunction(post::Comrade.TransformedVLBIPosterio return SciMLBase.OptimizationFunction(ℓ, args...; kwargs...) else function grad(G, x, p) - (_, dx) = LogDensityProblems.logdensity_and_gradient(post, x) + (_, dx) = Comrade.LogDensityProblems.logdensity_and_gradient(post, x) dx .*= -1 G .= dx return G diff --git a/src/instrument/caltable.jl b/src/instrument/caltable.jl index 21f5b69c..44b9b140 100644 --- a/src/instrument/caltable.jl +++ b/src/instrument/caltable.jl @@ -177,34 +177,33 @@ end t = getproperty.(gt[:Ti], :t0) xlims --> (t[begin], t[end] + 0.01*abs(t[end])) for (i,s) in enumerate(sites) - @series begin - T = nonmissingtype(eltype(gt[s])) - tt = Vector{eltype(t)}[] - yy = Vector{T}[] - seriestype := :scatter - subplot := i - title := String(s) - if i == length(sites) - xguide --> "Time (UTC)" - end - labels = ["$(f.central/1e9) GHz" for f in unique(gt[:Fr])] - if i == length(sites) - label --> reshape(labels, 1, :) - else - label := nothing - end - for f in unique(gt[:Fr]) + T = nonmissingtype(eltype(gt[s])) + for (j,f) in enumerate(unique(gt[:Fr])) + @series begin + seriestype := :scatter + subplot := i + title := String(s) + if i == length(sites) + xguide --> "Time (UTC)" + end + label = "$(round(f.central/1e9, digits=1)) GHz" + if i == length(sites) + label --> label + else + label := nothing + end ind = Base.:!.(ismissing.(gt[s])) find = findall(==(f), gt[:Fr][ind]) #x := gt[:Ti][ind] - push!(tt, t[ind][find]) + x = t[ind][find] + if !datagains - push!(yy, T.(gt[s][ind][find])) + y = T.(gt[s][ind][find]) else - push!(yy, T.(inv.(gt[s][ind][find]))) + y = T.(inv.((gt[s][ind][find]))) end + x, y end - tt, yy end end end From cc0f8c6feb24a005602b282130255931c8553c3d Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 21 Dec 2024 20:40:58 -0500 Subject: [PATCH 28/34] Remove the Pkg.develop in GeometricModeling since there is some strange Union{} bug --- docs/tutorials.jl | 2 +- examples/advanced/HybridImaging/main.jl | 24 +++++++++--------- examples/beginner/GeometricModeling/main.jl | 25 +++++++++++-------- examples/beginner/LoadingData/main.jl | 22 +++++++--------- examples/intermediate/ClosureImaging/main.jl | 21 ++++++++-------- .../intermediate/PolarizedImaging/main.jl | 17 +++++++------ examples/intermediate/StokesIImaging/main.jl | 18 ++++++------- 7 files changed, 64 insertions(+), 65 deletions(-) diff --git a/docs/tutorials.jl b/docs/tutorials.jl index c2dae2fe..0ac9c33a 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -9,8 +9,8 @@ OUTPUT = joinpath(@__DIR__, "src", "tutorials") TUTORIALS = [ "beginner/LoadingData/main.jl", - "beginner/GeometricModeling/main.jl", "intermediate/ClosureImaging/main.jl", + "beginner/GeometricModeling/main.jl", "intermediate/StokesIImaging/main.jl", "intermediate/PolarizedImaging/main.jl", "advanced/HybridImaging/main.jl", diff --git a/examples/advanced/HybridImaging/main.jl b/examples/advanced/HybridImaging/main.jl index ceb15692..e7ac1e4b 100644 --- a/examples/advanced/HybridImaging/main.jl +++ b/examples/advanced/HybridImaging/main.jl @@ -1,3 +1,15 @@ +import Pkg #hide +__DIR = @__DIR__ #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide + +ENV["GKSwstype"] = "nul" #hide + + # # Hybrid Imaging of a Black Hole # In this tutorial, we will use **hybrid imaging** to analyze the 2017 EHT data. @@ -16,18 +28,6 @@ # This is the approach we will take in this tutorial to analyze the April 6 2017 EHT data # of M87. -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide - -ENV["GKSwstype"] = "nul" #hide - - # ## Loading the Data # To get started we will load Comrade diff --git a/examples/beginner/GeometricModeling/main.jl b/examples/beginner/GeometricModeling/main.jl index d3195cf4..64b02036 100644 --- a/examples/beginner/GeometricModeling/main.jl +++ b/examples/beginner/GeometricModeling/main.jl @@ -1,3 +1,13 @@ +import Pkg; #hide +__DIR = @__DIR__; #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +## Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide + + # # Geometric Modeling of EHT Data # `Comrade` has been designed to work with the EHT and ngEHT. @@ -9,22 +19,15 @@ # In this tutorial, we will construct a similar model and fit it to the data in under # 50 lines of code (sans comments). To start, we load Comrade and some other packages we need. -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide -ENV["GKSwstype"] = "nul" #hide - +# To get started we load Comrade. +#- using Comrade +# Currently we use eht-imaging for data management, however this will soon be replaced +# by a pure Julia solution. using Pyehtim - # For reproducibility we use a stable random number genreator using StableRNGs rng = StableRNG(42) diff --git a/examples/beginner/LoadingData/main.jl b/examples/beginner/LoadingData/main.jl index a5236533..c064e438 100644 --- a/examples/beginner/LoadingData/main.jl +++ b/examples/beginner/LoadingData/main.jl @@ -1,14 +1,3 @@ -# # Loading Data into Comrade - -# The VLBI field does not have a standardized data format, and the EHT uses a -# particular uvfits format similar to the optical interferometry oifits format. -# As a result, we reuse the excellent `eht-imaging` package to load data into `Comrade`. - -# Once the data is loaded, we then convert the data into the tabular format `Comrade` -# expects. Note that this may change to a Julia package as the Julia radio -# astronomy group grows. - -# To get started, we will load `Comrade` and `Plots` to enable visualizations of the data import Pkg #hide __DIR = @__DIR__ #hide pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide @@ -18,11 +7,18 @@ Pkg.instantiate(; io=pkg_io) #hide Pkg.precompile(; io=pkg_io) #hide close(pkg_io) #hide -ENV["GKSwstype"] = "nul" #hide +# # Loading Data into Comrade +# The VLBI field does not have a standardized data format, and the EHT uses a +# particular uvfits format similar to the optical interferometry oifits format. +# As a result, we reuse the excellent `eht-imaging` package to load data into `Comrade`. -using Comrade +# Once the data is loaded, we then convert the data into the tabular format `Comrade` +# expects. Note that this may change to a Julia package as the Julia radio +# astronomy group grows. +# To get started, we will load `Comrade` and `Plots` to enable visualizations of the data +using Comrade using Plots # We also load Pyehtim since it loads eht-imaging into Julia using PythonCall and exports diff --git a/examples/intermediate/ClosureImaging/main.jl b/examples/intermediate/ClosureImaging/main.jl index 44c0d5b0..07070e94 100644 --- a/examples/intermediate/ClosureImaging/main.jl +++ b/examples/intermediate/ClosureImaging/main.jl @@ -1,3 +1,13 @@ +import Pkg #hide +__DIR = @__DIR__ #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide +ENV["GKSwstype"] = "nul"; #hide + # # Imaging a Black Hole using only Closure Quantities # In this tutorial, we will create a preliminary reconstruction of the 2017 M87 data on April 6 @@ -21,17 +31,6 @@ # # In this tutorial, we will do closure-only modeling of M87 to produce a posterior of images of M87. -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide -ENV["GKSwstype"] = "nul"; #hide - - # To get started, we will load Comrade using Comrade diff --git a/examples/intermediate/PolarizedImaging/main.jl b/examples/intermediate/PolarizedImaging/main.jl index da47f065..dfc47a53 100644 --- a/examples/intermediate/PolarizedImaging/main.jl +++ b/examples/intermediate/PolarizedImaging/main.jl @@ -1,3 +1,12 @@ +import Pkg #hide +__DIR = @__DIR__ #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide + # # Polarized Image and Instrumental Modeling # In this tutorial, we will analyze a simulated simple polarized dataset to demonstrate @@ -85,14 +94,6 @@ # In the rest of the tutorial, we are going to solve for all of these instrument model terms in # while re-creating the polarized image from the first [`EHT results on M87`](https://iopscience.iop.org/article/10.3847/2041-8213/abe71d). -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide # ## Load the Data diff --git a/examples/intermediate/StokesIImaging/main.jl b/examples/intermediate/StokesIImaging/main.jl index 48804e80..7f9d1b5a 100644 --- a/examples/intermediate/StokesIImaging/main.jl +++ b/examples/intermediate/StokesIImaging/main.jl @@ -1,12 +1,3 @@ -# # Stokes I Simultaneous Image and Instrument Modeling - -# In this tutorial, we will create a preliminary reconstruction of the 2017 M87 data on April 6 -# by simultaneously creating an image and model for the instrument. By instrument model, we -# mean something akin to self-calibration in traditional VLBI imaging terminology. However, -# unlike traditional self-cal, we will solve for the gains each time we update the image -# self-consistently. This allows us to model the correlations between gains and the image. - -# To get started we load Comrade. import Pkg #hide __DIR = @__DIR__ #hide pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide @@ -17,6 +8,15 @@ Pkg.precompile(; io=pkg_io) #hide close(pkg_io) #hide +# # Stokes I Simultaneous Image and Instrument Modeling + +# In this tutorial, we will create a preliminary reconstruction of the 2017 M87 data on April 6 +# by simultaneously creating an image and model for the instrument. By instrument model, we +# mean something akin to self-calibration in traditional VLBI imaging terminology. However, +# unlike traditional self-cal, we will solve for the gains each time we update the image +# self-consistently. This allows us to model the correlations between gains and the image. + +# To get started we load Comrade. using Comrade From 092699153f410b58029afb366b9ef68be47f67ff Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 21 Dec 2024 23:20:00 -0500 Subject: [PATCH 29/34] Remove CondaPkg.toml --- docs/CondaPkg.toml | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 docs/CondaPkg.toml diff --git a/docs/CondaPkg.toml b/docs/CondaPkg.toml deleted file mode 100644 index a44d2b67..00000000 --- a/docs/CondaPkg.toml +++ /dev/null @@ -1,16 +0,0 @@ - -[deps] -python = ">=3.6,<=3.10" -# astropy = "" -# ephem = "" -# future = "" -# h5py = "" -# ipython = "" -# matplotlib = "" -# networkx = "" -numpy = "<=1.23" -pandas = "<2" -# scipy = "" -# scikit-image = "" -[pip.deps] -ehtim = "" From 4a27ccdc1735af52bde329369b4143113918d6af Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 22 Dec 2024 00:17:22 -0500 Subject: [PATCH 30/34] Improve coverage --- src/instrument/instrument.jl | 4 +-- src/instrument/instrument_transforms.jl | 15 +++++----- test/Core/models.jl | 39 +++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/instrument/instrument.jl b/src/instrument/instrument.jl index 008e4f6f..10e889e6 100644 --- a/src/instrument/instrument.jl +++ b/src/instrument/instrument.jl @@ -4,8 +4,8 @@ import Distributions using Statistics using PrettyTables -struct IntegrationTime{T} - mjd::Int +struct IntegrationTime{I<:Integer, T} + mjd::I t0::T dt::T end diff --git a/src/instrument/instrument_transforms.jl b/src/instrument/instrument_transforms.jl index 580cccf4..e1e2dcc9 100644 --- a/src/instrument/instrument_transforms.jl +++ b/src/instrument/instrument_transforms.jl @@ -1,10 +1,11 @@ abstract type AbstractInstrumentTransform <: TV.VectorTransform end -site_map(t::AbstractInstrumentTransform) = t.site_map +site_map(t::AbstractInstrumentTransform) = t +EnzymeRules.inactive(::typeof(site_map), args...) = nothing inner_transform(t::AbstractInstrumentTransform) = t.inner_transform function TV.transform_with(flag::TV.LogJacFlag, m::AbstractInstrumentTransform, x, index) y, ℓ, index = _instrument_transform_with(flag, m, x, index) - sm = m.site_map + sm = site_map(m) return SiteArray(y, sm), ℓ, index end @@ -97,11 +98,11 @@ function site_diff!(y, site_map::SiteLookup) return nothing end -function simplediff(x::AbstractVector) - y = zero(x) - simplediff!(y, x) - return y -end +# function simplediff(x::AbstractVector) +# y = zero(x) +# simplediff!(y, x) +# return y +# end function simplediff!(y::AbstractVector, x::AbstractVector) y[begin] = x[begin] diff --git a/test/Core/models.jl b/test/Core/models.jl index ca5c9f87..eaf0b84c 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -41,6 +41,13 @@ function build_mfvis(vistuple...) end +function isequalmissing(x, y) + xm = x |> ismissing |> collect + ym = y |> ismissing |> collect + return xm == ym +end + + function test_caltable(c1, sites) @test Tables.istable(typeof(c1)) @test Tables.rowaccess(typeof(c1)) @@ -54,6 +61,7 @@ function test_caltable(c1, sites) @test c1.Ti == Tables.getcolumn(c1, 1) @test c1.Fr == Comrade.frequencies(c1) @test c1.Fr == Tables.getcolumn(c1, 2) + @test isequalmissing(c1.AA, Tables.getcolumn(c1, 3)) @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 3))) ≈ 0 @@ -71,6 +79,13 @@ function test_caltable(c1, sites) @test Tables.getcolumn(c1row, 2) == c1.Fr[30] @test Tables.getcolumn(c1row, 1) == c1.Ti[30] @test propertynames(c1) == propertynames(c1row) == [:Ti, :Fr, sites...] + @test Tables.getcolumn(c1row, Float64, 1, :Ti) == c1.Ti[30] + @test Tables.getcolumn(c1row, Float64, 2, :Fr) == c1.Fr[30] + @test Tables.getcolumn(c1row, Float64, 3, :AA) == c1.AA[30] + @test isequalmissing(c1[1:10, :AA], c1.AA[1:10]) + @test isequalmissing(c1[[1,2], :AA], c1.AA[[1,2]]) + @test isequalmissing(@view(c1[1:10, :AA]), @view(c1.AA[1:10])) + @test isequalmissing(@view(c1[[1,2], :AA]), @view(c1.AA[[1,2]])) Tables.schema(c1) isa Tables.Schema Tables.getcolumn(c1, Float64, 1, :test) @@ -177,6 +192,11 @@ end @inferred Comrade.time(x.lg, 5.0..6.0) @inferred Comrade.frequency(x.lg, 1.0..400.0) + @test x.lg ≈ SiteArray(x.lg, x.lg.times, x.lg.frequencies, x.lg.sites) + @inferred x.lg[1,1,1] + x.lg[1,1,1,1] = 1.0 + @test x.lg[1] ≈ 1.0 + # ps = ProjectTo(x.lg) # @test ps(x.lg) == x.lg # @test ps(NoTangent()) isa NoTangent @@ -573,5 +593,24 @@ end @test length(tt) < length(ts) ≤ length(ti) end + @testset "IntegrationTime" begin + ti = Comrade.IntegrationTime(10, 5.0, 0.1) + @test Comrade.mjd(ti) == ti.mjd + @test ti.t0 ∈ Comrade.interval(ti) + @test Comrade._center(ti) == ti.t0 + @test Comrade._region(ti) == 0.1 + end + + @testset "FrequencyChannel" begin + fc = Comrade.FrequencyChannel(230e9, 8e9, 1) + @test Comrade.channel(fc) == 1 + @test fc.central ∈ Comrade.interval(fc) + @test Comrade._center(fc) == fc.central + @test Comrade._region(fc) == 8e9 + @test 86e9 < fc + @test fc < 345e9 + end + + end From e0691b7afdafda56da1da42669bebae23b230151 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 22 Dec 2024 00:20:55 -0500 Subject: [PATCH 31/34] Try again --- examples/beginner/GeometricModeling/main.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/beginner/GeometricModeling/main.jl b/examples/beginner/GeometricModeling/main.jl index 64b02036..f9632a3b 100644 --- a/examples/beginner/GeometricModeling/main.jl +++ b/examples/beginner/GeometricModeling/main.jl @@ -2,9 +2,10 @@ import Pkg; #hide __DIR = @__DIR__; #hide pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide Pkg.activate(__DIR; io=pkg_io) #hide -## Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide Pkg.instantiate(; io=pkg_io) #hide Pkg.precompile(; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide close(pkg_io) #hide From b62493861c27640130289a67f3dfe30af94782e1 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 22 Dec 2024 01:08:40 -0500 Subject: [PATCH 32/34] Make docs directly in literate --- docs/tutorials.jl | 2 +- examples/beginner/GeometricModeling/main.jl | 4 ++-- src/instrument/instrument_transforms.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 0ac9c33a..881cf931 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -26,7 +26,7 @@ withenv("JULIA_DEBUG"=>"Literate") do jl_expr = "using Literate;"* "preprocess(path, str) = replace(str, \"__DIR = @__DIR__\" => \"__DIR = \\\"\$(dirname(path))\\\"\");"* "Literate.markdown(\"$(p_)\", \"$(joinpath(OUTPUT, d))\";"* - "name=\"$name\", execute=false, flavor=Literate.DocumenterFlavor(),"* + "name=\"$name\", execute=true, flavor=Literate.DocumenterFlavor(),"* "preprocess=Base.Fix1(preprocess, \"$(p_)\"))" cm = `julia --project=$(@__DIR__) -e $(jl_expr)` run(cm) diff --git a/examples/beginner/GeometricModeling/main.jl b/examples/beginner/GeometricModeling/main.jl index f9632a3b..dda40f2f 100644 --- a/examples/beginner/GeometricModeling/main.jl +++ b/examples/beginner/GeometricModeling/main.jl @@ -2,9 +2,9 @@ import Pkg; #hide __DIR = @__DIR__; #hide pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide Pkg.instantiate(; io=pkg_io) #hide Pkg.precompile(; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide Pkg.precompile(; io=pkg_io) #hide close(pkg_io) #hide @@ -189,7 +189,7 @@ DisplayAs.Text(DisplayAs.PNG(fig)) # parallel tempering sampler that enables global exploration of the posterior. For smaller dimension # problems (< 100) we recommend using this sampler, especially if you have access to > 1 core. using Pigeons -pt = pigeons(target=cpost, explorer=SliceSampler(), record=[traces, round_trip, log_sum_ratio], n_chains=16, n_rounds=10) +pt = pigeons(target=cpost, explorer=SliceSampler(), record=[traces, round_trip, log_sum_ratio], n_chains=16, n_rounds=8) # That's it! To finish it up we can then plot some simple visual fit diagnostics. diff --git a/src/instrument/instrument_transforms.jl b/src/instrument/instrument_transforms.jl index e1e2dcc9..3932f232 100644 --- a/src/instrument/instrument_transforms.jl +++ b/src/instrument/instrument_transforms.jl @@ -1,5 +1,5 @@ abstract type AbstractInstrumentTransform <: TV.VectorTransform end -site_map(t::AbstractInstrumentTransform) = t +site_map(t::AbstractInstrumentTransform) = t.site_map EnzymeRules.inactive(::typeof(site_map), args...) = nothing inner_transform(t::AbstractInstrumentTransform) = t.inner_transform From fe398fba6a6afb0bf4fa03ed76886fcd832908af Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 22 Dec 2024 01:44:08 -0500 Subject: [PATCH 33/34] try tutorials again --- docs/tutorials.jl | 2 +- examples/beginner/GeometricModeling/main.jl | 5 +++++ src/instrument/site_array.jl | 2 +- test/Core/models.jl | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 881cf931..b977a27e 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -9,8 +9,8 @@ OUTPUT = joinpath(@__DIR__, "src", "tutorials") TUTORIALS = [ "beginner/LoadingData/main.jl", - "intermediate/ClosureImaging/main.jl", "beginner/GeometricModeling/main.jl", + "intermediate/ClosureImaging/main.jl", "intermediate/StokesIImaging/main.jl", "intermediate/PolarizedImaging/main.jl", "advanced/HybridImaging/main.jl", diff --git a/examples/beginner/GeometricModeling/main.jl b/examples/beginner/GeometricModeling/main.jl index dda40f2f..39a2e6bd 100644 --- a/examples/beginner/GeometricModeling/main.jl +++ b/examples/beginner/GeometricModeling/main.jl @@ -242,3 +242,8 @@ DisplayAs.Text(DisplayAs.PNG(p)) # and the model, divided by the data's error: p = residual(post, chain[end]); DisplayAs.Text(DisplayAs.PNG(p)) + +post = nothing #hide +tpost = nothing #hide +cpost = nothing #hide +GC.gc() #hide diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 23634326..8059fba5 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -228,7 +228,7 @@ end function SiteArray(data::SiteArray{T, N}, times::AbstractArray{<:IntegrationTime, N}, frequencies::AbstractArray{<:FrequencyChannel, N}, - sites::AbstractArray{<:Number, N}) where {T, N} + sites::AbstractArray{<:Any, N}) where {T, N} return data end diff --git a/test/Core/models.jl b/test/Core/models.jl index eaf0b84c..93e6628a 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -302,7 +302,7 @@ end lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), - gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), refant=SingleReference(:AA, 0.0)), dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), From 7e37869fe719909304d3c332990c82d4c95eaf61 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 22 Dec 2024 02:45:14 -0500 Subject: [PATCH 34/34] make analytic models take up less memory --- examples/beginner/GeometricModeling/main.jl | 7 +------ src/skymodels/models.jl | 21 +++++++++++++++++++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/beginner/GeometricModeling/main.jl b/examples/beginner/GeometricModeling/main.jl index 39a2e6bd..c99b2945 100644 --- a/examples/beginner/GeometricModeling/main.jl +++ b/examples/beginner/GeometricModeling/main.jl @@ -1,11 +1,10 @@ import Pkg; #hide -__DIR = @__DIR__; #hide +__DIR = @__DIR__ #hide pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide Pkg.activate(__DIR; io=pkg_io) #hide Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide Pkg.instantiate(; io=pkg_io) #hide Pkg.precompile(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide close(pkg_io) #hide @@ -243,7 +242,3 @@ DisplayAs.Text(DisplayAs.PNG(p)) p = residual(post, chain[end]); DisplayAs.Text(DisplayAs.PNG(p)) -post = nothing #hide -tpost = nothing #hide -cpost = nothing #hide -GC.gc() #hide diff --git a/src/skymodels/models.jl b/src/skymodels/models.jl index 8859d495..abdf880f 100755 --- a/src/skymodels/models.jl +++ b/src/skymodels/models.jl @@ -72,7 +72,7 @@ function VLBISkyModels.FourierDualDomain(grid::AbstractRectiGrid, array::Abstrac return FourierDualDomain(grid, domain(array; executor), alg) end -struct ObservedSkyModel{F, G<:VLBISkyModels.AbstractFourierDualDomain, M} <: AbstractSkyModel +struct ObservedSkyModel{F, G<:VLBISkyModels.AbstractDomain, M} <: AbstractSkyModel f::F grid::G metadata::M @@ -82,6 +82,14 @@ function domain(m::AbstractSkyModel; kwargs...) return getfield(m, :grid) end +# If we are using a analytic model then we don't need to plan the FT and we +# can save some memory by not storing the plans. +struct AnalyticAlg <: FourierTransform end +struct AnalyticPlan <: VLBISkyModels.AbstractPlan end +VLBISkyModels.getplan(::AnalyticPlan) = nothing +VLBISkyModels.getphases(::AnalyticPlan) = nothing +VLBISkyModels.create_plans(::AnalyticAlg, imgdomain, visdomain) = (AnalyticPlan(), AnalyticPlan()) + """ ObservedSkyModel(sky::AbstractSkyModel, array::AbstractArrayConfiguration) @@ -92,9 +100,18 @@ pass that to a [`VLBIPosterior`](@ref) object instead. """ function ObservedSkyModel(m::SkyModel, arr::AbstractArrayConfiguration) - return ObservedSkyModel(m.f, FourierDualDomain(m.grid, arr, m.algorithm), m.metadata) + x = rand(NamedDist(m.prior)) + ms = m.f(x, m.metadata) + # if analytic don't bother planning the FT + if ComradeBase.visanalytic(typeof(ms)) === ComradeBase.IsAnalytic() + g = FourierDualDomain(m.grid, arr, AnalyticAlg()) + else + g = FourierDualDomain(m.grid, arr, m.algorithm) + end + return ObservedSkyModel(m.f, g, m.metadata) end + function set_array(m::AbstractSkyModel, array::AbstractArrayConfiguration) return ObservedSkyModel(m, array), m.prior end