diff --git a/demo/04_ssins.py b/demo/04_ssins.py index f91b109..deaf3db 100755 --- a/demo/04_ssins.py +++ b/demo/04_ssins.py @@ -478,9 +478,18 @@ def plot_spectrum(ss, args, obsname, suffix, cmap): mf = get_match_filter(ss, args) apply_match_test(mf, ins, args) - ins.sig_array[~np.isfinite(ins.sig_array)] = 0 - + ins.sig_array[~np.isfinite(ins.sig_array)] = np.nan + ins.metric_array[~np.isfinite(ins.sig_array)] = np.nan pols = ss.get_pols() + + # occupancy = np.sum( + # ss.flag_array.reshape(ss.Ntimes, ss.Nbls, ss.Nspws, ss.Nfreqs, len(pols)), + # axis=(1, 2), + # ).astype(np.float64) + # full_occupancy_value = np.max(occupancy) + # ins.sig_array[occupancy > full_occupancy_value / 2] = np.nan + # ins.metric_array[occupancy > full_occupancy_value / 2] = np.nan + gps_times = get_gps_times(ss) freqs_mhz = (ss.freq_array) / 1e6 channames = [f"{ch: 8.4f}" for ch in freqs_mhz] @@ -533,6 +542,8 @@ def plot_spectrum(ss, args, obsname, suffix, cmap): plt.gcf().set_size_inches(8 * len(pols), 16) + return ins + def plot_flags(ss: UVData, args, obsname, suffix, cmap): pols = ss.get_pols() @@ -786,7 +797,11 @@ def main(): if args.plot_type == "sigchain": plot_sigchain(ss, args, obsname, suffix, cmap) elif args.plot_type == "spectrum": - plot_spectrum(ss, args, obsname, suffix, cmap) + ins = plot_spectrum(ss, args, obsname, suffix, cmap) + maskname = f"{base}{suffix}" + ins.write(f"{base}{suffix}", output_type="mask", clobber=True) + print(f"wrote {maskname}_SSINS_mask.h5") + elif args.plot_type == "flags": plot_flags(ss, args, obsname, suffix, cmap)