diff --git a/demo/04_ssins.py b/demo/04_ssins.py index 20025f7..985b55a 100755 --- a/demo/04_ssins.py +++ b/demo/04_ssins.py @@ -11,33 +11,39 @@ from matplotlib import pyplot as plt import numpy as np from astropy.time import Time -import argparse +from argparse import ArgumentParser, SUPPRESS from itertools import groupby import re +import time +from pathlib import Path from matplotlib.axis import Axis def get_parser(): - parser = argparse.ArgumentParser() + parser = ArgumentParser() # arguments for SS.read() group_read = parser.add_argument_group("SS.read") group_read.add_argument( "files", nargs="+", - help="raw .fits (with .metafits), .uvfits supported", + help="Raw .fits (with .metafits), .uvfits, .uvh5, .ms supported", ) - group_read.add_argument( + group_mutex = group_read.add_mutually_exclusive_group() + group_mutex.add_argument("--diff", default=True, help=SUPPRESS) + group_mutex.add_argument( "--no-diff", - default=False, - action="store_true", - help="don't difference visibilities in time (sky-subtract)", + action="store_false", + dest="diff", + help="Don't difference visibilities in time (sky-subtract)", ) - group_read.add_argument( + group_mutex = group_read.add_mutually_exclusive_group() + group_mutex.add_argument("--flag-init", default=True, help=SUPPRESS) + group_mutex.add_argument( "--no-flag-init", - default=False, - action="store_true", - help="skip flagging of edge channels, quack time", + action="store_false", + dest="flag_init", + help="Skip flagging of quack time, edge channels", ) group_read.add_argument( "--remove-coarse-band", @@ -51,10 +57,12 @@ def get_parser(): action="store_true", help="Correct van vleck quantization artifacts in legacy correlator. slow!", ) - group_read.add_argument( + group_mutex = group_read.add_mutually_exclusive_group() + group_mutex.add_argument("--remove-flagged-ants", default=True, help=SUPPRESS) + group_mutex.add_argument( "--include-flagged-ants", - default=False, - action="store_true", + action="store_false", + dest="remove_flagged_ants", help="Include flagged antenna when reading raw files", ) group_read.add_argument( @@ -62,7 +70,8 @@ def get_parser(): default=None, nargs=1, choices=["original"], - help="original = apply flags from visibilities before running ssins (not recommended)", + type=str, + help="original = apply flags from visibilities before running ssins (only recommended for --plot-type=flags)", ) # arguments for SS.select() @@ -92,14 +101,14 @@ def get_parser(): ) # arguments for SSINS.INS - parser.add_argument( + group_ins = parser.add_argument_group("SSINS.INS") + group_ins.add_argument( "--spectrum-type", default="auto", choices=["auto", "cross"], help="analyse auto-correlations or cross-correlations. default: auto", ) - - parser.add_argument( + group_ins.add_argument( "--crosses", action="store_const", const="cross", @@ -107,36 +116,67 @@ def get_parser(): help="shorthand for --spectrum-type=cross", ) - parser.add_argument( - "--sigchain", - default=False, - action="store_true", - help="analyse z-scores for each tile and sum", - ) - # arguments for SSINS.MF - parser.add_argument( + group_mf = parser.add_argument_group("SSINS.MF") + group_mf.add_argument( "--threshold", default=5, type=float, help="match filter significance threshold. 0 disables match filter", ) - parser.add_argument( + group_mutex = group_mf.add_mutually_exclusive_group() + group_mutex.add_argument("--narrow", default=True, help=SUPPRESS) + group_mutex.add_argument( "--no-narrow", - default=False, - help="Don't look for narroband RFI", + action="store_false", + dest="narrow", + help="Don't look for narrowband RFI", + ) + group_mutex = group_mf.add_mutually_exclusive_group() + group_mutex.add_argument("--streak", default=True, help=SUPPRESS) + group_mutex.add_argument( + "--no-streak", + action="store_false", + dest="streak", + help="Don't look for streak RFI", + ) + + # plotting + group_plot = parser.add_argument_group("plotting") + group_plot.add_argument( + "--plot-type", + default="spectrum", + choices=["spectrum", "sigchain", "flags"], + ) + group_plot.add_argument( + "--sigchain", + action="store_const", + const="sigchain", + dest="plot_type", + help="analyse z-scores for each tile and sum", + ) + group_plot.add_argument( + "--flags", + action="store_const", + const="flags", + dest="plot_type", + help="analyse flag occupancy", ) - # other + group_plot.add_argument( + "--cmap", + default="viridis", + help="matplotlib.colormaps.get_cmap, default: viridis", + ) - parser.add_argument( + group_plot.add_argument( "--suffix", default="", type=str, - help="additional text to add to filename", + help="additional text to add to plot output filename", ) + group_plot.add_argument("--fontsize", default=8, help="plot tick label font size") - parser.add_argument("--fontsize", default=8, help="plot font size") return parser @@ -202,14 +242,18 @@ def get_unflagged_ants(ss: UVData, args): all_ant_names = np.array(ss.antenna_names) present_ant_nums = np.unique(ss.ant_1_array) present_ant_mask = np.where(np.isin(all_ant_nums, present_ant_nums))[0] - present_ant_names = np.array([*map(str.upper, all_ant_names[present_ant_mask])]) + + def sanitize(s: str): + return s.upper().strip() + + present_ant_names = np.array([*map(sanitize, all_ant_names[present_ant_mask])]) assert len(present_ant_nums) == len(present_ant_names) if args.sel_ants: - sel_ants = np.array([*map(str.upper, args.sel_ants)]) + sel_ants = np.array([*map(sanitize, args.sel_ants)]) return present_ant_nums[np.where(np.isin(present_ant_names, sel_ants))[0]] elif args.skip_ants: - skip_ants = np.array([*map(str.upper, args.skip_ants)]) + skip_ants = np.array([*map(sanitize, args.skip_ants)]) return present_ant_nums[np.where(~np.isin(present_ant_names, skip_ants))[0]] return present_ant_nums @@ -222,7 +266,7 @@ def get_gps_times(uvd: UVData): def get_suffix(args): suffix = args.suffix suffix = f".{args.spectrum_type}{suffix}" - if not args.no_diff: + if args.diff: suffix = f".diff{suffix}" if len(args.sel_ants) == 1: suffix = f"{suffix}.{args.sel_ants[0]}" @@ -237,7 +281,12 @@ def get_match_filter(ss, args): """ https://ssins.readthedocs.io/en/latest/match_filter.html """ - return MF(ss.freq_array, args.threshold, streak=True, narrow=(not args.no_narrow)) + return MF( + freq_array=ss.freq_array, + sig_thresh=args.threshold, + streak=args.streak, + narrow=args.narrow, + ) # #### # @@ -330,6 +379,10 @@ def slice(scores, axis): ], ) + plt.gcf().set_size_inches( + 8 * len(pols), (len(unflagged_ants) + ss.Nfreqs) * args.fontsize / 72 + ) + def plot_spectrum(ss, args, obsname, suffix, cmap): # incoherent noise spectrum https://ssins.readthedocs.io/en/latest/incoherent_noise_spectrum.html @@ -348,17 +401,73 @@ def plot_spectrum(ss, args, obsname, suffix, cmap): len(pols), sharex=True, sharey=True, - )[ - 1 - ].reshape((2, len(pols))) + )[1] + subplots = subplots.reshape((2, len(pols))) for i, pol in enumerate(pols): - # axis for metric being plotted - ax_met: Axis = subplots[0, i] + ax_mets = [ + ("vis_amps", ins.metric_array[..., i]), + ("z_score", ins.sig_array[..., i]), + ] + + for a, (name, metric) in enumerate(ax_mets): + ax: Axis = subplots[a, i] + ax.set_title(f"{obsname} {name}{suffix} {pol if len(pols) > 1 else ''}") + ax.imshow( + metric, + aspect="auto", + interpolation="none", + cmap=cmap, + extent=[ + np.min(freqs_mhz), + np.max(freqs_mhz), + np.max(gps_times), + np.min(gps_times), + ], + ) + + if i == 0: + ax.set_ylabel("GPS Time [s]") + + if a == len(ax_mets) - 1: + ax.set_xlabel("Frequency channel [MHz]") + + plt.gcf().set_size_inches(8 * len(pols), 16) + + +def plot_flags(ss: UVData, args, obsname, suffix, cmap): + pols = ss.get_pols() + gps_times = get_gps_times(ss) + freqs_mhz = (ss.freq_array) / 1e6 + + occupancy = np.sum( + ss.flag_array.reshape(ss.Ntimes, ss.Nbls, ss.Nspws, ss.Nfreqs, len(pols)), + axis=(1, 2), + ).astype(np.float64) + full_occupancy_value = ss.Nbls * ss.Nspws + occupancy[occupancy == full_occupancy_value] = np.nan + max_occupancy = np.nanmax(occupancy) + print(f"{max_occupancy=}") + # clip at half occupancy + # occupancy[occupancy <= full_occupancy_value / 2] = full_occupancy_value / 2 - ax_met.set_title(f"{obsname} vis amps{suffix} {pol if len(pols) > 1 else ''}") - ax_met.imshow( - ins.metric_array[..., i], + occupancy /= full_occupancy_value + + subplots = plt.subplots( + len(pols), + 1, + sharex=True, + sharey=True, + )[1] + if len(pols) == 1: + subplots = [subplots] + + for i, pol in enumerate(pols): + ax: Axis = subplots[i] + + ax.set_title(f"{obsname} occupancy{suffix} {pol if len(pols) > 1 else ''}") + ax.imshow( + occupancy[..., i], aspect="auto", interpolation="none", cmap=cmap, @@ -370,26 +479,65 @@ def plot_spectrum(ss, args, obsname, suffix, cmap): ], ) - # axis for significance - ax_sig: Axis = subplots[1, i] - ax_sig.set_title(f"{obsname} z-score{suffix} {pol if len(pols) > 1 else ''}") - ax_sig.imshow( - ins.sig_array[..., i], - aspect="auto", - interpolation="none", - cmap=cmap, - extent=[ - np.min(freqs_mhz), - np.max(freqs_mhz), - np.max(gps_times), - np.min(gps_times), - ], + ax.set_ylabel("GPS Time [s]") + + if i == len(pols) - 1: + ax.set_xlabel("Frequency channel [MHz]") + + plt.gcf().set_size_inches(16, np.min([9, 4 * len(pols)])) + + +def du_bs(path: Path, bs=1024 * 1024): + if path.is_file(): + return path.stat().st_size / bs + return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / bs + + +def read_raw(uvd: UVData, metafits, raw_fits, read_kwargs): + file_sizes_mb = {f: du_bs(Path(f)) for f in raw_fits} + total_size_mb = sum(file_sizes_mb.values()) + print(f"reading total {int(total_size_mb)}MB of raw files") + + start = time.time() + if len(raw_fits) <= 1: + uvd.read([metafits, *raw_fits], read_data=True, **read_kwargs) + read_time = time.time() - start + print(f"read took {int(read_time)}s. {int(total_size_mb/read_time)} MB/s") + return uvd + + # group and read raw by channel to save memory + raw_channel_groups = group_raw_by_channel(metafits, raw_fits) + # channels not always aligned in time + times = mwalib_get_common_times(metafits, raw_fits) + time_array = times.jd.astype(float) + n_chs = len(raw_channel_groups) + for ch_idx, ch in enumerate(sorted([*raw_channel_groups.keys()])): + # make a new UVData object, the same type as + uvd_ = type(uvd)() + channel_raw_fits = raw_channel_groups[ch] + channel_size_mb = sum([file_sizes_mb[f] for f in channel_raw_fits]) + print( + f"reading channel {ch}: {int(channel_size_mb)}MB of raw files ({ch_idx+1} of {n_chs})" ) - if i == 0: - ax_met.set_ylabel("GPS Time [s]") - ax_sig.set_ylabel("GPS Time [s]") + ch_start = time.time() + uvd_.read( + [metafits, *raw_channel_groups[ch]], + read_data=True, + times=time_array, + **read_kwargs, + ) + read_time = time.time() - ch_start + print( + f"reading channel {ch} took {int(read_time)}s. {int(channel_size_mb/read_time)} MB/s" + ) + if uvd.data_array is None: + uvd = uvd_ + else: + uvd.__add__(uvd_, inplace=True) - ax_sig.set_xlabel("Frequency channel [MHz]") + read_time = time.time() - start + print(f"read took {int(read_time)}s. {int(total_size_mb/read_time)} MB/s") + return uvd def main(): @@ -402,21 +550,24 @@ def main(): print(f"reading from {file_groups=}") # sky-subtract https://ssins.readthedocs.io/en/latest/sky_subtract.html ss = SS() + + flag_choice = args.flag_choice + if type(flag_choice) is list: + flag_choice = flag_choice[0] read_kwargs = { - "diff": (not args.no_diff), # difference timesteps + "diff": args.diff, # difference timesteps "remove_coarse_band": args.remove_coarse_band, # does not work with low freq res "correct_van_vleck": args.correct_van_vleck, # slow - "remove_flagged_ants": (not args.include_flagged_ants), # remove flagged ants - "flag_init": (not args.no_flag_init), + "remove_flagged_ants": args.remove_flagged_ants, # remove flagged ants + "flag_init": args.flag_init, "ant_str": args.spectrum_type, - "flag_choice": args.flag_choice, + "flag_choice": flag_choice, } # output name is basename of metafits, first uvfits or first ms if provided base = None # metafits and mwaf flag files only used if raw fits supplied - metafits = None - raw_fits = None + other_types = set(file_groups.keys()) - set([".fits", ".metafits"]) if ".fits" in file_groups: if ".metafits" not in file_groups: raise UserWarning(f"fits supplied, but no metafits in {args.files}") @@ -424,64 +575,51 @@ def main(): raise UserWarning(f"multiple metafits supplied in {args.files}") metafits = file_groups[".metafits"][0] base, _ = splitext(metafits) - raw_fits = file_groups[".fits"] - if len(raw_fits) > 1: - times = mwalib_get_common_times(metafits, raw_fits) - time_array = times.jd.astype(float) - # group and read raw by channel to save memory - raw_channel_groups = group_raw_by_channel(metafits, raw_fits) - for ch in sorted([*raw_channel_groups.keys()]): - ss_ = type(ss)() - ss_.read( - [metafits, *raw_channel_groups[ch]], - read_data=True, - times=time_array, - **read_kwargs, - ) - if ss.data_array is None: - ss = ss_ - else: - ss.__add__(ss_, inplace=True) - else: - ss.read([metafits, *raw_fits], read_data=True, **read_kwargs) - elif ".uvfits" in file_groups and ".ms" in file_groups: - raise UserWarning(f"both ms and uvfits in {args.files}") - elif ".uvfits" in file_groups or ".ms" in file_groups: - vis = file_groups.get(".uvfits", []) + file_groups.get(".ms", []) + ss = read_raw(ss, metafits, file_groups[".fits"], read_kwargs) + elif len(other_types) > 1: + raise UserWarning(f"multiple file types found ({[*other_types]}) {args.files}") + elif len(other_types.intersection([".uvfits", ".uvh5", ".ms"])) == 1: + vis = sum( + [ + file_groups.get(".uvfits", []), + file_groups.get(".uvh5", []), + file_groups.get(".ms", []), + ], + start=[], + ) base, _ = os.path.splitext(vis[0]) + + total_size_mb = sum(du_bs(Path(f)) for f in vis) + print(f"reading total {int(total_size_mb)}MB") + start = time.time() ss.read(vis, read_data=True, **read_kwargs) + read_time = time.time() - start + print(f"read took {int(read_time)}s. {int(total_size_mb/read_time)} MB/s") else: parser.print_usage() exit(1) - unflagged_ants = get_unflagged_ants(ss, args) - select_kwargs = {} if args.sel_pols: select_kwargs["polarizations"] = args.sel_pols ss.select(inplace=True, **select_kwargs) - ss.apply_flags(flag_choice=args.flag_choice) + # TODO: ss.apply_flags(flag_choice=flag_choice) ? plt.style.use("dark_background") - cmap = mpl.colormaps.get_cmap("viridis") + cmap = mpl.colormaps.get_cmap(args.cmap) cmap.set_bad(color="#00000000") suffix = get_suffix(args) - - pols = ss.get_pols() obsname = base.split("/")[-1] plt.suptitle(f"{obsname}{suffix}") - if args.sigchain: + if args.plot_type == "sigchain": plot_sigchain(ss, args, obsname, suffix, cmap) - plt.gcf().set_size_inches( - 8 * len(pols), (len(unflagged_ants) + ss.Nfreqs) * args.fontsize / 72 - ) - figname = f"{base}{suffix}.sigchain.png" - else: + elif args.plot_type == "spectrum": plot_spectrum(ss, args, obsname, suffix, cmap) - plt.gcf().set_size_inches(8 * len(pols), 16) - figname = f"{base}{suffix}.spectrum.png" + elif args.plot_type == "flags": + plot_flags(ss, args, obsname, suffix, cmap) + figname = f"{base}{suffix}.{args.plot_type}.png" plt.savefig(figname, bbox_inches="tight") print(f"wrote {figname}")