Skip to content

Commit

Permalink
add crosses, signal chain and fontsize options
Browse files Browse the repository at this point in the history
  • Loading branch information
d3v-null committed Sep 13, 2024
1 parent a50aea4 commit 52bbc61
Showing 1 changed file with 168 additions and 54 deletions.
222 changes: 168 additions & 54 deletions demo/04_ssins.py
Original file line number Diff line number Diff line change
@@ -6,10 +6,17 @@
import numpy as np
from astropy.time import Time
import argparse
import pandas as pd

parser = argparse.ArgumentParser()
parser.add_argument("--no-diff", default=False, action="store_true")
parser.add_argument("--no-cache", default=False, action="store_true")
group = parser.add_mutually_exclusive_group()
group.add_argument("--crosses", default=False, action="store_true", help="plot crosses instead of default: autos")
group.add_argument("--sigchain", default=False, action="store_true", help="plot sigchain instead of default: spectrum", )

parser.add_argument("--threshold", default=5, help="match filter significance threshold")
parser.add_argument("--fontsize", default=8, help="plot font size")
parser.add_argument("files", nargs="+")
args = parser.parse_args()

@@ -26,36 +33,49 @@

# sky-subtract https://ssins.readthedocs.io/en/latest/sky_subtract.html
ss = SS()
diff = not args.no_diff
suffix = ".diff" if diff else ""
cache = f"{base}{suffix}.ssa.h5"

def get_unflagged_ants(ss, args):
# TODO: args to flag additional ants
return np.unique(ss.ant_1_array)

suffix = "" if args.no_diff else ".diff"
spectrum_type = "cross" if args.crosses else "auto"
suffix += f".{spectrum_type}"
cache = f"{base}{suffix}.h5"
if not args.no_cache and os.path.exists(cache):
print(f"reading from {cache=}")
ss.read_uvh5(cache, read_data=True, use_future_array_shapes=True)
unflagged_ants = get_unflagged_ants(ss, args)
else:
print(f"reading from {args.files=}")
ss.read(
args.files,
read_data=True,
diff=diff, # difference timesteps
diff=(not args.no_diff), # difference timesteps
remove_coarse_band=False, # does not work with low freq res
correct_van_vleck=False, # slow
remove_flagged_ants=True, # remove flagged antennas
)
# just look at autos
unflagged_ants = np.unique(ss.ant_1_array)
ss = ss.select(bls=[(a, a) for a in unflagged_ants], inplace=False)

unflagged_ants = get_unflagged_ants(ss, args)
def cross_filter(pair):
a, b = pair
return a != b and a in unflagged_ants and b in unflagged_ants
def auto_filter(pair):
a, b = pair
return a == b and a in unflagged_ants
bls = [*filter( {"cross": cross_filter, "auto": auto_filter}[spectrum_type], ss.get_antpairs() )]
ss = ss.select(bls=bls, inplace=False)
ss.apply_flags(flag_choice="original")
if not args.no_cache:
ss.write_uvh5(cache)
print(f"wrote ss to {cache=}")

# incoherent noise spectrum https://ssins.readthedocs.io/en/latest/incoherent_noise_spectrum.html
ins = INS(ss, spectrum_type="auto")
ins = INS(ss, spectrum_type=spectrum_type)

# match filter https://ssins.readthedocs.io/en/latest/match_filter.html
threshold = 5
mf = MF(ss.freq_array, threshold, streak=True, narrow=False)
mf = MF(ss.freq_array, args.threshold, streak=True, narrow=False)
mf.apply_match_test(ins)
ins.sig_array[~np.isfinite(ins.sig_array)] = 0

@@ -66,53 +86,147 @@
cmap = mpl.colormaps.get_cmap("viridis")
cmap.set_bad(color="#00000000")

fontsize = 8
figsize = (16, 16)
pols = ss.get_pols()
subplot_rows = 2

subplots = plt.subplots(
subplot_rows,
len(pols),
figsize=figsize,
sharex=True,
sharey=True,
)[1]

time_labels = [*Time(np.unique(ss.time_array), format="jd").iso]
ntimes_visible = figsize[1] * 72 / fontsize / 2.4 / subplot_rows
time_stride = int(max(1, len(time_labels) // ntimes_visible))
chan_labels = [f"{ch: 8.3f}" for ch in ss.freq_array / 1e6]
nchans_visible = figsize[0] * 72 / fontsize / 2.0 / len(pols)
chan_stride = int(max(1, len(chan_labels) // nchans_visible))

for i, pol in enumerate(ss.get_pols()):
ax_met, ax_sig = subplots[:, i]

ax_met.set_title(
f'{base.split("/")[-1]} ss vis amps {"diff " if diff else " "}{pol}'
)
ax_met.imshow(
ins.metric_array[..., i], aspect="auto", interpolation="none", cmap=cmap
)
ax_sig.set_title(f'{base.split("/")[-1]} z-score {pol}')
ax_sig.imshow(ins.sig_array[..., i], aspect="auto", interpolation="none", cmap=cmap)
if i == 0:
ax_met.set_ylabel("Timestep [UTC]")
ax_met.set_yticks(np.arange(ss.Ntimes)[::time_stride])
ax_met.set_yticklabels(
time_labels[::time_stride], fontsize=fontsize, fontfamily="monospace"
gps_times = np.unique(Time(ss.time_array, format="jd").gps)
int_time = np.unique(ss.integration_time)[0]
obsid = round(gps_times[0] - int_time / 2)
title = f"{obsid}"
plt.title(title)

def plot_sigchain(ss, args):
ant_mask = np.where(np.isin(ss.antenna_numbers, unflagged_ants))[0]
antenna_numbers = np.array(ss.antenna_numbers)[ant_mask]
antenna_names = np.array(ss.antenna_names)[ant_mask]

# build a scores array for each signal chain
scores = np.zeros((len(unflagged_ants), ss.Ntimes, ss.Nfreqs, ss.Npols))
for ant_idx, (ant_num, ant_name) in enumerate(zip(antenna_numbers, antenna_names)):
print(ant_idx, ant_num, ant_name)
if ant_num not in unflagged_ants:
continue
# select only the auto-correlation for this antenna
ssa = ss.select(antenna_nums=[(ant_num, ant_num)], inplace=False)
ins = INS(ssa, spectrum_type="auto")
mf.apply_match_test(ins)
ins.sig_array[~np.isfinite(ins.sig_array)] = 0
scores[ant_idx] = ins.sig_array

plt.style.use("dark_background")
subplots = plt.subplots(
2,
len(pols),
height_ratios=[len(unflagged_ants), ss.Nfreqs],
)[1]

def slice(scores, axis):
pscore = np.sqrt(np.sum(scores**2, axis=-1))
# make per-antenna anomalies stand out more
# pscore = pscore / np.nanstd(pscore, axis=1)[:, np.newaxis]
# subtract minimum value
# pscore -= np.nanmin(pscore)
return pscore

# pad names
namelen = max(len(name) for name in antenna_names)
antnames = [f"{name: <{namelen}}" for name in antenna_names]
channames = [f"{ch: 8.4f}" for ch in ss.freq_array / 1e6]

for i, pol in enumerate(ss.get_pols()):
ax_signal, ax_spectrum = subplots[:, i]
title = f"{obsid} Autos z-score magnitude {pol}"

# by signal chain: [ant, time]
signal_pscore = slice(scores[..., i], axis=-1)
print(signal_pscore.shape)
signal_df = pd.DataFrame(
signal_pscore.transpose(), index=gps_times, columns=antnames
)
ax_sig.set_ylabel("Timestep [UTC]")
ax_sig.set_yticks(np.arange(ss.Ntimes)[::time_stride])
ax_sig.set_yticklabels(
time_labels[::time_stride], fontsize=fontsize, fontfamily="monospace"
signal_df.to_csv(
f"{base}.auto_sig.{pol}.tsv",
index_label="gps_time",
float_format="%.3f",
sep="\t",
)

ax_sig.set_xlabel("Frequency channel [MHz]")
ax_sig.set_xticks(np.arange(ss.Nfreqs)[::chan_stride])
ax_sig.set_xticklabels(chan_labels[::chan_stride], fontsize=fontsize, rotation=90)
ax_signal.imshow(signal_pscore, aspect="auto", interpolation="none", cmap=cmap)
ax_signal.set_ylabel("Antenna")
ax_spectrum.set_xlabel("Timestep")
ax_signal.set_yticks(np.arange(len(unflagged_ants)))
ax_signal.tick_params(pad=True)
ax_signal.set_yticklabels(antnames, fontsize=args.fontsize)
ax_signal.set_title(f"{obsid} Autos z-score row-normalized {pol}")

# by spectrum: [freq, time]
spectrum_pscore = slice(scores[..., i].transpose(2, 1, 0), axis=-1)
spectrum_df = pd.DataFrame(
spectrum_pscore.transpose(), index=gps_times, columns=channames
)
spectrum_df.to_csv(
f"{base}.auto_spx.{pol}.tsv",
index_label="gps_time",
float_format="%.3f",
sep="\t",
)
ax_spectrum.set_xlabel("Timestep")
ax_spectrum.set_ylabel("Frequency channel")
ax_spectrum.set_yticks(np.arange(ss.Nfreqs))
ax_spectrum.set_yticklabels(channames, fontsize=args.fontsize)
ax_spectrum.tick_params(pad=True)
ax_spectrum.imshow(spectrum_pscore, aspect="auto", interpolation="none", cmap=cmap)

def plot_spectrum(ss, args):
pols = ss.get_pols()
subplot_rows = 2

subplots = plt.subplots(
subplot_rows,
len(pols),
sharex=True,
sharey=True,
)[1]

time_labels = [*Time(np.unique(ss.time_array), format="jd").iso]
ntimes_visible = 16 * 72 / args.fontsize / 2.4 / subplot_rows
time_stride = int(max(1, len(time_labels) // ntimes_visible))
chan_labels = [f"{ch: 8.3f}" for ch in ss.freq_array / 1e6]
nchans_visible = 16 * 72 / args.fontsize / 2.0 / len(pols)
chan_stride = int(max(1, len(chan_labels) // nchans_visible))

for i, pol in enumerate(pols):
ax_met, ax_sig = subplots[:, i]

ax_met.set_title(
f'{base.split("/")[-1]} ss vis amps{suffix} {pol}'
)
ax_met.imshow(
ins.metric_array[..., i], aspect="auto", interpolation="none", cmap=cmap
)
ax_sig.set_title(f'{base.split("/")[-1]} z-score {pol}')
ax_sig.imshow(ins.sig_array[..., i], aspect="auto", interpolation="none", cmap=cmap)
if i == 0:
ax_met.set_ylabel("Timestep [UTC]")
ax_met.set_yticks(np.arange(ss.Ntimes)[::time_stride])
ax_met.set_yticklabels(
time_labels[::time_stride], fontsize=args.fontsize, fontfamily="monospace"
)
ax_sig.set_ylabel("Timestep [UTC]")
ax_sig.set_yticks(np.arange(ss.Ntimes)[::time_stride])
ax_sig.set_yticklabels(
time_labels[::time_stride], fontsize=args.fontsize, fontfamily="monospace"
)

ax_sig.set_xlabel("Frequency channel [MHz]")
ax_sig.set_xticks(np.arange(ss.Nfreqs)[::chan_stride])
ax_sig.set_xticklabels(chan_labels[::chan_stride], fontsize=args.fontsize, rotation=90)

if args.sigchain:
plot_sigchain(ss, args)
plt.gcf().set_size_inches(8 * len(pols), (len(unflagged_ants) + ss.Ntimes) * args.fontsize / 72)
figname = f"{base}{suffix}.sigchain.png"
else:
plot_spectrum(ss, args)
plt.gcf().set_size_inches(8 * len(pols), 16)
figname = f"{base}{suffix}.spectrum.png"

figname = f"{base}{suffix}.auto_spectrum.png"
plt.savefig(figname, bbox_inches="tight")
print(f"wrote {figname}")
print(f"wrote {figname}")

0 comments on commit 52bbc61

Please sign in to comment.