Skip to content

Commit

Permalink
fix #41 averaging, add tests
Browse files Browse the repository at this point in the history
use averaging algorithm from Marlu
  • Loading branch information
d3v-null committed Nov 30, 2024
1 parent ae5d0ae commit 4ca41b2
Show file tree
Hide file tree
Showing 5 changed files with 84,285 additions and 34 deletions.
50 changes: 18 additions & 32 deletions src/averaging/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,55 +574,41 @@ pub(crate) fn vis_average(
/// to be negative; if all of the weights in the chunk are negative or 0, the
/// averaged visibility is considered "flagged".
#[inline]
fn vis_average_weights_non_zero(
pub(super) fn vis_average_weights_non_zero(
jones_chunk_tf: ArrayView2<Jones<f32>>,
weight_chunk_tf: ArrayView2<f32>,
jones_to: &mut Jones<f32>,
weight_to: &mut f32,
) {
let mut jones_weighted_sum = Jones::default();
let mut weight_sum = 0.0;
let mut flagged = true;
let mut jones_sum = Jones::default();
let mut unflagged_weight_sum = 0.0;
let mut flagged_weight_sum = 0.0;
let mut all_flagged = true;

// iterate through time chunks
jones_chunk_tf
.iter()
.zip_eq(weight_chunk_tf.iter())
.for_each(|(jones, weight)| {
let jones = Jones::<f64>::from(*jones);
let weight = *weight as f64;

if weight > 0.0 {
// This visibility is not flagged.
if flagged {
// If previous visibilities were flagged, we need to discard
// that information.
jones_weighted_sum = jones * weight;
weight_sum = weight;
flagged = false;
} else {
// Otherwise, we're accumulating this unflagged vis.
jones_weighted_sum += jones * weight;
weight_sum += weight;
}
jones_sum += jones;

let weight_abs_f64 = (*weight as f64).abs();
if *weight > 0.0 {
all_flagged = false;
jones_weighted_sum += jones * weight_abs_f64;
unflagged_weight_sum += weight_abs_f64;
} else {
// This visibility is flagged.
if flagged {
// If all prior vis were also flagged, we accumulate here.
jones_weighted_sum += jones * weight;
weight_sum += weight;
}
// Nothing needs to be done if there were preceding unflagged
// vis.
flagged_weight_sum += weight_abs_f64;
}
});

if weight_sum == 0.0 {
// If the weight is 0, we can't divide the accumulated vis by the
// accumulated weight. So, divide by the chunk size instead.
*jones_to = Jones::from(jones_weighted_sum / jones_chunk_tf.len() as f64);
if all_flagged || unflagged_weight_sum <= 0.0 {
*jones_to = Jones::from(jones_sum / jones_chunk_tf.len() as f64);
*weight_to = -flagged_weight_sum as f32;
} else {
*jones_to = Jones::from(jones_weighted_sum / weight_sum);
*jones_to = Jones::from(jones_weighted_sum / unflagged_weight_sum);
*weight_to = unflagged_weight_sum as f32;
}
*weight_to = weight_sum as f32;
}
137 changes: 136 additions & 1 deletion src/averaging/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ fn test_vis_average_non_uniform_weights() {
&HashSet::new(),
);

// (1 * 2) + (11 * 5) = 57
assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 57.0 / 13.0);
assert_abs_diff_eq!(weight_to_fb[(0, 0)], 13.0);

Expand All @@ -619,7 +620,8 @@ fn test_vis_average_non_uniform_weights() {
);

// The first channel's weight accumulates only negatives.
assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 57.0 / 13.0);
// (1 + 5) / 2 = 3
assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 3.0);
assert_abs_diff_eq!(weight_to_fb[(0, 0)], -13.0);

assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 84.0 / 16.0);
Expand Down Expand Up @@ -726,3 +728,136 @@ fn test_vis_average_non_uniform_weights_non_integral_array_shapes() {
assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 93.0 / 18.0);
assert_abs_diff_eq!(weight_to_fb[(1, 0)], 18.0);
}

#[test]
fn test_vis_average_half_flagged() {
// 2 timesteps, 4 channels, 1 baseline.
let jones_from_tfb = array![
[
[Jones::identity()],
[Jones::identity() * 2.0],
[Jones::identity() * 3.0],
[Jones::identity() * 4.0]
],
[
[Jones::identity() * 5.0],
[Jones::identity() * 6.0],
[Jones::identity() * 7.0],
[Jones::identity() * 8.0]
]
];
let weight_from_tfb = array![[[1.], [1.], [0.], [0.]], [[1.], [1.], [0.], [0.]]];
let mut jones_to_fb = Array2::default((2, 1));
let mut weight_to_fb = Array2::default(jones_to_fb.dim());

vis_average(
jones_from_tfb.view(),
jones_to_fb.view_mut(),
weight_from_tfb.view(),
weight_to_fb.view_mut(),
&HashSet::new(),
);

assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 14. / 4.);
assert_abs_diff_eq!(weight_to_fb[(0, 0)], 4.0);

assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 22. / 4.);
assert_abs_diff_eq!(weight_to_fb[(1, 0)], 0.0);
}

#[test]
fn test_vis_average_weights_non_zero_half_flagged() {
// 2 timesteps, 2 channels
#[rustfmt::skip]
let jones_from_tf = array![
[
Jones::<f32>::identity(),
Jones::identity() * 2.0,
],
[
Jones::identity() * 3.0,
Jones::identity() * 4.0,
],
];
#[rustfmt::skip]
let weight_from_tf = array![
[1., 1.],
[1., 1.]
];
let mut jones_to = Jones::<f32>::default();
let mut weight_to = 0.0;

vis_average_weights_non_zero(
jones_from_tf.view(),
weight_from_tf.view(),
&mut jones_to,
&mut weight_to,
);

assert_abs_diff_eq!(jones_to, Jones::identity() * 10. / 4.);
assert_abs_diff_eq!(weight_to, 4.0);

#[rustfmt::skip]
let weight_from_tf = array![
[1., -1.],
[-1., 1.]
];

vis_average_weights_non_zero(
jones_from_tf.view(),
weight_from_tf.view(),
&mut jones_to,
&mut weight_to,
);

assert_abs_diff_eq!(jones_to, Jones::identity() * 5. / 2.);
assert_abs_diff_eq!(weight_to, 2.0);

#[rustfmt::skip]
let weight_from_tf = array![
[1., 0.],
[0., 1.]
];

vis_average_weights_non_zero(
jones_from_tf.view(),
weight_from_tf.view(),
&mut jones_to,
&mut weight_to,
);

assert_abs_diff_eq!(jones_to, Jones::identity() * 5. / 2.);
assert_abs_diff_eq!(weight_to, 2.0);

#[rustfmt::skip]
let weight_from_tf = array![
[0., 0.],
[0., 0.]
];

vis_average_weights_non_zero(
jones_from_tf.view(),
weight_from_tf.view(),
&mut jones_to,
&mut weight_to,
);

assert_abs_diff_eq!(jones_to, Jones::identity() * 10. / 4.);
assert_abs_diff_eq!(weight_to, 0.0);

#[rustfmt::skip]
let weight_from_tf = array![
[-1., -1.],
[-1., -1.]
];

vis_average_weights_non_zero(
jones_from_tf.view(),
weight_from_tf.view(),
&mut jones_to,
&mut weight_to,
);

assert_abs_diff_eq!(jones_to, Jones::identity() * 10. / 4.);
assert_abs_diff_eq!(weight_to, -4.0);
}
68 changes: 67 additions & 1 deletion src/cli/vis_convert/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ use std::{
};

use clap::Parser;
use ndarray::prelude::*;
use tempfile::TempDir;

use super::VisConvertArgs;
use crate::{
io::read::VisRead,
tests::{get_reduced_1090008640_raw, DataAsStrings},
params::VisConvertParams,
tests::{get_reduced_1061316544_uvfits, get_reduced_1090008640_raw, DataAsStrings},
MsReader, UvfitsReader,
};

Expand Down Expand Up @@ -150,3 +152,67 @@ fn test_per_coarse_chan_flags_and_smallest_contiguous_band_writing() {
}
}
}

#[test]
fn test_averaging_flags() {
let temp_dir = TempDir::new().expect("couldn't make tmp dir");
let uvfits_converted = temp_dir.path().join("converted.uvfits");
let DataAsStrings { vis, .. } = get_reduced_1061316544_uvfits();
// let vis = PathBuf::from(&vis[0]);
let uvfits_converted_string = uvfits_converted.display().to_string();
#[rustfmt::skip]
let args = vec![
"vis-convert",
"--data", &vis[0],
"--outputs", &uvfits_converted_string,
"--freq-average", "80kHz",
"--ignore-input-data-fine-channel-flags",
];

let vis_convert_args = VisConvertArgs::parse_from(args);

let vis_convert_params = vis_convert_args.parse().unwrap();
vis_convert_params.run().unwrap();
let VisConvertParams {
input_vis_params, ..
} = vis_convert_params;

let uvreader = UvfitsReader::new(uvfits_converted, None, None).unwrap();
// let obs_context = uvreader.get_obs_context();
let num_unflagged_tiles = input_vis_params.get_num_unflagged_tiles();
let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2;
let num_fine_channels = input_vis_params.spw.chanblocks.len();
let flagged_channels = input_vis_params.spw.flagged_chan_indices;
let cross_vis_shape = (num_fine_channels, num_unflagged_cross_baselines);
let auto_vis_shape = (num_fine_channels, num_unflagged_tiles);

let mut cross_data_fb = Array2::zeros(cross_vis_shape);
let mut cross_weights_fb = Array2::zeros(cross_vis_shape);
let mut auto_data_fb = Array2::zeros(auto_vis_shape);
let mut auto_weights_fb = Array2::zeros(auto_vis_shape);
let timestep = 0;
uvreader
.read_crosses_and_autos(
cross_data_fb.view_mut(),
cross_weights_fb.view_mut(),
auto_data_fb.view_mut(),
auto_weights_fb.view_mut(),
timestep,
&input_vis_params.tile_baseline_flags,
&flagged_channels,
)
.unwrap();
auto_data_fb.indexed_iter().for_each(|((chan, tile), val)| {
let j = val;
if tile == 76 {
return;
}
assert!(
j[0].re > 0.0,
"auto_data_fb[{}, {}][0].re = {}",
chan,
tile,
j[0].re
);
});
}
9 changes: 9 additions & 0 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ pub(crate) fn deflate_gz_into_tempfile<T: AsRef<Path>>(file: T) -> TempPath {
}

const DATA_DIR_1090008640: &str = "test_files/1090008640";
const DATA_DIR_1061316544: &str = "test_files/1061316544";

#[derive(Default)]
pub(crate) struct DataAsStrings {
pub(crate) metafits: String,
pub(crate) vis: Vec<String>,
Expand Down Expand Up @@ -63,6 +65,13 @@ pub(crate) fn get_reduced_1090008640_uvfits() -> DataAsStrings {
data
}

pub(crate) fn get_reduced_1061316544_uvfits() -> DataAsStrings {
DataAsStrings {
vis: vec![format!("{DATA_DIR_1061316544}/1061316544.uvfits")],
..DataAsStrings::default()
}
}

pub(crate) fn get_reduced_1090008640_raw_pbs() -> DataAsPathBufs {
let DataAsStrings {
metafits,
Expand Down
Loading

0 comments on commit 4ca41b2

Please sign in to comment.