From eef63576c528e9ea016f539d576d242e35148b45 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Thu, 29 Aug 2024 14:54:14 -0400 Subject: [PATCH] bench.py: update plotting for paper --- scripts/bench.py | 82 +++++++++++++++++++++++++++++------------------- scripts/sweep.py | 43 ++++++++++++++++--------- 2 files changed, 77 insertions(+), 48 deletions(-) diff --git a/scripts/bench.py b/scripts/bench.py index e3e513a..d33cd29 100644 --- a/scripts/bench.py +++ b/scripts/bench.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import matplotlib @@ -22,14 +23,19 @@ DEFAULT_NF = None # 10**5 DEFAULT_DTYPE = 'f8' DEFAULT_METHODS = ['cufinufft', 'finufft', 'astropy', 'finufft_par'] -NTHREAD_MAX = nifty_ls.utils.get_avail_cpus() +NTHREAD_MAX = len(os.sched_getaffinity(0)) DEFAULT_FFTW = nifty_ls.finufft.FFTW_MEASURE DEFAULT_EPS = 1e-9 +plt.rcParams['font.family'] = 'serif' +plt.rcParams['mathtext.fontset'] = 'dejavuserif' + def do_nifty_finufft(*args, **kwargs): return nifty_ls.finufft.lombscargle( - *args, **kwargs, finufft_kwargs={'fftw': DEFAULT_FFTW, 'eps': DEFAULT_EPS} + *args, + **kwargs, + finufft_kwargs={'fftw': DEFAULT_FFTW, 'eps': DEFAULT_EPS}, ) @@ -95,7 +101,7 @@ def run_one( f0, df, Nf = nifty_ls.utils.validate_frequency_grid(None, fmax, Nf, t) def func(): - return METHODS[method](t, y, dy, f0, df, Nf, **kwargs) + return METHODS[method](t=t, y=y, dy=dy, fmin=f0, df=df, Nf=Nf, **kwargs) res = { 'method': method, @@ -123,9 +129,7 @@ def func(): def get_plot_kwargs(method, nthread_max=NTHREAD_MAX): if method == 'finufft_par': label = ( - 'nifty-ls (finufft)' - if nthread_max == 1 - else 'nifty-ls (finufft, multi-threaded)' + 'nifty-ls (finufft)' if nthread_max == 1 else 'nifty-ls (finufft, parallel)' ) color = 'C1' ls = '--' @@ -138,13 +142,13 @@ def get_plot_kwargs(method, nthread_max=NTHREAD_MAX): color = 'C2' ls = '-' elif method == 'astropy': - label = r'astropy (${\tt fast}$ method)' + label = r'Astropy (${\tt fast}$ method)' color = 'C0' ls = '-' elif method == 'astropy_worst': - label = r'astropy (worst case)' + label = r'Astropy (worst case)' color = 'C0' - ls = '--' + ls = ':' else: label = method color = 'C3' @@ -187,6 +191,12 @@ def cli(): @click.option( '-o', '--results-file', default='bench_results.ecsv', help='File to save results to' ) +@click.option( + '-t', + '--nthread-max', + default=NTHREAD_MAX, + help='Maximum number of threads to use for parallel finufft', +) def bench( logmin, logmax, @@ -196,6 +206,7 @@ def bench( sweep='Nf', results_file='bench_results.ecsv', batch_size=1, + nthread_max=NTHREAD_MAX, ): # process args dtype = np.dtype(dtype).type @@ -223,19 +234,20 @@ def bench( print(f'{method} took {res["time"]:.4g} sec ({Nf=})') all_res.append(res) + for res in all_res: + del res['power'] + all_res = Table( all_res, meta={ 'sweep': sweep, - 'nthread_max': NTHREAD_MAX, + 'nthread_max': nthread_max, 'auto_Nf': DEFAULT_NF is None and sweep == 'N', 'batch_size': batch_size, }, ) # compare(all_res) - - del all_res['power'] all_res.write(results_file, overwrite=True) plot_fname = Path(results_file).with_suffix('.png') @@ -244,15 +256,17 @@ def bench( @cli.command() @click.argument('results_file') -def plot(results_file): +@click.option('--paper', is_flag=True, help='Use paper style') +def plot(results_file, paper=False): results = ascii.read(results_file) plot_fname = Path(results_file).with_suffix('.png') - _plot(results, fname=plot_fname) + _plot(results, fname=plot_fname, paper=paper) -def _plot(all_res: Table, sort=True, fname='bench_results.png'): +def _plot(all_res: Table, sort=True, fname='bench_results.png', paper=False): all_res = all_res.group_by('method') - fig, ax = plt.subplots() + figsize = (w := 3.8, w / 1.3) if paper else (6, 4) + fig, ax = plt.subplots(figsize=figsize, layout='constrained') ax: plt.Axes sweep = all_res.meta['sweep'] const = 'N' if sweep == 'Nf' else 'Nf' @@ -276,20 +290,21 @@ def _plot(all_res: Table, sort=True, fname='bench_results.png'): if all_res.meta['batch_size'] > 1: lines.append(f'Batch size = {all_res.meta["batch_size"]}') - if auto_Nf and var == 'N': - lines.append(f'{const_desc} $\\approx 12{var}$') - else: - lines.append( - f'{const_desc}: {eval(const+"_const")}', + if not paper: + if auto_Nf and var == 'N': + lines.append(f'{const_desc} $\\approx 12{var}$') + else: + lines.append( + f'{const_desc}: {eval(const+"_const")}', + ) + lines = '\n'.join(lines) + ax.annotate( + lines, + xy=(0.99, 0.01), + xycoords='axes fraction', + ha='right', + va='bottom', ) - lines = '\n'.join(lines) - ax.annotate( - lines, - xy=(0.99, 0.01), - xycoords='axes fraction', - ha='right', - va='bottom', - ) xline = all_res[var].max() yline = all_res['time'][all_res[var] == xline].min() @@ -300,15 +315,16 @@ def _plot(all_res: Table, sort=True, fname='bench_results.png'): ax.tick_params(right=True, top=True, which='both') - ax.legend() + ax.legend(fontsize='small', frameon=False) if sweep == 'Nf': - ax.set_xlabel('Number of frequencies') + ax.set_xlabel('$N_f$') elif sweep == 'N': - ax.set_xlabel('Number of data points $N$') + ax.set_xlabel('$N_d$') ax.set_ylabel('Time (seconds)') ax.set_xscale('log') ax.set_yscale('log') - fig.savefig(fname, bbox_inches='tight') + fig.savefig(fname) + fig.savefig(Path(fname).with_suffix('.pdf')) if __name__ == '__main__': diff --git a/scripts/sweep.py b/scripts/sweep.py index 8934aa6..49c0679 100644 --- a/scripts/sweep.py +++ b/scripts/sweep.py @@ -81,14 +81,15 @@ def run(dtype, methods, ref, output_file): @cli.command() @click.argument('results_file') -def analyze(results_file): +@click.option('--paper', is_flag=True, help='Generate plots for paper') +def analyze(results_file, paper=False): with asdf.open(results_file, lazy_load=False, copy_arrays=True) as af: tables = af['data'] plot_fname = Path(results_file).with_suffix('.png') - _analyze(tables, fname=plot_fname) + _analyze(tables, fname=plot_fname, paper=paper) -def _analyze(all_tables, fname, plot=True): +def _analyze(all_tables, fname, plot=True, paper=False): # dtype = all_tables[0]['dtype'][0] # each table is one N/Nf value @@ -145,12 +146,13 @@ def _analyze(all_tables, fname, plot=True): if plot: # _plot_box(all_info, fname) - _plot_99(all_info, fname) + _plot_99(all_info, fname, paper=paper) -def _plot_99(all_info, fname): +def _plot_99(all_info, fname, paper=False): """Plot the median and 99th percentile error""" - fig, ax = plt.subplots() + figsize = (w := 3.8, w / 1.3) if paper else (6, 4) + fig, ax = plt.subplots(figsize=figsize, layout='constrained') ax: plt.Axes ax.set_yscale('log') ax.set_xscale('log') @@ -163,7 +165,7 @@ def _plot_99(all_info, fname): lo = [info['bxp_stats']['med'] for info in all_info] hi = [info['bxp_stats']['whishi'] for info in all_info] - o = 0.1 + o = 0.125 offsets = { 'cufinufft': 1.0, 'finufft': 1.0 - o, @@ -181,7 +183,7 @@ def _plot_99(all_info, fname): ax.plot(x, lo[i], 'o', **plot_kwargs) ax.plot(x, hi[i], 'v', **plot_kwargs) - ax.set_ylim(1e-15) + ax.set_ylim(1e-17) # Make dummy entries for the v and o symbols cap1 = ax.plot([], [], 'o', color='black', label='Median') @@ -197,23 +199,35 @@ def _plot_99(all_info, fname): lines['cufinufft'], ], [ - 'Median & 99th percentile', - 'astropy', - 'astropy (worst case)', + '50th & 99th percentile', + 'Astropy', + 'Astropy (worst case)', 'finufft', 'cufinufft', ], - ncol=1, + ncol=2 if paper else 1, loc='lower right', handler_map={tuple: HandlerTuple(ndivide=None)}, fontsize='small', ) - ax.set_xlabel('Number of data points $N$') + ax.set_xlabel('$N_d$') ax.set_ylabel('Fractional error') + ax.tick_params(right=True, top=True, which='both') + logymaxtick = int(np.log10(ax.get_ylim()[1])) + logymintick = int(np.ceil(np.log10(ax.get_ylim()[0]))) + numyticks = logymaxtick - logymintick + ticks = np.logspace(logymintick, logymaxtick, numyticks + 1) + ax.set_yticks( + ticks, + labels=[ + f'$10^{{{np.log10(t):.0f}}}$' if i % 3 == 0 else '' + for i, t in enumerate(ticks) + ], + ) - fig.tight_layout() fig.savefig(fname) + fig.savefig(Path(fname).with_suffix('.pdf')) def _plot_box(all_info, fname): @@ -236,7 +250,6 @@ def _plot_box(all_info, fname): # ax.set_xticklabels([f'{method}\nN={N}, Nf={Nf}' for method, N, Nf in zip(methods, Ns, Nfs)], rotation=45) ax.set_ylabel('Fractional error') - fig.tight_layout() fig.savefig(fname)