-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add script to process sota; add interactive leaderboard
- Loading branch information
Showing
17 changed files
with
355 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import plotly.graph_objs as go | ||
from pathlib import Path | ||
|
||
|
||
def generate_trace(df, model_name, metric, headline_var, color): | ||
"""Generate a single trace for a model, metric, and variable.""" | ||
y = df[headline_var].to_numpy().squeeze() | ||
return go.Scatter( | ||
x=np.arange(1, df.shape[0] + 1), | ||
y=y, | ||
mode='lines', | ||
name=f"{model_name} vs. ERA5", | ||
customdata=[f"{metric}-{headline_var}"], | ||
line=dict(width=3, color=color), | ||
visible=(metric == 'rmse' and headline_var == 't-850') | ||
) | ||
|
||
|
||
def update_visibility(fig, selected_metric, selected_variable): | ||
"""Update trace visibility based on selected metric and variable.""" | ||
selected_combo = f"{selected_metric}-{selected_variable}" | ||
return [ | ||
trace.customdata[0] == selected_combo | ||
for trace in fig.data | ||
] | ||
|
||
|
||
def create_dropdown_buttons(metrics, headline_vars, fig): | ||
"""Create dropdown buttons for metric-variable combinations.""" | ||
return [ | ||
{ | ||
"method": "update", | ||
"label": f"{metric.upper()} - {var.capitalize()}", | ||
"args": [ | ||
{"visible": update_visibility(fig, metric, var)} | ||
], | ||
} | ||
for metric in metrics | ||
for var in headline_vars.keys() | ||
] | ||
|
||
|
||
def configure_layout(fig, title, metrics, headline_vars): | ||
"""Configure layout and dropdown menu for the figure.""" | ||
fig.update_layout( | ||
updatemenus=[{ | ||
"buttons": create_dropdown_buttons(metrics, headline_vars, fig), | ||
"direction": "down", | ||
"showactive": True, | ||
"x": 0, | ||
"xanchor": "left", | ||
"y": 1.1, | ||
"yanchor": "top", | ||
"name": "Metric-Variable" | ||
}], | ||
title=title, | ||
xaxis_title="Number of Days Ahead", | ||
hovermode="x unified" | ||
) | ||
|
||
|
||
def save_figure(fig, output_dir, filename): | ||
"""Save the figure as an interactive HTML file.""" | ||
output_path = output_dir / filename | ||
fig.write_html(output_path) | ||
|
||
|
||
def plot_metrics(metrics, headline_vars, model_names, data_path, output_dir, title, filename, ensemble=False): | ||
"""Generate interactive plots for the given metrics, models, and variables.""" | ||
fig = go.Figure() | ||
linecolors = [ | ||
'black', '#1f77b4', '#ff7f0e', '#2ca02c', | ||
'#d62728', '#9467bd', '#8c564b', '#e377c2' | ||
] | ||
|
||
for metric in metrics: | ||
for model_idx, model_name in enumerate(model_names): | ||
color = linecolors[model_idx % len(linecolors)] | ||
for headline_var in headline_vars: | ||
if ensemble: | ||
csv_path = data_path / f"{model_name}_ensemble/eval/{metric}_{model_name}.csv" | ||
else: | ||
csv_path = data_path / f"{model_name}/eval/{metric}_{model_name}.csv" | ||
if csv_path.exists(): | ||
df = pd.read_csv(csv_path) | ||
fig.add_trace(generate_trace(df, model_name, metric, headline_var, color)) | ||
|
||
configure_layout(fig, title, metrics, headline_vars) | ||
save_figure(fig, output_dir, filename) | ||
|
||
|
||
def main(): | ||
""" | ||
Main driver to generate interactive HTML for metrics display | ||
Usage example: `python compute_climatology.py --dataset_name era5 --is_spatial 0` | ||
""" | ||
|
||
output_dir = Path('../website/html') | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
control_model_names = [ | ||
'climatology', 'panguweather', 'graphcast', 'fourcastnetv2', | ||
'ecmwf', 'ncep', 'ukmo', 'cma', | ||
] | ||
ensemble_model_names = ['ecmwf', 'ncep', 'ukmo', 'cma'] | ||
|
||
headline_vars = {'t-850': 'K', 'z-500': 'gpm', 'q-700': 'g/kg'} | ||
|
||
# Control (deterministic metrics) plot | ||
plot_metrics( | ||
metrics=['rmse', 'acc'], | ||
headline_vars=headline_vars, | ||
model_names=control_model_names, | ||
data_path=Path('../logs'), | ||
output_dir=output_dir, | ||
title="Control (Deterministic Metrics)", | ||
filename="control.html" | ||
) | ||
|
||
# Ensemble (probabilistic metrics) plot | ||
plot_metrics( | ||
metrics=['rmse', 'crpss'], | ||
headline_vars=headline_vars, | ||
model_names=ensemble_model_names, | ||
data_path=Path('../logs'), | ||
output_dir=output_dir, | ||
title="Ensemble (Probabilistic Metrics)", | ||
filename="ensemble.html", | ||
ensemble=True | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import argparse | ||
import subprocess | ||
from pathlib import Path | ||
import config | ||
|
||
import numpy as np | ||
import xarray as xr | ||
|
||
def process_param_levels(ds): | ||
"""Flatten variable/level and return list of available parameters""" | ||
|
||
flat_ds = {} | ||
|
||
for data_var in list(ds.data_vars): | ||
curr_ds = ds[data_var] | ||
data_levels = curr_ds.isobaricInhPa.values | ||
|
||
for data_level in data_levels: | ||
flat_varname = f'{data_var}-{int(data_level)}' | ||
flat_ds[flat_varname] = curr_ds.sel(isobaricInhPa=data_level).drop_vars('isobaricInhPa') | ||
|
||
return xr.Dataset(flat_ds) | ||
|
||
|
||
def main(args): | ||
""" | ||
Main driver to download S2S data-driven benchmark data based | ||
For now, we only process bi-weekly forecasts (decorrelated timescale) | ||
But the script provided here can be easily extended for any resolution, with minor modifications... | ||
The script relies on the excellent `ai-models` package maintained by the ECMWF | ||
README: to setup, follow instructions from https://github.com/ecmwf-lab/ai-models. | ||
you might also need to setup CDS API, follow instructions from https://cds.climate.copernicus.eu/how-to-api | ||
Usage example: | ||
(1) Panguweather : `python process_sota.py --model_name panguweather --years 2022` | ||
(2) Graphcast : `python process_sota.py --model_name graphcast --years 2022` | ||
(3) FourcastNetV2 : `python process_sota.py --model_name fourcastnetv2 --years 2022` | ||
""" | ||
assert args.model_name in ['fourcastnetv2', 'panguweather', 'graphcast'] | ||
|
||
output_dir = config.DATA_DIR / f'{args.model_name}' | ||
asset_dir = output_dir / 'assets' | ||
|
||
model_code = f'{args.model_name}-small' if args.model_name == 'fourcastnetv2' else args.model_name | ||
mm_dds = ['0101', '0115', '0201', '0215', '0301', '0315', '0401', '0415', '0501', '0515', '0601', '0615', | ||
'0701', '0715', '0801', '0815', '0901', '0915', '1001', '1015', '1101', '1115', '1201', '1215'] | ||
for year in args.years: | ||
year = int(year) | ||
|
||
# NOTE: Feel free to relax this monthly implementation to get e.g., daily forecasts | ||
# At present (2024): biweekly (decorrelated timescale) | ||
# The eval_sota.py script should be able to automatically handle different forecast frequency | ||
for mm_dd in mm_dds: | ||
output_daily_file = output_dir / f'{args.model_name}_full_1.5deg_{year}{mm_dd}.zarr' | ||
|
||
if not output_daily_file.exists(): | ||
|
||
temp_file = output_dir / f'{args.model_name}_{year}{mm_dd}.grib' | ||
|
||
# Get prediction | ||
command = [ | ||
"ai-models", | ||
"--input", "cds", | ||
"--path", temp_file, | ||
"--assets", asset_dir, | ||
"--date", f"{year}{mm_dd}", | ||
"--time", "0000", | ||
"--lead-time", "1056", | ||
model_code | ||
] | ||
|
||
result = subprocess.run(command, capture_output=True, text=True) | ||
if result.returncode != 0: | ||
print(result.stderr) # Print any error message | ||
|
||
# Process prediction | ||
dataset = xr.open_dataset(temp_file, backend_kwargs={'filter_by_keys': {'typeOfLevel': 'isobaricInhPa'}}) | ||
dataset['z'] = dataset['z'] / config.G_CONSTANT # to gpm conversion | ||
dataset = process_param_levels(dataset) | ||
dataset = dataset.isel(step=slice(4, None, 4)) # every 24th-hour --> daily resolution | ||
dataset = dataset.coarsen(latitude=6, longitude=6, boundary='trim').mean() | ||
dataset = dataset.interp(latitude=np.linspace(dataset.latitude.values.max(), dataset.latitude.values.min(), 121)) | ||
|
||
# Break down into daily .zarr (cloud-optimized) | ||
dataset.to_zarr(output_daily_file) | ||
|
||
# Post-process (cleaning up files) | ||
idx_files = list(temp_file.parent.glob(f"{temp_file.stem}.*.idx")) | ||
for idx_file in idx_files: | ||
idx_file.unlink() | ||
|
||
temp_file.unlink() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model_name', help='Provide the name of the model, e.g., fourcastnetv2, graphcast, panguweather...') | ||
parser.add_argument('--years', nargs='+', help='Provide the years to evaluate on...') | ||
|
||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,4 @@ chapters: | |
sections: | ||
- file: metrics | ||
- file: baseline | ||
|
||
- file: leaderboard |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Leaderboard | ||
We provide leaderboard for both deterministic and probabilistic models/metrics. | ||
This interactive page is inspired by the WeatherBench 2 design (https://sites.research.google/weatherbench/) | ||
All results are evaluated on the year 2022. | ||
|
||
## Deterministic Models | ||
[Interactive Control](https://github.com/leap-stc/ChaosBench/tree/main/website/html/control.html) | ||
|
||
## Ensemble Models | ||
[Interactive Ensemble](https://github.com/leap-stc/ChaosBench/tree/main/website/html/ensemble.html) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters