Skip to content

Commit

Permalink
add script to process sota; add interactive leaderboard
Browse files Browse the repository at this point in the history
  • Loading branch information
juannat7 committed Nov 12, 2024
1 parent 3eea5fc commit c064285
Show file tree
Hide file tree
Showing 17 changed files with 355 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ ChaosBench is a benchmark project to improve and extend the predictability range

## Benchmarking
- [Baseline Models](https://leap-stc.github.io/ChaosBench/baseline.html)
- [Leaderboard](https://leap-stc.github.io/ChaosBench/leaderboard.html)

## Citation
If you find any of the code and dataset useful, feel free to acknowledge our work through:
Expand Down
Binary file modified docs/center_acc.pdf
Binary file not shown.
Binary file modified docs/center_ens_acc.pdf
Binary file not shown.
Binary file modified docs/center_ratio_acc.pdf
Binary file not shown.
Binary file modified docs/sota_acc.pdf
Binary file not shown.
Binary file modified docs/sota_bias.pdf
Binary file not shown.
Binary file modified docs/sota_rmse.pdf
Binary file not shown.
Binary file modified docs/sota_sdiv.pdf
Binary file not shown.
Binary file modified docs/sota_ssim.pdf
Binary file not shown.
79 changes: 69 additions & 10 deletions notebooks/04a_s2s_eval_iter.ipynb

Large diffs are not rendered by default.

136 changes: 136 additions & 0 deletions scripts/generate_interactive_html.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from pathlib import Path


def generate_trace(df, model_name, metric, headline_var, color):
"""Generate a single trace for a model, metric, and variable."""
y = df[headline_var].to_numpy().squeeze()
return go.Scatter(
x=np.arange(1, df.shape[0] + 1),
y=y,
mode='lines',
name=f"{model_name} vs. ERA5",
customdata=[f"{metric}-{headline_var}"],
line=dict(width=3, color=color),
visible=(metric == 'rmse' and headline_var == 't-850')
)


def update_visibility(fig, selected_metric, selected_variable):
"""Update trace visibility based on selected metric and variable."""
selected_combo = f"{selected_metric}-{selected_variable}"
return [
trace.customdata[0] == selected_combo
for trace in fig.data
]


def create_dropdown_buttons(metrics, headline_vars, fig):
"""Create dropdown buttons for metric-variable combinations."""
return [
{
"method": "update",
"label": f"{metric.upper()} - {var.capitalize()}",
"args": [
{"visible": update_visibility(fig, metric, var)}
],
}
for metric in metrics
for var in headline_vars.keys()
]


def configure_layout(fig, title, metrics, headline_vars):
"""Configure layout and dropdown menu for the figure."""
fig.update_layout(
updatemenus=[{
"buttons": create_dropdown_buttons(metrics, headline_vars, fig),
"direction": "down",
"showactive": True,
"x": 0,
"xanchor": "left",
"y": 1.1,
"yanchor": "top",
"name": "Metric-Variable"
}],
title=title,
xaxis_title="Number of Days Ahead",
hovermode="x unified"
)


def save_figure(fig, output_dir, filename):
"""Save the figure as an interactive HTML file."""
output_path = output_dir / filename
fig.write_html(output_path)


def plot_metrics(metrics, headline_vars, model_names, data_path, output_dir, title, filename, ensemble=False):
"""Generate interactive plots for the given metrics, models, and variables."""
fig = go.Figure()
linecolors = [
'black', '#1f77b4', '#ff7f0e', '#2ca02c',
'#d62728', '#9467bd', '#8c564b', '#e377c2'
]

for metric in metrics:
for model_idx, model_name in enumerate(model_names):
color = linecolors[model_idx % len(linecolors)]
for headline_var in headline_vars:
if ensemble:
csv_path = data_path / f"{model_name}_ensemble/eval/{metric}_{model_name}.csv"
else:
csv_path = data_path / f"{model_name}/eval/{metric}_{model_name}.csv"
if csv_path.exists():
df = pd.read_csv(csv_path)
fig.add_trace(generate_trace(df, model_name, metric, headline_var, color))

configure_layout(fig, title, metrics, headline_vars)
save_figure(fig, output_dir, filename)


def main():
"""
Main driver to generate interactive HTML for metrics display
Usage example: `python compute_climatology.py --dataset_name era5 --is_spatial 0`
"""

output_dir = Path('../website/html')
output_dir.mkdir(parents=True, exist_ok=True)

control_model_names = [
'climatology', 'panguweather', 'graphcast', 'fourcastnetv2',
'ecmwf', 'ncep', 'ukmo', 'cma',
]
ensemble_model_names = ['ecmwf', 'ncep', 'ukmo', 'cma']

headline_vars = {'t-850': 'K', 'z-500': 'gpm', 'q-700': 'g/kg'}

# Control (deterministic metrics) plot
plot_metrics(
metrics=['rmse', 'acc'],
headline_vars=headline_vars,
model_names=control_model_names,
data_path=Path('../logs'),
output_dir=output_dir,
title="Control (Deterministic Metrics)",
filename="control.html"
)

# Ensemble (probabilistic metrics) plot
plot_metrics(
metrics=['rmse', 'crpss'],
headline_vars=headline_vars,
model_names=ensemble_model_names,
data_path=Path('../logs'),
output_dir=output_dir,
title="Ensemble (Probabilistic Metrics)",
filename="ensemble.html",
ensemble=True
)


if __name__ == "__main__":
main()
102 changes: 102 additions & 0 deletions scripts/process_sota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import subprocess
from pathlib import Path
import config

import numpy as np
import xarray as xr

def process_param_levels(ds):
"""Flatten variable/level and return list of available parameters"""

flat_ds = {}

for data_var in list(ds.data_vars):
curr_ds = ds[data_var]
data_levels = curr_ds.isobaricInhPa.values

for data_level in data_levels:
flat_varname = f'{data_var}-{int(data_level)}'
flat_ds[flat_varname] = curr_ds.sel(isobaricInhPa=data_level).drop_vars('isobaricInhPa')

return xr.Dataset(flat_ds)


def main(args):
"""
Main driver to download S2S data-driven benchmark data based
For now, we only process bi-weekly forecasts (decorrelated timescale)
But the script provided here can be easily extended for any resolution, with minor modifications...
The script relies on the excellent `ai-models` package maintained by the ECMWF
README: to setup, follow instructions from https://github.com/ecmwf-lab/ai-models.
you might also need to setup CDS API, follow instructions from https://cds.climate.copernicus.eu/how-to-api
Usage example:
(1) Panguweather : `python process_sota.py --model_name panguweather --years 2022`
(2) Graphcast : `python process_sota.py --model_name graphcast --years 2022`
(3) FourcastNetV2 : `python process_sota.py --model_name fourcastnetv2 --years 2022`
"""
assert args.model_name in ['fourcastnetv2', 'panguweather', 'graphcast']

output_dir = config.DATA_DIR / f'{args.model_name}'
asset_dir = output_dir / 'assets'

model_code = f'{args.model_name}-small' if args.model_name == 'fourcastnetv2' else args.model_name
mm_dds = ['0101', '0115', '0201', '0215', '0301', '0315', '0401', '0415', '0501', '0515', '0601', '0615',
'0701', '0715', '0801', '0815', '0901', '0915', '1001', '1015', '1101', '1115', '1201', '1215']
for year in args.years:
year = int(year)

# NOTE: Feel free to relax this monthly implementation to get e.g., daily forecasts
# At present (2024): biweekly (decorrelated timescale)
# The eval_sota.py script should be able to automatically handle different forecast frequency
for mm_dd in mm_dds:
output_daily_file = output_dir / f'{args.model_name}_full_1.5deg_{year}{mm_dd}.zarr'

if not output_daily_file.exists():

temp_file = output_dir / f'{args.model_name}_{year}{mm_dd}.grib'

# Get prediction
command = [
"ai-models",
"--input", "cds",
"--path", temp_file,
"--assets", asset_dir,
"--date", f"{year}{mm_dd}",
"--time", "0000",
"--lead-time", "1056",
model_code
]

result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
print(result.stderr) # Print any error message

# Process prediction
dataset = xr.open_dataset(temp_file, backend_kwargs={'filter_by_keys': {'typeOfLevel': 'isobaricInhPa'}})
dataset['z'] = dataset['z'] / config.G_CONSTANT # to gpm conversion
dataset = process_param_levels(dataset)
dataset = dataset.isel(step=slice(4, None, 4)) # every 24th-hour --> daily resolution
dataset = dataset.coarsen(latitude=6, longitude=6, boundary='trim').mean()
dataset = dataset.interp(latitude=np.linspace(dataset.latitude.values.max(), dataset.latitude.values.min(), 121))

# Break down into daily .zarr (cloud-optimized)
dataset.to_zarr(output_daily_file)

# Post-process (cleaning up files)
idx_files = list(temp_file.parent.glob(f"{temp_file.stem}.*.idx"))
for idx_file in idx_files:
idx_file.unlink()

temp_file.unlink()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', help='Provide the name of the model, e.g., fourcastnetv2, graphcast, panguweather...')
parser.add_argument('--years', nargs='+', help='Provide the years to evaluate on...')

args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion website/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ chapters:
sections:
- file: metrics
- file: baseline

- file: leaderboard
14 changes: 14 additions & 0 deletions website/html/control.html

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions website/html/ensemble.html

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions website/leaderboard.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Leaderboard
We provide leaderboard for both deterministic and probabilistic models/metrics.
This interactive page is inspired by the WeatherBench 2 design (https://sites.research.google/weatherbench/)
All results are evaluated on the year 2022.

## Deterministic Models
[Interactive Control](https://github.com/leap-stc/ChaosBench/tree/main/website/html/control.html)

## Ensemble Models
[Interactive Ensemble](https://github.com/leap-stc/ChaosBench/tree/main/website/html/ensemble.html)
11 changes: 8 additions & 3 deletions website/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,26 @@ $ chmod +x process.sh
```
**Step 3**: Download the data
```
# Required for inputs and climatology (e.g., normalization)
# Required for inputs and climatology (e.g., for normalization; 1979-)
$ ./process.sh era5
$ ./process.sh lra5
$ ./process.sh oras5
$ ./process.sh climatology
# Optional: control (deterministic) forecasts
# Optional: control (deterministic) forecasts (2018-)
$ ./process.sh ukmo
$ ./process.sh ncep
$ ./process.sh cma
$ ./process.sh ecmwf
# Optional: perturbed (ensemble) forecasts
# Optional: perturbed (ensemble) forecasts (2022-)
$ ./process.sh ukmo_ensemble
$ ./process.sh ncep_ensemble
$ ./process.sh cma_ensemble
$ ./process.sh ecmwf_ensemble
# Optional: state-of-the-art (deterministic) forecasts (2022-)
$ ./process.sh panguweather
$ ./process.sh fourcastnetv2
$ ./process.sh graphcast
```

0 comments on commit c064285

Please sign in to comment.