From ddb8163d0c2ef04a2d522aa6ef56b28d67d4d036 Mon Sep 17 00:00:00 2001 From: Dev Null Date: Wed, 26 Jun 2024 12:44:53 +0800 Subject: [PATCH] implement antenna (tile) selection (`--sel-ants`) (#151) * add baseline_selection in correct_geometry, correct_cable_lengths, * use marlu issue-150 * use vis_sel to reduce args in correct_geometry * fix calc_part_uvws for ant-sel, run ant-sel tests --- Cargo.toml | 4 +- benches/expensive_benches.rs | 8 ++-- src/cli.rs | 52 ++++++++++++++++++++- src/corrections.rs | 69 +++++++++++++-------------- src/io/mod.rs | 91 ++++++++++++++++++++++++++++++++++++ src/io/mwaf.rs | 1 - src/preprocessing.rs | 4 +- 7 files changed, 182 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 450f4b0..666ffea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,7 +69,7 @@ ndarray = { version = "0.15.4", features = ["approx-0_5"] } tempfile = "3.3" [build-dependencies] -built = { version = "~0.7.3", default-features = false, features = [ +built = { version = "~0.7", default-features = false, features = [ "chrono", "git2", ] } @@ -91,6 +91,6 @@ opt-level = 3 [patch.crates-io] # marlu = { path = "../Marlu" } -# marlu = { git = "https://github.com/MWATelescope/Marlu", branch = "DUT1" } +# marlu = { git = "https://github.com/MWATelescope/Marlu", branch = "birli-150" } # mwalib = { path = "../mwalib" } # mwalib = { git = "https://github.com/MWATelescope/mwalib", branch = "digital_gains_plus" } diff --git a/benches/expensive_benches.rs b/benches/expensive_benches.rs index 8543603..62e6081 100644 --- a/benches/expensive_benches.rs +++ b/benches/expensive_benches.rs @@ -120,6 +120,7 @@ fn bench_correct_cable_lengths_mwax_half_1247842824(crt: &mut Criterion) { black_box(&corr_ctx), black_box(jones_array.view_mut()), black_box(&vis_sel.coarse_chan_range), + black_box(&vis_sel.baseline_idxs), false, ) }) @@ -147,6 +148,7 @@ fn bench_correct_cable_lengths_ord_half_1196175296(crt: &mut Criterion) { black_box(&corr_ctx), black_box(jones_array.view_mut()), black_box(&vis_sel.coarse_chan_range), + black_box(&vis_sel.baseline_idxs), false, ) }) @@ -173,8 +175,7 @@ fn bench_correct_geometry_mwax_half_1247842824(crt: &mut Criterion) { correct_geometry( black_box(&corr_ctx), black_box(jones_array.view_mut()), - black_box(&vis_sel.timestep_range), - black_box(&vis_sel.coarse_chan_range), + black_box(&vis_sel), None, None, false, @@ -203,8 +204,7 @@ fn bench_correct_geometry_ord_half_1196175296(crt: &mut Criterion) { correct_geometry( black_box(&corr_ctx), black_box(jones_array.view_mut()), - black_box(&vis_sel.timestep_range), - black_box(&vis_sel.coarse_chan_range), + black_box(&vis_sel), None, None, false, diff --git a/src/cli.rs b/src/cli.rs index 57c7b66..ae3ebdf 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1841,8 +1841,8 @@ mod tests { use tempfile::tempdir; use crate::{ - test_common::get_1254670392_avg_paths, test_common::get_mwax_data_paths, BirliContext, - VisSelection, + test_common::{get_1254670392_avg_paths, get_mwax_data_paths}, + BirliContext, BirliError, VisSelection, }; #[test] @@ -2143,6 +2143,54 @@ mod tests { let BirliContext { flag_ctx, .. } = BirliContext::from_args(&args).unwrap(); } + + /// Test that corrections work correctly with `--sel-ants` + #[test] + fn test_sel_ants_baselines() { + let (metafits_path, gpufits_paths) = get_1254670392_avg_paths(); + + #[rustfmt::skip] + let args = vec![ + "birli", + "-m", metafits_path, + "--sel-ants", "1", "2", "3", "4", + "--no-draw-progress", + gpufits_paths[0], + gpufits_paths[1], + ]; + + let birli_ctx = BirliContext::from_args(&args).unwrap(); + + // check baseline_idxs is the correct size + assert_eq!(&birli_ctx.vis_sel.baseline_idxs.len(), &10); + assert_eq!( + &birli_ctx.vis_sel.baseline_idxs, + &[128, 129, 130, 131, 255, 256, 257, 381, 382, 506] + ); + + birli_ctx.run().unwrap(); + } + + /// Test `--sel-ants` handles invalid antenna idxs + #[test] + fn test_sel_ants_invalid() { + let (metafits_path, gpufits_paths) = get_1254670392_avg_paths(); + + #[rustfmt::skip] + let args = vec![ + "birli", + "-m", metafits_path, + "--sel-ants", "0", "11", "3", "999", + "--no-draw-progress", + gpufits_paths[0], + gpufits_paths[1], + ]; + + assert!(matches!( + BirliContext::from_args(&args), + Err(BirliError::CLIError(_)) + )); + } } #[cfg(test)] diff --git a/src/corrections.rs b/src/corrections.rs index 1de9593..8473df1 100644 --- a/src/corrections.rs +++ b/src/corrections.rs @@ -12,7 +12,7 @@ use marlu::{ io::error::BadArrayShape, mwalib::{CorrelatorContext, MWAVersion}, precession::precess_time, - Complex, LatLngHeight, RADec, XyzGeodetic, UVW, + Complex, LatLngHeight, RADec, VisSelection, XyzGeodetic, UVW, }; use std::{f64::consts::TAU, ops::Range}; use thiserror::Error; @@ -55,7 +55,7 @@ use thiserror::Error; /// read_mwalib(&vis_sel, &corr_ctx, jones_array.view_mut(), flag_array.view_mut(), false) /// .unwrap(); /// -/// correct_cable_lengths(&corr_ctx, jones_array.view_mut(), &vis_sel.coarse_chan_range, false); +/// correct_cable_lengths(&corr_ctx, jones_array.view_mut(), &vis_sel.coarse_chan_range, &vis_sel.baseline_idxs, false); /// ``` /// /// # Accuracy @@ -65,9 +65,9 @@ use thiserror::Error; pub fn correct_cable_lengths( corr_ctx: &CorrelatorContext, mut jones_array: ArrayViewMut3>, + // TODO: take a VisSelection coarse_chan_range: &Range, - // TODO: allow subset of baselines - // baseline_idxs: &[usize], + baseline_idxs: &[usize], draw_progress: bool, ) { trace!("start correct_cable_lengths"); @@ -77,9 +77,14 @@ pub fn correct_cable_lengths( let all_freqs_hz = corr_ctx.get_fine_chan_freqs_hz_array(&coarse_chan_range.clone().collect::>()); - let ant_pairs = (meta_ctx.baselines) + let ant_pairs = baseline_idxs .iter() - .map(|b| (b.ant1_index, b.ant2_index)) + .map(|b| { + ( + meta_ctx.baselines[*b].ant1_index, + meta_ctx.baselines[*b].ant2_index, + ) + }) .collect::>(); let draw_target = if draw_progress { @@ -188,26 +193,21 @@ pub fn correct_cable_lengths( /// read_mwalib(&vis_sel, &corr_ctx, jones_array.view_mut(), flag_array.view_mut(), false) /// .unwrap(); /// -/// correct_cable_lengths(&corr_ctx, jones_array.view_mut(), &vis_sel.coarse_chan_range, false); +/// correct_cable_lengths(&corr_ctx, jones_array.view_mut(), &vis_sel.coarse_chan_range, &vis_sel.baseline_idxs, false); /// /// correct_geometry( /// &corr_ctx, /// jones_array.view_mut(), -/// &vis_sel.timestep_range, -/// &vis_sel.coarse_chan_range, +/// &vis_sel, /// None, /// None, /// false, /// ); /// ``` -#[allow(clippy::too_many_arguments)] pub fn correct_geometry( corr_ctx: &CorrelatorContext, mut jones_array: ArrayViewMut3>, - timestep_range: &Range, - coarse_chan_range: &Range, - // TODO: allow subset of baselines - // baseline_idxs: &[usize], + vis_sel: &VisSelection, array_pos: Option, phase_centre: Option, draw_progress: bool, @@ -222,12 +222,10 @@ pub fn correct_geometry( LatLngHeight::mwa() }); - let timesteps = &corr_ctx.timesteps[timestep_range.clone()]; - - let baselines = &corr_ctx.metafits_context.baselines; + let timesteps = &corr_ctx.timesteps[vis_sel.timestep_range.clone()]; - let all_freqs_hz = - corr_ctx.get_fine_chan_freqs_hz_array(&coarse_chan_range.clone().collect::>()); + let all_freqs_hz = corr_ctx + .get_fine_chan_freqs_hz_array(&vis_sel.coarse_chan_range.clone().collect::>()); let jones_dims = jones_array.dim(); let integration_time_s = corr_ctx.metafits_context.corr_int_time_ms as f64 / 1000.0; @@ -236,17 +234,14 @@ pub fn correct_geometry( .unwrap_or_else(|| RADec::from_mwalib_phase_or_pointing(&corr_ctx.metafits_context)); let tiles_xyz_geod = XyzGeodetic::get_tiles(&corr_ctx.metafits_context, array_pos.latitude_rad); - let ant_pairs = baselines - .iter() - .map(|b| (b.ant1_index, b.ant2_index)) - .collect::>(); + // use baseline_idxs to select antpairs out of corr_ctx.metafits_context.baselines + let ant_pairs = vis_sel.get_ant_pairs(&corr_ctx.metafits_context); let centroid_timestamps = timesteps .iter() .map(|t| Epoch::from_gpst_seconds(t.gps_time_ms as f64 / 1000.0 + integration_time_s / 2.0)) .collect::>(); let dut1 = Duration::from_seconds(corr_ctx.metafits_context.dut1.unwrap_or(0.0)); let part_uvws = calc_part_uvws( - &ant_pairs, ¢roid_timestamps, dut1, phase_centre, @@ -329,9 +324,9 @@ pub enum DigitalGainCorrection { pub fn correct_digital_gains( corr_ctx: &CorrelatorContext, jones_array: ArrayViewMut3>, + // TODO: take a VisSelection coarse_chan_range: &Range, ant_pairs: &[(usize, usize)], - // TODO: take a VisSelection ) -> Result<(), DigitalGainCorrection> { let num_fine_chans_per_coarse = corr_ctx.metafits_context.num_corr_fine_chans_per_coarse; @@ -656,15 +651,14 @@ pub fn scrunch_gains( // UVWs are in units of meters. To get the UVWs in units of wavelengths, divide by the wavelength. // uvw at ts, (ant1, ant2) = part_uvw[ant1] - part_uvw[ant2] fn calc_part_uvws( - ant_pairs: &[(usize, usize)], centroid_timestamps: &[Epoch], dut1: Duration, phase_centre: RADec, array_pos: LatLngHeight, tile_xyzs: &[XyzGeodetic], ) -> Array2 { - let max_ant = ant_pairs.iter().map(|&(a, b)| a.max(b)).max().unwrap(); - let mut part_uvws = Array2::from_elem((centroid_timestamps.len(), max_ant + 1), UVW::default()); + let nants = tile_xyzs.len(); + let mut part_uvws = Array2::from_elem((centroid_timestamps.len(), nants), UVW::default()); for (t, &epoch) in centroid_timestamps.iter().enumerate() { let prec = precess_time( array_pos.longitude_rad, @@ -712,7 +706,8 @@ mod tests { #[test] fn test_cable_length_corrections_mwax() { let corr_ctx = get_mwax_context(); - let vis_sel = VisSelection::from_mwalib(&corr_ctx).unwrap(); + let mut vis_sel = VisSelection::from_mwalib(&corr_ctx).unwrap(); + vis_sel.baseline_idxs = vec![0, 1]; // Create a blank array to store flags and visibilities let fine_chans_per_coarse = corr_ctx.metafits_context.num_corr_fine_chans_per_coarse; @@ -814,6 +809,7 @@ mod tests { &corr_ctx, jones_array.view_mut(), &vis_sel.coarse_chan_range, + &vis_sel.baseline_idxs, false, ); @@ -954,6 +950,7 @@ mod tests { &corr_ctx, jones_array.view_mut(), &vis_sel.coarse_chan_range, + &vis_sel.baseline_idxs, false, ); @@ -1108,8 +1105,7 @@ mod tests { correct_geometry( &corr_ctx, jones_array.view_mut(), - &vis_sel.timestep_range, - &vis_sel.coarse_chan_range, + &vis_sel, None, None, false, @@ -1139,7 +1135,9 @@ mod tests { #[test] fn test_geometric_corrections_mwax() { let corr_ctx = get_mwax_context(); - let vis_sel = VisSelection::from_mwalib(&corr_ctx).unwrap(); + let mut vis_sel = VisSelection::from_mwalib(&corr_ctx).unwrap(); + // only select baselines we're looking at + vis_sel.baseline_idxs = vec![0, 1]; // Create a blank array to store flags and visibilities let fine_chans_per_coarse = corr_ctx.metafits_context.num_corr_fine_chans_per_coarse; @@ -1177,7 +1175,7 @@ mod tests { ]) ); - // ts 0, chan 0 (cc 0, fc 0), baseline 5 + // ts 0, chan 0 (cc 0, fc 0), baseline 1 let viz_0_0_1 = jones_array[(0, 0, 1)]; compare_jones!( viz_0_0_1, @@ -1189,7 +1187,7 @@ mod tests { ]) ); - // ts 3, chan 3 (cc 1, fc 1), baseline 5 + // ts 3, chan 3 (cc 1, fc 1), baseline 1 let viz_3_3_1 = jones_array[(3, 3, 1)]; compare_jones!( viz_3_3_1, @@ -1256,8 +1254,7 @@ mod tests { correct_geometry( &corr_ctx, jones_array.view_mut(), - &vis_sel.timestep_range, - &vis_sel.coarse_chan_range, + &vis_sel, None, None, false, diff --git a/src/io/mod.rs b/src/io/mod.rs index 8a3dfa6..1d757df 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -157,7 +157,9 @@ pub fn read_mwalib( let shape = vis_sel.get_shape(fine_chans_per_coarse); let (num_timesteps, _, _) = shape; let num_coarse_chans = vis_sel.coarse_chan_range.len(); + let max_bl_idx = corr_ctx.metafits_context.baselines.len(); + // check output array dimensions if jones_array.dim() != shape { return Err(SelectionError::BadArrayShape { argument: "jones_array".to_string(), @@ -176,6 +178,15 @@ pub fn read_mwalib( }); }; + // check all selected baseline idxs are < max_bl_idx + if vis_sel.baseline_idxs.iter().any(|&idx| idx >= max_bl_idx) { + return Err(SelectionError::BadBaselineIdx { + function: "VisSelection::read_mwalib".to_string(), + expected: format!(" < {max_bl_idx}"), + received: format!("{:?}", vis_sel.baseline_idxs.clone()), + }); + } + // since we are using read_by_baseline_into_buffer, the visibilities are read in order: // baseline,frequency,pol,r,i @@ -583,6 +594,86 @@ pub fn write_ms>( Ok(()) } +#[cfg(test)] +mod tests { + use approx::assert_abs_diff_eq; + use marlu::{Complex, Jones}; + + use crate::{compare_jones, test_common::get_mwax_context, VisSelection}; + + use super::read_mwalib; + + // test read_mwalib with bad vis_sel.baseline_idxs + #[test] + fn test_read_bad_baseline_sel() { + let corr_ctx = get_mwax_context(); + + let mut vis_sel = VisSelection::from_mwalib(&corr_ctx).unwrap(); + vis_sel.baseline_idxs = vec![5]; + + let fine_chans_per_coarse = corr_ctx.metafits_context.num_corr_fine_chans_per_coarse; + let mut flag_array = vis_sel.allocate_flags(fine_chans_per_coarse).unwrap(); + let mut jones_array = vis_sel.allocate_jones(fine_chans_per_coarse).unwrap(); + assert!(read_mwalib( + &vis_sel, + &corr_ctx, + jones_array.view_mut(), + flag_array.view_mut(), + false, + ) + .is_err()); + } + + // test read_mwalib with custom vis_sel.baseline_idxs + #[test] + fn test_read_baseline_sel() { + let corr_ctx = get_mwax_context(); + + let mut vis_sel = VisSelection::from_mwalib(&corr_ctx).unwrap(); + vis_sel.baseline_idxs = vec![1]; + + let fine_chans_per_coarse = corr_ctx.metafits_context.num_corr_fine_chans_per_coarse; + let mut flag_array = vis_sel.allocate_flags(fine_chans_per_coarse).unwrap(); + assert_eq!(flag_array.shape(), &[4, 4, 1]); + let mut jones_array = vis_sel.allocate_jones(fine_chans_per_coarse).unwrap(); + assert_eq!(jones_array.shape(), &[4, 4, 1]); + let weight_array = vis_sel.allocate_weights(fine_chans_per_coarse).unwrap(); + assert_eq!(weight_array.shape(), &[4, 4, 1]); + read_mwalib( + &vis_sel, + &corr_ctx, + jones_array.view_mut(), + flag_array.view_mut(), + false, + ) + .unwrap(); + + // ts 0, chan 0 (cc 0, fc 0), baseline 1 + let viz_0_0_1 = jones_array[(0, 0, 0)]; + compare_jones!( + viz_0_0_1, + Jones::from([ + Complex::new(0x410010 as f32, 0x410011 as f32), + Complex::new(0x410012 as f32, 0x410013 as f32), + Complex::new(0x410014 as f32, 0x410015 as f32), + Complex::new(0x410016 as f32, 0x410017 as f32), + ]) + ); + + // ts 3, chan 3 (cc 1, fc 1), baseline 1 + let viz_3_3_1 = jones_array[(3, 3, 0)]; + compare_jones!( + viz_3_3_1, + Jones::from([ + Complex::new(0x410718 as f32, 0x410719 as f32), + Complex::new(0x41071a as f32, 0x41071b as f32), + Complex::new(0x41071c as f32, 0x41071d as f32), + Complex::new(0x41071e as f32, 0x41071f as f32), + ]) + ); + } +} + #[cfg(test)] #[cfg(feature = "aoflagger")] /// Tests which require the use of the aoflagger feature diff --git a/src/io/mwaf.rs b/src/io/mwaf.rs index 6e184a7..45e1562 100644 --- a/src/io/mwaf.rs +++ b/src/io/mwaf.rs @@ -1346,7 +1346,6 @@ mod tests { num_ants, ); let int_time = 1e-3 * meta_ctx.corr_int_time_ms as f32; - dbg!(int_time); flag_ctx.flag_init = int_time; flag_ctx.flag_end = int_time; flag_ctx.timestep_flags[1] = true; diff --git a/src/preprocessing.rs b/src/preprocessing.rs index 3c43a00..9d2d988 100644 --- a/src/preprocessing.rs +++ b/src/preprocessing.rs @@ -170,6 +170,7 @@ impl<'a> PreprocessContext<'a> { corr_ctx, jones_array.view_mut(), &vis_sel.coarse_chan_range, + &vis_sel.baseline_idxs, self.draw_progress ) ); @@ -234,8 +235,7 @@ impl<'a> PreprocessContext<'a> { correct_geometry( corr_ctx, jones_array.view_mut(), - &vis_sel.timestep_range, - &vis_sel.coarse_chan_range, + vis_sel, Some(self.array_pos), Some(self.phase_centre), self.draw_progress,