-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding nbsphinx + setting up sample notebook
- Loading branch information
Kalle Westerling
committed
Oct 11, 2023
1 parent
a225d22
commit a3a9bb8
Showing
4 changed files
with
156 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"nbsphinx": "hidden" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys, os\n", | ||
"sys.path.append(os.path.abspath(\"../../../\"))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Tutorial: Quickstart" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Here we will demonstrate a simple example of training a convolutional conditional neural process (ConvCNP) to spatially interpolate ERA5 data.\n", | ||
"\n", | ||
"We can go from imports to predictions with a trained model in less than 30 lines of code!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import deepsensor.torch\n", | ||
"from deepsensor.data.loader import TaskLoader\n", | ||
"from deepsensor.data.processor import DataProcessor\n", | ||
"from deepsensor.model.convnp import ConvNP\n", | ||
"from deepsensor.train.train import train_epoch\n", | ||
"\n", | ||
"import xarray as xr\n", | ||
"import pandas as pd\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"# Load raw data\n", | ||
"ds_raw = xr.tutorial.open_dataset(\"air_temperature\")\n", | ||
"\n", | ||
"# Normalise data\n", | ||
"data_processor = DataProcessor(x1_name=\"lat\", x1_map=(15, 75), x2_name=\"lon\", x2_map=(200, 330))\n", | ||
"ds = data_processor(ds_raw)\n", | ||
"\n", | ||
"# Set up task loader\n", | ||
"task_loader = TaskLoader(context=ds, target=ds)\n", | ||
"\n", | ||
"# Set up model\n", | ||
"model = ConvNP(data_processor, task_loader)\n", | ||
"\n", | ||
"# Generate training tasks with up to 10% of grid cells passed as context and all grid cells\n", | ||
"# passed as targets\n", | ||
"train_tasks = []\n", | ||
"for date in pd.date_range(\"2013-01-01\", \"2014-11-30\")[::7]:\n", | ||
" task = task_loader(date, context_sampling=np.random.uniform(0.0, 0.1), target_sampling=\"all\")\n", | ||
" train_tasks.append(task)\n", | ||
"\n", | ||
"# Train model\n", | ||
"for epoch in range(10):\n", | ||
" train_epoch(model, train_tasks, progress_bar=True)\n", | ||
"\n", | ||
"# Predict on new task with 10% of context data and a dense grid of target points\n", | ||
"test_task = task_loader(\"2014-12-31\", 0.1)\n", | ||
"mean_ds, std_ds = model.predict(test_task, X_t=ds_raw)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After training, the model can predict directly to `xarray` in your data's original units and coordinate system:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mean_ds" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can also predict directly to `pandas` containing a timeseries of predictions at off-grid locations\n", | ||
"by passing a `numpy` array of target locations to the `X_t` argument of `.predict`:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Predict at two off-grid locations for three days in December 2014\n", | ||
"test_tasks = task_loader(pd.date_range(\"2014-12-01\", \"2014-12-31\"), 0.1)\n", | ||
"mean_df, std_df = model.predict(test_tasks, X_t=np.array([[50, 280], [40, 250]]).T)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mean_df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "mr_py38", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
Sphinx==7.2.6 | ||
sphinx-rtd-theme==1.3.0 | ||
sphinx-rtd-theme==1.3.0 | ||
nbsphinx==0.9.3 |