Skip to content

Commit

Permalink
bench.py: update plotting for paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lgarrison committed Sep 10, 2024
1 parent 1a7a689 commit 4330088
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 48 deletions.
82 changes: 49 additions & 33 deletions scripts/bench.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

import matplotlib
Expand All @@ -22,14 +23,19 @@
DEFAULT_NF = None # 10**5
DEFAULT_DTYPE = 'f8'
DEFAULT_METHODS = ['cufinufft', 'finufft', 'astropy', 'finufft_par']
NTHREAD_MAX = nifty_ls.utils.get_avail_cpus()
NTHREAD_MAX = len(os.sched_getaffinity(0))
DEFAULT_FFTW = nifty_ls.finufft.FFTW_MEASURE
DEFAULT_EPS = 1e-9

plt.rcParams['font.family'] = 'serif'
plt.rcParams['mathtext.fontset'] = 'dejavuserif'


def do_nifty_finufft(*args, **kwargs):
return nifty_ls.finufft.lombscargle(
*args, **kwargs, finufft_kwargs={'fftw': DEFAULT_FFTW, 'eps': DEFAULT_EPS}
*args,
**kwargs,
finufft_kwargs={'fftw': DEFAULT_FFTW, 'eps': DEFAULT_EPS},
)


Expand Down Expand Up @@ -95,7 +101,7 @@ def run_one(
f0, df, Nf = nifty_ls.utils.validate_frequency_grid(None, fmax, Nf, t)

def func():
return METHODS[method](t, y, dy, f0, df, Nf, **kwargs)
return METHODS[method](t=t, y=y, dy=dy, fmin=f0, df=df, Nf=Nf, **kwargs)

res = {
'method': method,
Expand Down Expand Up @@ -123,9 +129,7 @@ def func():
def get_plot_kwargs(method, nthread_max=NTHREAD_MAX):
if method == 'finufft_par':
label = (
'nifty-ls (finufft)'
if nthread_max == 1
else 'nifty-ls (finufft, multi-threaded)'
'nifty-ls (finufft)' if nthread_max == 1 else 'nifty-ls (finufft, parallel)'
)
color = 'C1'
ls = '--'
Expand All @@ -138,13 +142,13 @@ def get_plot_kwargs(method, nthread_max=NTHREAD_MAX):
color = 'C2'
ls = '-'
elif method == 'astropy':
label = r'astropy (${\tt fast}$ method)'
label = r'Astropy (${\tt fast}$ method)'
color = 'C0'
ls = '-'
elif method == 'astropy_worst':
label = r'astropy (worst case)'
label = r'Astropy (worst case)'
color = 'C0'
ls = '--'
ls = ':'
else:
label = method
color = 'C3'
Expand Down Expand Up @@ -187,6 +191,12 @@ def cli():
@click.option(
'-o', '--results-file', default='bench_results.ecsv', help='File to save results to'
)
@click.option(
'-t',
'--nthread-max',
default=NTHREAD_MAX,
help='Maximum number of threads to use for parallel finufft',
)
def bench(
logmin,
logmax,
Expand All @@ -196,6 +206,7 @@ def bench(
sweep='Nf',
results_file='bench_results.ecsv',
batch_size=1,
nthread_max=NTHREAD_MAX,
):
# process args
dtype = np.dtype(dtype).type
Expand Down Expand Up @@ -223,19 +234,20 @@ def bench(
print(f'{method} took {res["time"]:.4g} sec ({Nf=})')
all_res.append(res)

for res in all_res:
del res['power']

all_res = Table(
all_res,
meta={
'sweep': sweep,
'nthread_max': NTHREAD_MAX,
'nthread_max': nthread_max,
'auto_Nf': DEFAULT_NF is None and sweep == 'N',
'batch_size': batch_size,
},
)

# compare(all_res)

del all_res['power']
all_res.write(results_file, overwrite=True)

plot_fname = Path(results_file).with_suffix('.png')
Expand All @@ -244,15 +256,17 @@ def bench(

@cli.command()
@click.argument('results_file')
def plot(results_file):
@click.option('--paper', is_flag=True, help='Use paper style')
def plot(results_file, paper=False):
results = ascii.read(results_file)
plot_fname = Path(results_file).with_suffix('.png')
_plot(results, fname=plot_fname)
_plot(results, fname=plot_fname, paper=paper)


def _plot(all_res: Table, sort=True, fname='bench_results.png'):
def _plot(all_res: Table, sort=True, fname='bench_results.png', paper=False):
all_res = all_res.group_by('method')
fig, ax = plt.subplots()
figsize = (w := 3.8, w / 1.3) if paper else (6, 4)
fig, ax = plt.subplots(figsize=figsize, layout='constrained')
ax: plt.Axes
sweep = all_res.meta['sweep']
const = 'N' if sweep == 'Nf' else 'Nf'
Expand All @@ -276,20 +290,21 @@ def _plot(all_res: Table, sort=True, fname='bench_results.png'):
if all_res.meta['batch_size'] > 1:
lines.append(f'Batch size = {all_res.meta["batch_size"]}')

if auto_Nf and var == 'N':
lines.append(f'{const_desc} $\\approx 12{var}$')
else:
lines.append(
f'{const_desc}: {eval(const+"_const")}',
if not paper:
if auto_Nf and var == 'N':
lines.append(f'{const_desc} $\\approx 12{var}$')
else:
lines.append(
f'{const_desc}: {eval(const+"_const")}',
)
lines = '\n'.join(lines)
ax.annotate(
lines,
xy=(0.99, 0.01),
xycoords='axes fraction',
ha='right',
va='bottom',
)
lines = '\n'.join(lines)
ax.annotate(
lines,
xy=(0.99, 0.01),
xycoords='axes fraction',
ha='right',
va='bottom',
)

xline = all_res[var].max()
yline = all_res['time'][all_res[var] == xline].min()
Expand All @@ -300,15 +315,16 @@ def _plot(all_res: Table, sort=True, fname='bench_results.png'):

ax.tick_params(right=True, top=True, which='both')

ax.legend()
ax.legend(fontsize='small', frameon=False)
if sweep == 'Nf':
ax.set_xlabel('Number of frequencies')
ax.set_xlabel('$N_f$')
elif sweep == 'N':
ax.set_xlabel('Number of data points $N$')
ax.set_xlabel('$N_d$')
ax.set_ylabel('Time (seconds)')
ax.set_xscale('log')
ax.set_yscale('log')
fig.savefig(fname, bbox_inches='tight')
fig.savefig(fname)
fig.savefig(Path(fname).with_suffix('.pdf'))


if __name__ == '__main__':
Expand Down
43 changes: 28 additions & 15 deletions scripts/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@ def run(dtype, methods, ref, output_file):

@cli.command()
@click.argument('results_file')
def analyze(results_file):
@click.option('--paper', is_flag=True, help='Generate plots for paper')
def analyze(results_file, paper=False):
with asdf.open(results_file, lazy_load=False, copy_arrays=True) as af:
tables = af['data']
plot_fname = Path(results_file).with_suffix('.png')
_analyze(tables, fname=plot_fname)
_analyze(tables, fname=plot_fname, paper=paper)


def _analyze(all_tables, fname, plot=True):
def _analyze(all_tables, fname, plot=True, paper=False):
# dtype = all_tables[0]['dtype'][0]

# each table is one N/Nf value
Expand Down Expand Up @@ -145,12 +146,13 @@ def _analyze(all_tables, fname, plot=True):

if plot:
# _plot_box(all_info, fname)
_plot_99(all_info, fname)
_plot_99(all_info, fname, paper=paper)


def _plot_99(all_info, fname):
def _plot_99(all_info, fname, paper=False):
"""Plot the median and 99th percentile error"""
fig, ax = plt.subplots()
figsize = (w := 3.8, w / 1.3) if paper else (6, 4)
fig, ax = plt.subplots(figsize=figsize, layout='constrained')
ax: plt.Axes
ax.set_yscale('log')
ax.set_xscale('log')
Expand All @@ -163,7 +165,7 @@ def _plot_99(all_info, fname):

lo = [info['bxp_stats']['med'] for info in all_info]
hi = [info['bxp_stats']['whishi'] for info in all_info]
o = 0.1
o = 0.125
offsets = {
'cufinufft': 1.0,
'finufft': 1.0 - o,
Expand All @@ -181,7 +183,7 @@ def _plot_99(all_info, fname):
ax.plot(x, lo[i], 'o', **plot_kwargs)
ax.plot(x, hi[i], 'v', **plot_kwargs)

ax.set_ylim(1e-15)
ax.set_ylim(1e-17)

# Make dummy entries for the v and o symbols
cap1 = ax.plot([], [], 'o', color='black', label='Median')
Expand All @@ -197,23 +199,35 @@ def _plot_99(all_info, fname):
lines['cufinufft'],
],
[
'Median & 99th percentile',
'astropy',
'astropy (worst case)',
'50th & 99th percentile',
'Astropy',
'Astropy (worst case)',
'finufft',
'cufinufft',
],
ncol=1,
ncol=2 if paper else 1,
loc='lower right',
handler_map={tuple: HandlerTuple(ndivide=None)},
fontsize='small',
)

ax.set_xlabel('Number of data points $N$')
ax.set_xlabel('$N_d$')
ax.set_ylabel('Fractional error')
ax.tick_params(right=True, top=True, which='both')
logymaxtick = int(np.log10(ax.get_ylim()[1]))
logymintick = int(np.ceil(np.log10(ax.get_ylim()[0])))
numyticks = logymaxtick - logymintick
ticks = np.logspace(logymintick, logymaxtick, numyticks + 1)
ax.set_yticks(
ticks,
labels=[
f'$10^{{{np.log10(t):.0f}}}$' if i % 3 == 0 else ''
for i, t in enumerate(ticks)
],
)

fig.tight_layout()
fig.savefig(fname)
fig.savefig(Path(fname).with_suffix('.pdf'))


def _plot_box(all_info, fname):
Expand All @@ -236,7 +250,6 @@ def _plot_box(all_info, fname):
# ax.set_xticklabels([f'{method}\nN={N}, Nf={Nf}' for method, N, Nf in zip(methods, Ns, Nfs)], rotation=45)
ax.set_ylabel('Fractional error')

fig.tight_layout()
fig.savefig(fname)


Expand Down

0 comments on commit 4330088

Please sign in to comment.