diff --git a/Project.toml b/Project.toml index 501004be..09af5ad1 100644 --- a/Project.toml +++ b/Project.toml @@ -89,7 +89,7 @@ ParameterHandling = "0.4, 0.5" Pigeons = "0.3, 0.4" PolarizedTypes = "0.1" PrettyTables = "1, 2" -Pyehtim = "0.1" +Pyehtim = "0.2" RecipesBase = "1" Reexport = "1" SpecialFunctions = "0.10, 1, 2" diff --git a/docs/CondaPkg.toml b/docs/CondaPkg.toml deleted file mode 100644 index a44d2b67..00000000 --- a/docs/CondaPkg.toml +++ /dev/null @@ -1,16 +0,0 @@ - -[deps] -python = ">=3.6,<=3.10" -# astropy = "" -# ephem = "" -# future = "" -# h5py = "" -# ipython = "" -# matplotlib = "" -# networkx = "" -numpy = "<=1.23" -pandas = "<2" -# scipy = "" -# scikit-image = "" -[pip.deps] -ehtim = "" diff --git a/docs/tutorials.jl b/docs/tutorials.jl index c2dae2fe..b977a27e 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -26,7 +26,7 @@ withenv("JULIA_DEBUG"=>"Literate") do jl_expr = "using Literate;"* "preprocess(path, str) = replace(str, \"__DIR = @__DIR__\" => \"__DIR = \\\"\$(dirname(path))\\\"\");"* "Literate.markdown(\"$(p_)\", \"$(joinpath(OUTPUT, d))\";"* - "name=\"$name\", execute=false, flavor=Literate.DocumenterFlavor(),"* + "name=\"$name\", execute=true, flavor=Literate.DocumenterFlavor(),"* "preprocess=Base.Fix1(preprocess, \"$(p_)\"))" cm = `julia --project=$(@__DIR__) -e $(jl_expr)` run(cm) diff --git a/examples/advanced/HybridImaging/Project.toml b/examples/advanced/HybridImaging/Project.toml index 11b0baee..3ea32949 100644 --- a/examples/advanced/HybridImaging/Project.toml +++ b/examples/advanced/HybridImaging/Project.toml @@ -18,7 +18,7 @@ CairoMakie = "0.12" Distributions = "0.25" Optimization = "4" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" StatsBase = "0.34" VLBIImagePriors = "0.9" diff --git a/examples/advanced/HybridImaging/main.jl b/examples/advanced/HybridImaging/main.jl index ceb15692..e7ac1e4b 100644 --- a/examples/advanced/HybridImaging/main.jl +++ b/examples/advanced/HybridImaging/main.jl @@ -1,3 +1,15 @@ +import Pkg #hide +__DIR = @__DIR__ #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide + +ENV["GKSwstype"] = "nul" #hide + + # # Hybrid Imaging of a Black Hole # In this tutorial, we will use **hybrid imaging** to analyze the 2017 EHT data. @@ -16,18 +28,6 @@ # This is the approach we will take in this tutorial to analyze the April 6 2017 EHT data # of M87. -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide - -ENV["GKSwstype"] = "nul" #hide - - # ## Loading the Data # To get started we will load Comrade diff --git a/examples/beginner/GeometricModeling/Project.toml b/examples/beginner/GeometricModeling/Project.toml index 7e7be586..8879f150 100644 --- a/examples/beginner/GeometricModeling/Project.toml +++ b/examples/beginner/GeometricModeling/Project.toml @@ -18,6 +18,6 @@ Distributions = "0.25" Optimization = "4" Pigeons = "0.4" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" VLBIImagePriors = "0.9" diff --git a/examples/beginner/GeometricModeling/main.jl b/examples/beginner/GeometricModeling/main.jl index df880cf7..c99b2945 100644 --- a/examples/beginner/GeometricModeling/main.jl +++ b/examples/beginner/GeometricModeling/main.jl @@ -1,3 +1,13 @@ +import Pkg; #hide +__DIR = @__DIR__ #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide + + # # Geometric Modeling of EHT Data # `Comrade` has been designed to work with the EHT and ngEHT. @@ -9,22 +19,15 @@ # In this tutorial, we will construct a similar model and fit it to the data in under # 50 lines of code (sans comments). To start, we load Comrade and some other packages we need. -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide -ENV["GKSwstype"] = "nul" #hide - +# To get started we load Comrade. +#- using Comrade +# Currently we use eht-imaging for data management, however this will soon be replaced +# by a pure Julia solution. using Pyehtim - # For reproducibility we use a stable random number genreator using StableRNGs rng = StableRNG(42) @@ -81,7 +84,7 @@ prior = ( ξτ= Uniform(0.0, π), f = Uniform(0.0, 1.0), σG = Uniform(μas2rad(1.0), μas2rad(100.0)), - τG = Uniform(0.0, 1.0), + τG = Exponential(1.0), ξG = Uniform(0.0, 1π), xG = Uniform(-μas2rad(80.0), μas2rad(80.0)), yG = Uniform(-μas2rad(80.0), μas2rad(80.0)) @@ -146,8 +149,8 @@ fpost = asflat(post) p = prior_sample(rng, post) # and then transform it to transformed space using T -logdensityof(cpost, Comrade.TV.inverse(cpost, p)) -logdensityof(fpost, Comrade.TV.inverse(fpost, p)) +logdensityof(cpost, Comrade.inverse(cpost, p)) +logdensityof(fpost, Comrade.inverse(fpost, p)) # note that the log densit is not the same since the transformation has causes a jacobian to ensure volume is preserved. @@ -238,3 +241,4 @@ DisplayAs.Text(DisplayAs.PNG(p)) # and the model, divided by the data's error: p = residual(post, chain[end]); DisplayAs.Text(DisplayAs.PNG(p)) + diff --git a/examples/beginner/LoadingData/Project.toml b/examples/beginner/LoadingData/Project.toml index 5910585b..8168f932 100644 --- a/examples/beginner/LoadingData/Project.toml +++ b/examples/beginner/LoadingData/Project.toml @@ -6,4 +6,4 @@ Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468" [compat] Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" diff --git a/examples/beginner/LoadingData/main.jl b/examples/beginner/LoadingData/main.jl index a5236533..c064e438 100644 --- a/examples/beginner/LoadingData/main.jl +++ b/examples/beginner/LoadingData/main.jl @@ -1,14 +1,3 @@ -# # Loading Data into Comrade - -# The VLBI field does not have a standardized data format, and the EHT uses a -# particular uvfits format similar to the optical interferometry oifits format. -# As a result, we reuse the excellent `eht-imaging` package to load data into `Comrade`. - -# Once the data is loaded, we then convert the data into the tabular format `Comrade` -# expects. Note that this may change to a Julia package as the Julia radio -# astronomy group grows. - -# To get started, we will load `Comrade` and `Plots` to enable visualizations of the data import Pkg #hide __DIR = @__DIR__ #hide pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide @@ -18,11 +7,18 @@ Pkg.instantiate(; io=pkg_io) #hide Pkg.precompile(; io=pkg_io) #hide close(pkg_io) #hide -ENV["GKSwstype"] = "nul" #hide +# # Loading Data into Comrade +# The VLBI field does not have a standardized data format, and the EHT uses a +# particular uvfits format similar to the optical interferometry oifits format. +# As a result, we reuse the excellent `eht-imaging` package to load data into `Comrade`. -using Comrade +# Once the data is loaded, we then convert the data into the tabular format `Comrade` +# expects. Note that this may change to a Julia package as the Julia radio +# astronomy group grows. +# To get started, we will load `Comrade` and `Plots` to enable visualizations of the data +using Comrade using Plots # We also load Pyehtim since it loads eht-imaging into Julia using PythonCall and exports diff --git a/examples/intermediate/ClosureImaging/Project.toml b/examples/intermediate/ClosureImaging/Project.toml index 918f8c78..7d161925 100644 --- a/examples/intermediate/ClosureImaging/Project.toml +++ b/examples/intermediate/ClosureImaging/Project.toml @@ -24,6 +24,6 @@ Distributions = "0.25" Optimization = "4" Pkg = "1" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" VLBIImagePriors = "0.9" diff --git a/examples/intermediate/ClosureImaging/main.jl b/examples/intermediate/ClosureImaging/main.jl index 44c0d5b0..07070e94 100644 --- a/examples/intermediate/ClosureImaging/main.jl +++ b/examples/intermediate/ClosureImaging/main.jl @@ -1,3 +1,13 @@ +import Pkg #hide +__DIR = @__DIR__ #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide +ENV["GKSwstype"] = "nul"; #hide + # # Imaging a Black Hole using only Closure Quantities # In this tutorial, we will create a preliminary reconstruction of the 2017 M87 data on April 6 @@ -21,17 +31,6 @@ # # In this tutorial, we will do closure-only modeling of M87 to produce a posterior of images of M87. -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide -ENV["GKSwstype"] = "nul"; #hide - - # To get started, we will load Comrade using Comrade diff --git a/examples/intermediate/PolarizedImaging/main.jl b/examples/intermediate/PolarizedImaging/main.jl index a633bb2f..dfc47a53 100644 --- a/examples/intermediate/PolarizedImaging/main.jl +++ b/examples/intermediate/PolarizedImaging/main.jl @@ -1,3 +1,12 @@ +import Pkg #hide +__DIR = @__DIR__ #hide +pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide +Pkg.activate(__DIR; io=pkg_io) #hide +Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide +Pkg.instantiate(; io=pkg_io) #hide +Pkg.precompile(; io=pkg_io) #hide +close(pkg_io) #hide + # # Polarized Image and Instrumental Modeling # In this tutorial, we will analyze a simulated simple polarized dataset to demonstrate @@ -85,14 +94,6 @@ # In the rest of the tutorial, we are going to solve for all of these instrument model terms in # while re-creating the polarized image from the first [`EHT results on M87`](https://iopscience.iop.org/article/10.3847/2041-8213/abe71d). -import Pkg #hide -__DIR = @__DIR__ #hide -pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide -Pkg.activate(__DIR; io=pkg_io) #hide -Pkg.develop(; path=joinpath(__DIR, "..", "..", ".."), io=pkg_io) #hide -Pkg.instantiate(; io=pkg_io) #hide -Pkg.precompile(; io=pkg_io) #hide -close(pkg_io) #hide # ## Load the Data @@ -366,16 +367,16 @@ gamp_ratio = caltable(exp.(xopt.instrument.lgrat)) # expected since gain ratios are typically stable over the course of an observation and the constant # offset was removed in the EHT calibration process. gphaseR = caltable(xopt.instrument.gpR) -p = Plots.plot(gphaseR, layout=(3,3), size=(650,500)); -Plots.plot!(p, gphase_ratio, layout=(3,3), size=(650,500)); +p = Plots.plot(gphaseR, layout=(3,3), size=(650,500), label="R Gain Phase"); +Plots.plot!(p, gphase_ratio, layout=(3,3), size=(650,500), label="Gain Phase Ratio"); p |> DisplayAs.PNG |> DisplayAs.Text #- # Moving to the amplitudes we see largely stable gain amplitudes on the right circular polarization except for LMT which is # known and due to pointing issues during the 2017 observation. Again the gain ratios are stable and close to unity. Typically # we expect that apriori calibration should make the gain ratios close to unity. gampr = caltable(exp.(xopt.instrument.lgR)) -p = Plots.plot(gampr, layout=(3,3), size=(650,500)) -Plots.plot!(p, gamp_ratio, layout=(3,3), size=(650,500)) +p = Plots.plot(gampr, layout=(3,3), size=(650,500), label="R Gain Amp."); +Plots.plot!(p, gamp_ratio, layout=(3,3), size=(650,500), label="Gain Amp. Ratio") p |> DisplayAs.PNG |> DisplayAs.Text #- diff --git a/examples/intermediate/StokesIImaging/Project.toml b/examples/intermediate/StokesIImaging/Project.toml index a36a695a..548458a4 100644 --- a/examples/intermediate/StokesIImaging/Project.toml +++ b/examples/intermediate/StokesIImaging/Project.toml @@ -21,6 +21,6 @@ Distributions = "0.25" Optimization = "4" Pkg = "1" Plots = "1" -Pyehtim = "0.1" +Pyehtim = "0.2" StableRNGs = "1" VLBIImagePriors = "0.9" diff --git a/examples/intermediate/StokesIImaging/main.jl b/examples/intermediate/StokesIImaging/main.jl index 4933b901..7f9d1b5a 100644 --- a/examples/intermediate/StokesIImaging/main.jl +++ b/examples/intermediate/StokesIImaging/main.jl @@ -1,12 +1,3 @@ -# # Stokes I Simultaneous Image and Instrument Modeling - -# In this tutorial, we will create a preliminary reconstruction of the 2017 M87 data on April 6 -# by simultaneously creating an image and model for the instrument. By instrument model, we -# mean something akin to self-calibration in traditional VLBI imaging terminology. However, -# unlike traditional self-cal, we will solve for the gains each time we update the image -# self-consistently. This allows us to model the correlations between gains and the image. - -# To get started we load Comrade. import Pkg #hide __DIR = @__DIR__ #hide pkg_io = open(joinpath(__DIR, "pkg.log"), "w") #hide @@ -17,6 +8,15 @@ Pkg.precompile(; io=pkg_io) #hide close(pkg_io) #hide +# # Stokes I Simultaneous Image and Instrument Modeling + +# In this tutorial, we will create a preliminary reconstruction of the 2017 M87 data on April 6 +# by simultaneously creating an image and model for the instrument. By instrument model, we +# mean something akin to self-calibration in traditional VLBI imaging terminology. However, +# unlike traditional self-cal, we will solve for the gains each time we update the image +# self-consistently. This allows us to model the correlations between gains and the image. + +# To get started we load Comrade. using Comrade @@ -131,15 +131,13 @@ skym = SkyModel(sky, prior, grid; metadata=skymeta) # - Gain phases which are more difficult to constrain and can shift rapidly. G = SingleStokesGain() do x - lg = x.lgμ + x.lgσ*x.lgz + lg = x.lg gp = x.gp return exp(lg + 1im*gp) end intpr = ( - lgμ = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2)); LM = IIDSitePrior(TrackSeg(), Normal(0.0, 1.0))), - lgσ = ArrayPrior(IIDSitePrior(TrackSeg(), Exponential(0.1))), - lgz = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 1.0))), + lg = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.2)); LM = IIDSitePrior(ScanSeg(), Normal(0.0, 1.0))), gp= ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(π^2))); refant=SEFDReference(0.0), phase=true) ) intmodel = InstrumentModel(G, intpr) diff --git a/ext/ComradeEnzymeExt.jl b/ext/ComradeEnzymeExt.jl index 0eca7e49..9ae85532 100644 --- a/ext/ComradeEnzymeExt.jl +++ b/ext/ComradeEnzymeExt.jl @@ -12,7 +12,7 @@ function LogDensityProblems.logdensity_and_gradient(d::Comrade.TransformedVLBIPo mode = Enzyme.EnzymeCore.WithPrimal(Comrade.admode(d)) dx = zero(x) y = fetch(schedule( - Task(32*1024*2014) do + Task(32*1024*1024) do (_, y) = autodiff(mode, Comrade.logdensityof, Active, Const(d), Duplicated(x, dx)) return y end diff --git a/ext/ComradeOptimizationExt.jl b/ext/ComradeOptimizationExt.jl index aed367c1..c0f4091f 100644 --- a/ext/ComradeOptimizationExt.jl +++ b/ext/ComradeOptimizationExt.jl @@ -14,7 +14,7 @@ function Optimization.OptimizationFunction(post::Comrade.TransformedVLBIPosterio return SciMLBase.OptimizationFunction(ℓ, args...; kwargs...) else function grad(G, x, p) - (_, dx) = LogDensityProblems.logdensity_and_gradient(post, x) + (_, dx) = Comrade.LogDensityProblems.logdensity_and_gradient(post, x) dx .*= -1 G .= dx return G diff --git a/ext/ComradePyehtimExt.jl b/ext/ComradePyehtimExt.jl index e357ddf1..90063c9b 100644 --- a/ext/ComradePyehtimExt.jl +++ b/ext/ComradePyehtimExt.jl @@ -14,8 +14,14 @@ function build_arrayconfig(obs) source = get_source(obsc) bw = get_bw(obsc) angles = get_fr_angles(obsc) - tarr = Pyehtim.get_arraytable(obsc) - scans = get_scantable(obsc) + tarr = StructArray(Pyehtim.get_arraytable(obsc)) + + # This is because sometimes eht-imaging sets the scans to nothing + # and sometimes it fills it with junk + if length(obsc.scans) <= 1 + obsc.add_scans() + end + scans = StructArray(get_scantable(obsc)) bw = get_bw(obsc) elevation = StructArray(angles[1]) parallactic = StructArray(angles[2]) @@ -140,7 +146,6 @@ Returns an EHTObservationTable with visibility amplitude data """ function Comrade.extract_amp(obsc; pol=:I, debias=false, kwargs...) obs = obsc.copy() - obs.add_scans() obs.reorder_tarr_snr() obs.add_amp(;debias, kwargs...) config = build_arrayconfig(obs) @@ -163,7 +168,6 @@ Returns an EHTObservationTable with complex visibility data """ function Comrade.extract_vis(obsc; pol=:I, kwargs...) obs = obsc.copy() - obs.add_scans() obs.reorder_tarr_snr() config = build_arrayconfig(obs) vis, viserr = getvisfield(obs) @@ -181,7 +185,6 @@ Returns an EHTObservationTable with coherency matrices """ function Comrade.extract_coherency(obsc; kwargs...) obs = obsc.copy() - obs.add_scans() obs.reorder_tarr_snr() config = build_arrayconfig(obs) vis, viserr = getcoherency(obs) diff --git a/src/instrument/caltable.jl b/src/instrument/caltable.jl index c9810dbc..44b9b140 100644 --- a/src/instrument/caltable.jl +++ b/src/instrument/caltable.jl @@ -4,13 +4,14 @@ export caltable $(TYPEDEF) A Tabes of calibration quantities. The columns of the table are the telescope sites codes. -The rows are the calibration quantities at a specific time stamp. This user should not +The rows are the calibration quantities at a specific time and frequency. This user should not use this struct directly. Instead that should call [`caltable`](@ref). """ -struct CalTable{T,G<:AbstractVecOrMat} +struct CalTable{T,F,G<:AbstractVecOrMat} names::Vector{Symbol} lookup::Dict{Symbol,Int} times::T + freqs::F gmat::G end @@ -25,40 +26,45 @@ Tables.columnaccess(::Type{<:CalTable}) = true Return the sites in the calibration table """ sites(g::CalTable) = getfield(g, :names) -scantimes(g::CalTable) = getfield(g, :times) +times(g::CalTable) = getfield(g, :times) +frequencies(g::CalTable) = getfield(g, :freqs) lookup(g::CalTable) = getfield(g, :lookup) gmat(g::CalTable) = getfield(g, :gmat) -function Tables.schema(g::CalTable{T,G}) where {T,G} - nms = [:time] +function Tables.schema(g::CalTable{T,F,G}) where {T,F,G} + nms = [:Ti, :Fr] append!(nms, sites(g)) - types = Type[eltype(T)] + types = Type[eltype(T), eltype(F)] append!(types, fill(eltype(G), size(gmat(g),2))) return Tables.Schema(nms, types) end -Tables.columns(g::CalTable) = Tables.table([scantimes(g) gmat(g)]; header=Tables.columnnames(g)) +Tables.columns(g::CalTable) = Tables.table(hcat(times(g), frequencies(g), gmat(g)); header=Tables.columnnames(g)) function Tables.getcolumn(g::CalTable, ::Type{T}, col::Int, nm::Symbol) where {T} - (col == 1 || nm == :time) && return scantimes(g) - gmat(g)[:, col-1] + (col == 1 || nm == :Ti) && return times(g) + (col == 2 || nm == :Fr) && return frequencies(g) + gmat(g)[:, col-2] end function Tables.getcolumn(g::CalTable, nm::Symbol) - nm == :time && return scantimes(g) + nm == :Ti && return times(g) + nm == :Fr && return frequencies(g) return gmat(g)[:, lookup(g)[nm]] end function Tables.getcolumn(g::CalTable, i::Int) - i==1 && return scantimes(g) - return gmat(g)[:, i-1] + i==1 && return times(g) + i==2 && return frequencies(g) + return gmat(g)[:, i-2] end function viewcolumn(gt::CalTable, nm::Symbol) - nm == :time && return scantimes(gt) + nm == :Ti && return times(gt) + nm == :Fr && return frequencies(gt) return @view gmat(gt)[:, lookup(gt)[nm]] end -Tables.columnnames(g::CalTable) = [:time, sites(g)...] +Tables.columnnames(g::CalTable) = [:Ti, :Fr, sites(g)...] Tables.rowaccess(::Type{<:CalTable}) = true Tables.rows(g::CalTable) = g @@ -68,6 +74,8 @@ struct CalTableRow{T,G} <: Tables.AbstractRow source::CalTable{T,G} end +Tables.columnnames(g::CalTableRow) = [:Ti, :Fr, sites(getfield(g, :source))...] + function Base.propertynames(g::CalTable) return Tables.columnnames(g) end @@ -77,7 +85,8 @@ function Base.getproperty(g::CalTable, nm::Symbol) end function Base.getindex(gt::CalTable, i::Int, nm::Symbol) - nm == :time && return gt.times[i] + nm == :Ti && return times(gt)[i] + nm == :Fr && return frequencies(gt)[i] return Tables.getcolumn(gt, nm)[i] end @@ -88,11 +97,21 @@ end function Base.getindex(gt::CalTable, I::AbstractUnitRange, nm::Symbol) getproperty(gt, nm)[I] end +function Base.getindex(gt::CalTable, I::AbstractVector{Int}, nm::Symbol) + getproperty(gt, nm)[I] +end + function Base.view(gt::CalTable, I::AbstractUnitRange, nm::Symbol) @view getproperty(gt, nm)[I] end +function Base.view(gt::CalTable, I::AbstractVector{Int}, nm::Symbol) + @view getproperty(gt, nm)[I] +end + + + function Base.getindex(gt::CalTable, ::Colon, nm::Symbol) Tables.getcolumn(gt, nm) end @@ -117,18 +136,21 @@ end function Tables.getcolumn(g::CalTableRow, ::Type, col::Int, nm::Symbol) - (col == 1 || nm == :time) && return scantimes(getfield(g, :source))[getfield(g, :row)] - gmat(getfield(g, :source))[getfield(g, :row), col-1] + (col == 1 || nm == :Ti) && return times(getfield(g, :source))[getfield(g, :row)] + (col == 2 || nm == :Fr) && return frequencies(getfield(g, :source))[getfield(g, :row)] + gmat(getfield(g, :source))[getfield(g, :row), col-2] end function Tables.getcolumn(g::CalTableRow, i::Int) - (i==1) && return scantimes(getfield(g, :source))[getfield(g, :row)] - gmat(getfield(g, :source))[getfield(g, :row), i-1] + (i==1) && return times(getfield(g, :source))[getfield(g, :row)] + (i==2) && return frequencies(getfield(g, :source))[getfield(g, :row)] + gmat(getfield(g, :source))[getfield(g, :row), i-2] end function Tables.getcolumn(g::CalTableRow, nm::Symbol) src = getfield(g, :source) - nm == :time && return scantimes(src)[getfield(g, :row)] + nm == :Ti && return times(src)[getfield(g, :row)] + nm == :Fr && return frequencies(src)[getfield(g, :row)] return gmat(src)[getfield(g, :row), lookup(src)[nm]] end @@ -152,45 +174,57 @@ end #else # ylims --> inv.(lims)[end:-1:begin] #end - t = getproperty.(gt[:time], :t0) + t = getproperty.(gt[:Ti], :t0) xlims --> (t[begin], t[end] + 0.01*abs(t[end])) for (i,s) in enumerate(sites) - @series begin - seriestype := :scatter - subplot := i - label --> :none - - if i == length(sites) - xguide --> "Time (UTC)" - end - - T = nonmissingtype(eltype(gt[s])) - ind = Base.:!.(ismissing.(gt[s])) - #x := gt[:time][ind] - if !datagains - yy = gt[s][ind] - else - yy = inv.(gt[s])[ind] + T = nonmissingtype(eltype(gt[s])) + for (j,f) in enumerate(unique(gt[:Fr])) + @series begin + seriestype := :scatter + subplot := i + title := String(s) + if i == length(sites) + xguide --> "Time (UTC)" + end + label = "$(round(f.central/1e9, digits=1)) GHz" + if i == length(sites) + label --> label + else + label := nothing + end + ind = Base.:!.(ismissing.(gt[s])) + find = findall(==(f), gt[:Fr][ind]) + #x := gt[:Ti][ind] + x = t[ind][find] + + if !datagains + y = T.(gt[s][ind][find]) + else + y = T.(inv.((gt[s][ind][find]))) + end + x, y end - - title --> string(s) - t[ind], T.(yy) end end end -Tables.columnnames(g::CalTableRow) = [:time, sites(getfield(g, :source))...] using PrettyTables function Base.show(io::IO, ct::CalTable, ) pretty_table(io, Tables.columns(ct); header=Tables.columnnames(ct), - vlines=[1], - # formatters = (v,i,j)->round(v, digits=3) + vlines=[1,2], + formatters = _ctab_formatter ) end +function _ctab_formatter(v, i, j) + j == 1 && return "$(round(v.t0, digits=2)) hr" + j == 2 && return "$(round(v.central, digits=2)/1e9) GHz" + return round(v, digits=3) +end + """ caltable(s::SiteArray) @@ -199,18 +233,22 @@ Creates a calibration table from a site array """ function caltable(sarr::SiteArray) sites = sort(unique(Comrade.sites(sarr))) - time = unique(times(sarr)) + tf = collect(Iterators.product(unique(times(sarr))|>sort, unique(frequencies(sarr))|>sort)) + time = vec(first.(tf)) + freq = vec(last.(tf)) gmat = Matrix{Union{eltype(sarr), Missing}}(missing, length(time), length(sites)) gmat .= missing lookup = Dict(sites[i]=>i for i in eachindex(sites)) for (j, s) in enumerate(sites) cterms = site(sarr, s) - for (i, t) in enumerate(time) - ind = findfirst(==(t), times(cterms)) + for (i, (t,f)) in enumerate(tf) + ti = times(cterms) + fi = frequencies(cterms) + ind = findfirst(i->((ti[i]==t)&&fi[i]==f), eachindex(ti, fi)) if !isnothing(ind) gmat[i, j] = cterms[ind] end end end - return CalTable(sites, lookup, time, gmat) + return CalTable(sites, lookup, time, freq, gmat) end diff --git a/src/instrument/instrument.jl b/src/instrument/instrument.jl index ff111549..10e889e6 100644 --- a/src/instrument/instrument.jl +++ b/src/instrument/instrument.jl @@ -4,26 +4,41 @@ import Distributions using Statistics using PrettyTables -struct IntegrationTime{T} - mjd::Int +struct IntegrationTime{I<:Integer, T} + mjd::I t0::T dt::T end -Base.in(t::Number, ts::IntegrationTime) = (ts.t0 - ts.dt/2) ≤ t < (ts.t0 + ts.dt) +mjd(ts::IntegrationTime) = ts.mjd +interval(ts::IntegrationTime) = (ts.t0 - ts.dt/2)..(ts.t0 + ts.dt/2) +Base.in(t::Number, ts::IntegrationTime) = (ts.t0 - ts.dt/2) ≤ t < (ts.t0 + ts.dt/2) Base.isless(t::IntegrationTime, ts::IntegrationTime) = t.t0 < ts.t0 Base.isless(s::Number, t::IntegrationTime) = s < (t.t0 - t.dt/2) Base.isless(t::IntegrationTime, s::Number) = (t.t0 + t.dt/2) < s -mjd(ts::IntegrationTime) = ts.mjd +Base.Broadcast.broadcastable(ts::IntegrationTime) = Ref(ts) +_center(ts::IntegrationTime) = ts.t0 +_region(ts::IntegrationTime) = ts.dt struct FrequencyChannel{T, I<:Integer} central::T bandwidth::T channel::I end -Base.in(f::Number, fs::FrequencyChannel) = (fs.central-fs.bandwidth/2) ≤ f < (fs.central+fs.bandwidth/2) channel(fs::FrequencyChannel) = fs.channel +interval(fs::FrequencyChannel) = (fs.central - fs.bandwidth/2)..(fs.central + fs.bandwidth/2) +Base.in(f::Number, fs::FrequencyChannel) = (fs.central-fs.bandwidth/2) ≤ f < (fs.central+fs.bandwidth/2) +Base.isless(t::FrequencyChannel, ts::FrequencyChannel) = _center(t) < _center(ts) +Base.isless(s::Number, t::FrequencyChannel) = s < (_center(t) - _region(t)/2) +Base.isless(t::FrequencyChannel, s::Number) = (_center(t) + _region(t)/2) < s +Base.Broadcast.broadcastable(fs::FrequencyChannel) = Ref(fs) + + + +_center(fs::FrequencyChannel) = fs.central +_region(fs::FrequencyChannel) = fs.bandwidth + include("site_array.jl") diff --git a/src/instrument/instrument_transforms.jl b/src/instrument/instrument_transforms.jl index 1e02167a..3932f232 100644 --- a/src/instrument/instrument_transforms.jl +++ b/src/instrument/instrument_transforms.jl @@ -1,13 +1,17 @@ abstract type AbstractInstrumentTransform <: TV.VectorTransform end site_map(t::AbstractInstrumentTransform) = t.site_map +EnzymeRules.inactive(::typeof(site_map), args...) = nothing inner_transform(t::AbstractInstrumentTransform) = t.inner_transform function TV.transform_with(flag::TV.LogJacFlag, m::AbstractInstrumentTransform, x, index) y, ℓ, index = _instrument_transform_with(flag, m, x, index) - sm = m.site_map - return SiteArray(y, sm.times, sm.frequencies, sm.sites), ℓ, index + sm = site_map(m) + return SiteArray(y, sm), ℓ, index end +EnzymeRules.inactive_type(::Type{AbstractArray{<:IntegrationTime}}) = true +EnzymeRules.inactive_type(::Type{AbstractArray{<:FrequencyChannel}}) = true + function TV.inverse_at!(x::AbstractArray, index, t::AbstractInstrumentTransform, y::SiteArray) itrf = inner_transform(t) return TV.inverse_at!(x, index, itrf, parent(y)) @@ -50,15 +54,25 @@ end return yout, ℓ, index end +EnzymeRules.inactive_type(::Type{<:SiteLookup}) = true + @inline function site_sum(y, site_map::SiteLookup) - yout = similar(y) - @inbounds for site in site_map.lookup + # yout = similar(y) + vals = values(lookup(site_map)) + @inbounds for site in vals + # i0 = site[begin] + # yout[i0] = y[i0] + # acc = zero(eltype(y)) + # for idx in site + # acc += y[idx] + # yout[idx] = acc + # end ys = @view y[site] # y should never alias so we should be fine here. - youts = @view yout[site] - cumsum!(youts, ys) + # youts = @view yout[site] + cumsum!(ys, ys) end - return yout + return y end # function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_instrument_transform_with), flag, m::MarkovInstrumentTransform, x, index) @@ -84,11 +98,11 @@ function site_diff!(y, site_map::SiteLookup) return nothing end -function simplediff(x::AbstractVector) - y = zero(x) - simplediff!(y, x) - return y -end +# function simplediff(x::AbstractVector) +# y = zero(x) +# simplediff!(y, x) +# return y +# end function simplediff!(y::AbstractVector, x::AbstractVector) y[begin] = x[begin] diff --git a/src/instrument/jonesmatrices.jl b/src/instrument/jonesmatrices.jl index 2042b4c6..274e7a11 100644 --- a/src/instrument/jonesmatrices.jl +++ b/src/instrument/jonesmatrices.jl @@ -225,7 +225,7 @@ of `xs, and whose elements are the jones matrices for the specific parameters. function forward_jones(v::AbstractJonesMatrix, xs::NamedTuple{N}) where {N} sm = broadest_sitemap(xs) bl = map(x->(x,x), sm.sites) - bmaps = map(x->_construct_baselinemap(getproperty.(sm.times, :t0), sm.frequencies, bl, x).indices_1, xs) + bmaps = map(x->_construct_baselinemap(getproperty.(sm.times, :t0), getproperty.(sm.frequencies, :central), bl, x).indices_1, xs) vs = map(eachindex(sm.times)) do index indices = map(x->getindex(x, index), bmaps) params = NamedTuple{N}(map(getindex, values(xs), values(indices))) diff --git a/src/instrument/model.jl b/src/instrument/model.jl index 423492e8..abf331f1 100644 --- a/src/instrument/model.jl +++ b/src/instrument/model.jl @@ -195,8 +195,8 @@ function _construct_baselinemap(T, F, bl, x::SiteArray) t = T[i] f = F[i] s1, s2 = bl[i] - i1 = findall(x->(t∈x[1])&&(x[2]==s1), tsf) - i2 = findall(x->(t∈x[1])&&(x[2]==s2), tsf) + i1 = findall(x->(t∈x[1])&&(x[2]==s1)&&(f∈x[3]), tsf) + i2 = findall(x->(t∈x[1])&&(x[2]==s2)&&(f∈x[3]), tsf) length(i1) > 1 && throw(AssertionError("Multiple indices found for $t, $((s1)) in SiteArray")) length(i2) > 1 && throw(AssertionError("Multiple indices found for $t, $((s2)) in SiteArray")) (isnothing(i1) | isempty(i1)) && throw(AssertionError("$t, $f, $((s1)) not found in SiteArray")) diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index 910592ca..fe8c7279 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -71,40 +71,52 @@ function build_sitemap(d::ArrayPrior, array) # Ok to so this we are going to construct the schema first over sites. # At the end we may re-order depending on the schema ordering we want # to use. - tlists = map(keys(sites_prior)) do s + lists = map(keys(sites_prior)) do s seg = segmentation(sites_prior[s]) # get all the indices where this site is present inds_s = findall(x->((x[1]==s)||x[2]==s), array[:sites]) # Get all the unique times - ts = unique(T[inds_s]) + ts = T[inds_s] + fs = F[inds_s] + tfs = zip(ts, fs) # Now makes the acceptable time stamps given the segmentation tstamp = timestamps(seg, array) + fchan = freqchannels(SpectralWindow(), array) # Now we find commonalities - times = eltype(tstamp)[] - for t in tstamp - if any(x->x∈t, ts) && (!(t.t0 ∈ times)) - push!(times, t) + tf = Tuple{eltype(tstamp), eltype(fchan)}[] + for f in fchan, t in tstamp + if any(x->(x[1]∈t && x[2]∈f), tfs) && ((!((t, f) ∈ tf))) + push!(tf, (t, f)) end end - return times + return first.(tf), last.(tf) end + tlists = first.(lists) + flists = last.(lists) # construct the schema slist = mapreduce((t,s)->fill(s, length(t)), vcat, tlists, keys(sites_prior)) tlist = reduce(vcat, tlists) + flist = reduce(vcat, flists) + tlistre = similar(tlist) slistre = similar(slist) - # Now rearrange so we have time site ordering (sites are the fastest changing) - tuni = sort(unique(getproperty.(tlist, :t0))) + flistre = similar(flist) + # Now rearrange so we have frquency, time, site ordering (sites are the fastest changing) + tuni = sort(unique((tlist))) + funi = sort(unique((flist))) ind0 = 1 - for t in tuni - ind = findall(x->x.t0==t, tlist) - tlistre[ind0:ind0+length(ind)-1] .= tlist[ind] - slistre[ind0:ind0+length(ind)-1] .= slist[ind] - ind0 += length(ind) + for f in funi, t in tuni + ind = (f .== flist) .& (t .== tlist) + vtlist = @view tlist[ind] + vslist = @view slist[ind] + vflist = @view flist[ind] + tlistre[ind0:ind0+length(vtlist)-1] .= vtlist + slistre[ind0:ind0+length(vtlist)-1] .= vslist + flistre[ind0:ind0+length(vtlist)-1] .= vflist + ind0 += length(vtlist) end - freqs = Fill(F[1], length(tlistre)) - return SiteLookup(tlistre, freqs, slistre) + return SiteLookup(tlistre, flistre, slistre) end function ObservedArrayPrior(d::ArrayPrior, array::EHTArrayConfiguration) diff --git a/src/instrument/priors/refant.jl b/src/instrument/priors/refant.jl index 8585f579..235e6de2 100644 --- a/src/instrument/priors/refant.jl +++ b/src/instrument/priors/refant.jl @@ -43,14 +43,19 @@ end function reference_indices(array::AbstractArrayConfiguration, st::SiteLookup, r::SEFDReference) tarr = array.tarr t = unique(st.times) + f = unique(st.frequencies) sefd = NamedTuple{Tuple(tarr.sites)}(Tuple(tarr.SEFD1 .+ tarr.SEFD2)) - fixedinds = map(eachindex(t)) do i - inds = findall(==(t[i]), st.times) + fixedinds = Int[] + for i in eachindex(t), j in eachindex(f) + inds = findall(x->((st.times[x]==t[i])&&(st.frequencies[x]==f[j])), eachindex(st.times)) + if isempty(inds) + continue + end sites = Tuple(st.sites[inds]) @assert length(sites) <= length(sefd) "Error in reference site generation. Too many sites" sp = select(sefd, sites) _, ind = findmin(values(sp)) - return inds[ind] + push!(fixedinds, inds[ind]) end return fixedinds, Fill(r.value, length(fixedinds)) end diff --git a/src/instrument/priors/segmentation.jl b/src/instrument/priors/segmentation.jl index 3b6eb7d8..da6af450 100644 --- a/src/instrument/priors/segmentation.jl +++ b/src/instrument/priors/segmentation.jl @@ -11,13 +11,17 @@ the following functions: """ abstract type Segmentation end +abstract type TimeSegmentation <: Segmentation end + +abstract type FrequencySegmentation <: Segmentation end + # Track is for quantities that remain static across an entire observation """ $(TYPEDEF) Data segmentation such that the quantity is constant over a `track`, i.e., the observation "night". """ -struct TrackSeg <: Segmentation end +struct TrackSeg <: TimeSegmentation end # Scan is for quantities that are constant across a scan """ @@ -25,7 +29,7 @@ struct TrackSeg <: Segmentation end Data segmentation such that the quantity is constant over a `scan`. """ -struct ScanSeg <: Segmentation end +struct ScanSeg <: TimeSegmentation end # Integration is for quantities that change every integration time @@ -35,10 +39,10 @@ struct ScanSeg <: Segmentation end Data segmentation such that the quantity is constant over the time stamps in the data. If the data is scan-averaged before then `IntegSeg` will be identical to `ScanSeg`. """ -struct IntegSeg <: Segmentation end +struct IntegSeg <: TimeSegmentation end """ - timestamps(seg::Segmentation, array::AbstractArrayConfiguration) + timestamps(seg::TimeSegmentation, array::AbstractArrayConfiguration) Return the time stamps or really a vector of integration time regions for a given segmentation scheme `seg` and array configuration `array`. @@ -48,10 +52,10 @@ function timestamps end function timestamps(::ScanSeg, array) st = array.scans mjd = array.mjd - # Shift the central time to the middle of the scan dt = (st.stop .- st.start) t0 = st.start .+ dt./2 + return IntegrationTime.(mjd, t0, dt) end @@ -60,7 +64,6 @@ end function timestamps(::IntegSeg, array) ts = unique(array[:Ti]) - st = array.scans mjd = array.mjd # TODO build in the dt into the data format @@ -74,10 +77,10 @@ function timestamps(::IntegSeg, array) end function timestamps(::TrackSeg, array) - st = array.scans mjd = array.mjd tstart, tend = extrema(array[:Ti]) + tend = tend + 1 dt = tend - tstart if iszero(dt) dt = 1/3600 @@ -85,3 +88,15 @@ function timestamps(::TrackSeg, array) # TODO build in the dt into the data format return (IntegrationTime(mjd, (tend-tstart)/2 + tstart, dt),) end + + +struct SpectralWindow <: FrequencySegmentation end + + +function freqchannels(::SpectralWindow, array) + Fr = unique(array[:Fr]) + if length(Fr) > 1 + Fr[end] = nextfloat(Fr[end]) + end + return FrequencyChannel.(Fr, array.bandwidth, 1:length(Fr)) +end diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 0e53ef3e..8059fba5 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -16,7 +16,7 @@ which will grab the first 10 time and frequency points for the ALMA site. Otherwise indexing into the array will return an element whose time, frequency, and site are the element of the `times`, `frequencies`, and `sites` arrays respectively. """ -struct SiteArray{T, N, A<:AbstractArray{T,N}, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:Number, N}, Sy<:AbstractArray{<:Any, N}} <: AbstractArray{T, N} +struct SiteArray{T, N, A<:AbstractArray{T,N}, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:FrequencyChannel, N}, Sy<:AbstractArray{<:Any, N}} <: AbstractArray{T, N} data::A times::Ti frequencies::Fr @@ -111,55 +111,55 @@ function site(arr::SiteArray, p) end -function time(arr::SiteArray, a::Union{AbstractInterval, IntegrationTime}) +function time(arr::SiteArray, a::Union{AbstractInterval, Real}) inds = findall(in(a), times(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end -function frequency(arr::SiteArray, a::Union{AbstractInterval, FrequencyChannel}) +function frequency(arr::SiteArray, a::Union{AbstractInterval, Real}) inds = findall(in(a), times(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end +_equalorin(x::T, y::T) where {T} = x == y +_equalorin(x::Real, y) = x ∈ y +_equalorin(x, y::Real) = y ∈ x +_equalorin(x, y) = y ∈ x +_equalorin(x, ::typeof(Base.Colon())) = true +_equalorin(::typeof(Base.Colon()), x) = true +const Indexable = Union{Integer, AbstractArray{<:Integer}, BitArray} -@inline function _maybe_all(arr, X) - if X isa Base.Colon - ext = extrema(arr) - return ClosedInterval(ext[1], ext[2]) - else - return X - end -end - -function Base.getindex(arr::SiteArray; F=Base.Colon(), S=Base.Colon(), T=Base.Colon()) - T2 = _maybe_all(times(arr), T) - F2 = _maybe_all(frequencies(arr), F) - S2 = S isa Base.Colon ? unique(S) : S - return select_region(arr, S2, T2, F2) -end - -function select_region(arr::SiteArray, S::Symbol, T::Union{IntegrationTime, AbstractInterval}, F::Union{FrequencyChannel, AbstractInterval}) - select_region(arr, (S,), T, F) -end - - -function select_region(arr::SiteArray, site, T::Union{IntegrationTime, AbstractInterval}, F::Union{FrequencyChannel, AbstractInterval}) - inds = findall(i->((Comrade.sites(arr)[i] ∈ site)&&(Comrade.times(arr)[i] ∈ T)), eachindex(arr)) +function Base.getindex(arr::SiteArray; Fr=Base.Colon(), S=Base.Colon(), Ti=Base.Colon()) + Fr2 = isa(Fr, Indexable) ? unique(arr.frequencies)[Fr] : Fr + S2 = isa(S, Indexable) ? unique(arr.sites)[S] : S + Ti2 = isa(Ti, Indexable) ? unique(arr.times)[Ti] : Ti + inds = findall(i->(_equalorin(S2, Comrade.sites(arr)[i])&&_equalorin(Ti2, Comrade.times(arr)[i])&&_equalorin(Fr2, Comrade.frequencies(arr)[i])), eachindex(arr)) nd = view(parent(arr), inds) return SiteArray(nd, view(times(arr), inds), view(frequencies(arr), inds), view(sites(arr), inds)) end -struct SiteLookup{L<:NamedTuple, N, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:Number, N}, Sy<:AbstractArray{<:Any, N}} + +struct SiteLookup{L<:NamedTuple, N, Ti<:AbstractArray{<:IntegrationTime, N}, Fr<:AbstractArray{<:FrequencyChannel, N}, Sy<:AbstractArray{<:Any, N}} lookup::L times::Ti frequencies::Fr sites::Sy end +times(s::SiteLookup) = s.times +sites(s::SiteLookup) = s.sites +frequencies(s::SiteLookup) = s.frequencies +lookup(s::SiteLookup) = s.lookup + +EnzymeRules.inactive(::typeof(times), ::SiteLookup) = nothing +EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup) = nothing +EnzymeRules.inactive(::typeof(sites), ::SiteLookup) = nothing +EnzymeRules.inactive(::typeof(lookup), ::SiteLookup) = nothing + function sitemap!(f, out::AbstractArray, gains::AbstractArray, slook::SiteLookup) - map(slook.lookup) do site + map(lookup(slook)) do site ysite = @view gains[site] outsite = @view out[site] outsite .= f.(ysite) @@ -173,9 +173,10 @@ function sitemap(f, gains::AbstractArray{T}, slook::SiteLookup) where {T} end function sitemap!(::typeof(cumsum), out::AbstractArray, gains::AbstractArray, slook::SiteLookup) - map(slook.lookup) do site + map(lookup(slook)) do site ys = @view gains[site] cumsum!(ys, ys) + nothing end return out end @@ -190,8 +191,28 @@ function SiteLookup(s::SiteArray) end function SiteLookup(times::AbstractVector, frequencies::AbstractArray, sites::AbstractArray) - slist = Tuple(sort(unique(sites))) - return SiteLookup(NamedTuple{slist}(map(p->findall(==(p), sites), slist)), times, frequencies, sites) + slist = sort(unique(sites)) + flist = sort(unique(frequencies)) + if length(flist) == 1 + return SiteLookup(NamedTuple{Tuple(slist)}(map(p->findall(==(p), sites), slist)), times, frequencies, sites) + else + # Find sites first + sflist = Symbol[] + inds = Vector{Int}[] + for s in slist + sinds = findall(==(s), sites) + for (i,f) in enumerate(flist) + finds = findall(==(f), @view(frequencies[sinds])) + if !isempty(finds) + ss = Symbol(s, i) + push!(sflist, ss) + push!(inds, finds) + end + end + end + lookup = NamedTuple{Tuple(sflist)}(Tuple(inds)) + return SiteLookup(lookup, times, frequencies, sites) + end end """ @@ -201,10 +222,13 @@ Construct a site array with the entries `arr` and the site ordering implied by `sitelookup`. """ function SiteArray(a::AbstractArray, map::SiteLookup) - return SiteArray(a, map.times, map.frequencies, map.sites) + return SiteArray(a, times(map), frequencies(map), sites(map)) end -function SiteArray(data::SiteArray{T, N}, times::AbstractArray{<:IntegrationTime, N}, frequencies::AbstractArray{<:Number, N}, sites::AbstractArray{<:Number, N}) where {T, N} +function SiteArray(data::SiteArray{T, N}, + times::AbstractArray{<:IntegrationTime, N}, + frequencies::AbstractArray{<:FrequencyChannel, N}, + sites::AbstractArray{<:Any, N}) where {T, N} return data end diff --git a/src/skymodels/models.jl b/src/skymodels/models.jl index 8859d495..abdf880f 100755 --- a/src/skymodels/models.jl +++ b/src/skymodels/models.jl @@ -72,7 +72,7 @@ function VLBISkyModels.FourierDualDomain(grid::AbstractRectiGrid, array::Abstrac return FourierDualDomain(grid, domain(array; executor), alg) end -struct ObservedSkyModel{F, G<:VLBISkyModels.AbstractFourierDualDomain, M} <: AbstractSkyModel +struct ObservedSkyModel{F, G<:VLBISkyModels.AbstractDomain, M} <: AbstractSkyModel f::F grid::G metadata::M @@ -82,6 +82,14 @@ function domain(m::AbstractSkyModel; kwargs...) return getfield(m, :grid) end +# If we are using a analytic model then we don't need to plan the FT and we +# can save some memory by not storing the plans. +struct AnalyticAlg <: FourierTransform end +struct AnalyticPlan <: VLBISkyModels.AbstractPlan end +VLBISkyModels.getplan(::AnalyticPlan) = nothing +VLBISkyModels.getphases(::AnalyticPlan) = nothing +VLBISkyModels.create_plans(::AnalyticAlg, imgdomain, visdomain) = (AnalyticPlan(), AnalyticPlan()) + """ ObservedSkyModel(sky::AbstractSkyModel, array::AbstractArrayConfiguration) @@ -92,9 +100,18 @@ pass that to a [`VLBIPosterior`](@ref) object instead. """ function ObservedSkyModel(m::SkyModel, arr::AbstractArrayConfiguration) - return ObservedSkyModel(m.f, FourierDualDomain(m.grid, arr, m.algorithm), m.metadata) + x = rand(NamedDist(m.prior)) + ms = m.f(x, m.metadata) + # if analytic don't bother planning the FT + if ComradeBase.visanalytic(typeof(ms)) === ComradeBase.IsAnalytic() + g = FourierDualDomain(m.grid, arr, AnalyticAlg()) + else + g = FourierDualDomain(m.grid, arr, m.algorithm) + end + return ObservedSkyModel(m.f, g, m.metadata) end + function set_array(m::AbstractSkyModel, array::AbstractArrayConfiguration) return ObservedSkyModel(m, array), m.prior end diff --git a/test/Core/core.jl b/test/Core/core.jl index 7523c531..8b6bd6b0 100644 --- a/test/Core/core.jl +++ b/test/Core/core.jl @@ -13,4 +13,3 @@ include(joinpath(@__DIR__, "observation.jl")) include(joinpath(@__DIR__, "partially_fixed.jl")) include(joinpath(@__DIR__, "models.jl")) include(joinpath(@__DIR__, "bayes.jl")) -# include(joinpath(@__DIR__, "rules.jl")) diff --git a/test/Core/models.jl b/test/Core/models.jl index e7f7a80b..93e6628a 100644 --- a/test/Core/models.jl +++ b/test/Core/models.jl @@ -20,6 +20,90 @@ _ntequal(x::T, y::T) where {T<:NamedTuple} = ntequal(values(x), values(y)) _ntequal(x::T, y::T) where {T<:Tuple} = map(_ntequal, x, y) _ntequal(x, y) = x ≈ y +function build_mfvis(vistuple...) + configs = arrayconfig.(vistuple) + vis = vistuple[1] + newdatatables = Comrade.StructArray(reduce(vcat, Comrade.datatable.(configs))) + newscans = reduce(vcat, getproperty.(configs, :scans)) + newconfig = Comrade.EHTArrayConfiguration(vis.config.bandwidth, + vis.config.tarr, + newscans, + vis.config.mjd, + vis.config.ra, + vis.config.dec, + vis.config.source, + :UTC, + newdatatables) + newmeasurement = reduce(vcat, Comrade.measurement.(vistuple)) + newnoise = reduce(vcat, Comrade.noise.(vistuple)) + + return Comrade.EHTObservationTable{Comrade.datumtype(vis)}(newmeasurement,newnoise,newconfig) +end + + +function isequalmissing(x, y) + xm = x |> ismissing |> collect + ym = y |> ismissing |> collect + return xm == ym +end + + +function test_caltable(c1, sites) + @test Tables.istable(typeof(c1)) + @test Tables.rowaccess(typeof(c1)) + @test Tables.rows(c1) === c1 + @test Tables.columnaccess(c1) + clmns = Tables.columns(c1) + @test clmns[1] == Comrade.times(c1) + @test clmns[2] == Comrade.frequencies(c1) + @test Bool(prod(skipmissing(Tables.matrix(clmns)[:,begin+2:end]) .== skipmissing(Comrade.gmat(c1)))) + @test c1.Ti == Comrade.times(c1) + @test c1.Ti == Tables.getcolumn(c1, 1) + @test c1.Fr == Comrade.frequencies(c1) + @test c1.Fr == Tables.getcolumn(c1, 2) + @test isequalmissing(c1.AA, Tables.getcolumn(c1, 3)) + + @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 + @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 3))) ≈ 0 + @test Tables.columnnames(c1) == [:Ti, :Fr, sites...] + + c1row = Tables.getrow(c1, 30) + @test eltype(c1) == typeof(c1row) + @test c1row.Ti == c1.Ti[30] + @test c1row.AA == c1.AA[30] + @test c1row.Fr == c1.Fr[30] + @test Tables.getcolumn(c1row, :AA) == c1.AA[30] + @test Tables.getcolumn(c1row, :Ti) == c1.Ti[30] + @test Tables.getcolumn(c1row, :Fr) == c1.Fr[30] + @test Tables.getcolumn(c1row, 3) == c1.AA[30] + @test Tables.getcolumn(c1row, 2) == c1.Fr[30] + @test Tables.getcolumn(c1row, 1) == c1.Ti[30] + @test propertynames(c1) == propertynames(c1row) == [:Ti, :Fr, sites...] + @test Tables.getcolumn(c1row, Float64, 1, :Ti) == c1.Ti[30] + @test Tables.getcolumn(c1row, Float64, 2, :Fr) == c1.Fr[30] + @test Tables.getcolumn(c1row, Float64, 3, :AA) == c1.AA[30] + @test isequalmissing(c1[1:10, :AA], c1.AA[1:10]) + @test isequalmissing(c1[[1,2], :AA], c1.AA[[1,2]]) + @test isequalmissing(@view(c1[1:10, :AA]), @view(c1.AA[1:10])) + @test isequalmissing(@view(c1[[1,2], :AA]), @view(c1.AA[[1,2]])) + + Tables.schema(c1) isa Tables.Schema + Tables.getcolumn(c1, Float64, 1, :test) + Tables.getcolumn(c1, Float64, 2, :test) + + c1[1, :AA] + c1[!, :AA] + c1[:, :AA] + @test length(c1) == length(c1.AA) + @test c1[1 ,:] isa Comrade.CalTableRow + @test length(Tables.getrow(c1, 1:5)) == 5 + + plot(c1) + plot(c1, datagains=true) + plot(c1, sites=(:AA,)) + + show(c1) +end @testset "SkyModel" begin @@ -105,8 +189,13 @@ end @test Comrade.SiteArray(x.lg, sl) == x.lg - Comrade.time(x.lg, 5.0..6.0) - Comrade.frequency(x.lg, 1.0..400.0) + @inferred Comrade.time(x.lg, 5.0..6.0) + @inferred Comrade.frequency(x.lg, 1.0..400.0) + + @test x.lg ≈ SiteArray(x.lg, x.lg.times, x.lg.frequencies, x.lg.sites) + @inferred x.lg[1,1,1] + x.lg[1,1,1,1] = 1.0 + @test x.lg[1] ≈ 1.0 # ps = ProjectTo(x.lg) # @test ps(x.lg) == x.lg @@ -213,7 +302,7 @@ end lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), - gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), refant=SingleReference(:AA, 0.0)), dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), @@ -365,45 +454,133 @@ end @testset "caltable test" begin c1 = caltable(x.lgR) - @test Tables.istable(typeof(c1)) - @test Tables.rowaccess(typeof(c1)) - @test Tables.rows(c1) === c1 - @test Tables.columnaccess(c1) - clmns = Tables.columns(c1) - @test clmns[1] == Comrade.scantimes(c1) - @test Bool(prod(skipmissing(Tables.matrix(clmns)[:,begin+1:end]) .== skipmissing(Comrade.gmat(c1)))) - @test c1.time == Comrade.scantimes(c1) - @test c1.time == Tables.getcolumn(c1, 1) - @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, :AA))) ≈ 0 - @test maximum(abs, skipmissing(c1.AA) .- skipmissing(Tables.getcolumn(c1, 2))) ≈ 0 - @test Tables.columnnames(c1) == [:time, sort(sites(amp))...] - - c1row = Tables.getrow(c1, 30) - @test eltype(c1) == typeof(c1row) - @test c1row.time == c1.time[30] - @test c1row.AA == c1.AA[30] - @test Tables.getcolumn(c1row, :AA) == c1.AA[30] - @test Tables.getcolumn(c1row, :time) == c1.time[30] - @test Tables.getcolumn(c1row, 2) == c1.AA[30] - @test Tables.getcolumn(c1row, 1) == c1.time[30] - @test propertynames(c1) == propertynames(c1row) == [:time, sort(sites(amp))...] - - Tables.schema(c1) isa Tables.Schema - Tables.getcolumn(c1, Float64, 1, :test) - Tables.getcolumn(c1, Float64, 2, :test) - - c1[1, :AA] - c1[!, :AA] - c1[:, :AA] - @test length(c1) == length(c1.AA) - @test c1[1 ,:] isa Comrade.CalTableRow - @test length(Tables.getrow(c1, 1:5)) == 5 - - plot(c1) - plot(c1, datagains=true) - plot(c1, sites=(:AA,)) - - show(c1) + test_caltable(c1, sort(sites(amp))) + end + + end + + + @testset "Coherencies Multifrequency" begin + dcoh2 = deepcopy(dcoh) + dcoh2.config[:Fr] .= 345e9 + dcohmf = build_mfvis(dcoh, dcoh2) + vissi = CoherencyMatrix.(Comrade.measurement(dcoh), Ref(CirBasis())) + vismf = CoherencyMatrix.(Comrade.measurement(dcohmf), Ref(CirBasis())) + G = JonesG() do x + gR = exp(x.lgR + 1im*x.gpR) + gL = gR*exp(x.lgrat + 1im*x.gprat) + return gR, gL + end + + D = JonesD() do x + dR = complex(x.dRx, x.dRy) + dL = complex(x.dLx, x.dLy) + return dR, dL + end + + R = JonesR(;add_fr=true) + + J = JonesSandwich(*, G, D, R) + F = JonesF() + + intprior = ( + lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + gpR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, inv(π ^2))); phase=true, refant=SEFDReference(0.0)), + lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1)), phase=false), + gprat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.1))), + dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + dLy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), + ) + + + intm = InstrumentModel(J, intprior) + show(IOBuffer(), MIME"text/plain"(), intm) + + ointsi, printsi = Comrade.set_array(intm, arrayconfig(dcoh)) + ointmf, printmf = Comrade.set_array(intm, arrayconfig(dcohmf)) + + + Rsi = Comrade.preallocate_jones(F, arrayconfig(dcoh), CirBasis()) + Rmf = Comrade.preallocate_jones(R, arrayconfig(dcohmf), CirBasis()) + # Check that the copied matrices are identical + @test Rsi.matrices[1] ≈ Rmf.matrices[1][1:length(Rsi.matrices[1])] + @test Rsi.matrices[1] ≈ Rmf.matrices[1][length(Rsi.matrices[1])+1:end] + @test Rsi.matrices[2] ≈ Rmf.matrices[2][1:length(Rsi.matrices[1])] + @test Rsi.matrices[2] ≈ Rmf.matrices[2][length(Rsi.matrices[1])+1:end] + + for p in propertynames(ointsi.bsitelookup) + L = length(ointsi.bsitelookup[p].indices_1) + @test ointsi.bsitelookup[p].indices_1 == ointmf.bsitelookup[p].indices_1[1:L] + @test ointsi.bsitelookup[p].indices_2 == ointmf.bsitelookup[p].indices_2[1:L] + @test 2*L == length(ointmf.bsitelookup[p].indices_1) + end + + pintmf, _ = Comrade.set_array(InstrumentModel(R), arrayconfig(dcohmf)) + + xsi = rand(printsi) + xmf = rand(printmf) + + for s in sites(dcoh) + map(x->fill!(x, 0.0), xsi) + map(x->fill!(x, 0.0), xmf) + + inds1si = findall(x->(x[1]==s), dcoh[:baseline].sites) + inds2si = findall(x->(x[2]==s), dcoh[:baseline].sites) + nindssi = findall(x->(x[1]!=s && x[2]!=s), dcoh[:baseline].sites) + + inds1mf = findall(x->(x[1]==s), dcohmf[:baseline].sites) + inds2mf = findall(x->(x[2]==s), dcohmf[:baseline].sites) + nindsmf = findall(x->(x[1]!=s && x[2]!=s), dcohmf[:baseline].sites) + + xsilgR = xsi.lgR[S=s] + xsilgR .= log(2) + xmflgR = xmf.lgR[S=s] + xmflgR[1:length(xsilgR)] .= xsilgR + xmflgR[length(xsilgR)+1:end] .= 2 .* xsilgR + + xsilgrat = xsi.lgrat[S=s] + xsilgrat .= -log(2) + xmflgrat = xmf.lgrat[S=s] + xmflgrat[1:length(xsilgrat)] .= xsilgrat + xmflgrat[length(xsilgrat)+1:end] .= xsilgrat + vmf = Comrade.apply_instrument(vismf, ointmf, (;instrument=xmf)) + vsi = Comrade.apply_instrument(vissi, ointsi, (;instrument=xsi)) + Gmf = SMatrix{2,2}(2.0, 0.0, 0.0, 2.0) + @test vsi[inds1si] ≈ vmf[inds1si] + @test vsi[inds1si] .* Ref(Gmf) ≈ vmf[inds1mf[length(inds1si)+1:end]] + + # Now phases + map(x->fill!(x, 0.0), xsi) + map(x->fill!(x, 0.0), xmf) + + xsigpR = xsi.gpR[S=s] + xsigpR .= π/3 + xmfgpR = xmf.gpR[S=s] + xmfgpR[1:length(xsigpR)] .= xsigpR + xmfgpR[length(xsilgR)+1:end] .= 2 .* xsigpR + + xsigprat = xsi.gprat[S=s] + xsigprat .= -π/6 + xmfgprat = xmf.gprat[S=s] + xmfgprat[1:length(xsigprat)] .= xsigprat + xmfgprat[length(xsigprat)+1:end] .= xsigprat + + vmf = Comrade.apply_instrument(vismf, ointmf, (;instrument=xmf)) + vsi = Comrade.apply_instrument(vissi, ointsi, (;instrument=xsi)) + Gmf = SMatrix{2,2}(exp(1im*π/3), 0.0, 0.0, exp(1im*π/3)) + @test vsi[inds1si] ≈ vmf[inds1si] + @test vsi[inds1si] .* Ref(Gmf) ≈ vmf[inds1mf[length(inds1si)+1:end]] + + + end + + + @testset "caltable test" begin + xmf = rand(printmf) + c1 = caltable(xmf.lgR) + test_caltable(c1, sort(sites(amp))) end end @@ -416,5 +593,24 @@ end @test length(tt) < length(ts) ≤ length(ti) end + @testset "IntegrationTime" begin + ti = Comrade.IntegrationTime(10, 5.0, 0.1) + @test Comrade.mjd(ti) == ti.mjd + @test ti.t0 ∈ Comrade.interval(ti) + @test Comrade._center(ti) == ti.t0 + @test Comrade._region(ti) == 0.1 + end + + @testset "FrequencyChannel" begin + fc = Comrade.FrequencyChannel(230e9, 8e9, 1) + @test Comrade.channel(fc) == 1 + @test fc.central ∈ Comrade.interval(fc) + @test Comrade._center(fc) == fc.central + @test Comrade._region(fc) == 8e9 + @test 86e9 < fc + @test fc < 345e9 + end + + end diff --git a/test/Project.toml b/test/Project.toml index b4519898..4259d1c3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,4 +37,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029" VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/test_util.jl b/test/test_util.jl index cf60e055..ae916459 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -8,6 +8,7 @@ function load_data() joinpath(@__DIR__, "../examples/Data/array.txt"), polrep="circ" ) + obspol.add_scans() m = ehtim.model.Model() m = m.add_gauss(1.0, μas2rad(40.0), μas2rad(20.0), π/3, 0.0, 0.0)