Skip to content

Commit

Permalink
Adding nbsphinx + setting up sample notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalle Westerling authored and tom-andersson committed Oct 11, 2023
1 parent 42502db commit f2bee1f
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 86 deletions.
6 changes: 5 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = ["sphinx.ext.autodoc", "sphinx.ext.intersphinx"]
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"nbsphinx",
]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
Expand Down
149 changes: 149 additions & 0 deletions docs/getting-started/tutorials/quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"nbsphinx": "hidden"
},
"outputs": [],
"source": [
"import sys, os\n",
"sys.path.append(os.path.abspath(\"../../../\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tutorial: Quickstart"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we will demonstrate a simple example of training a convolutional conditional neural process (ConvCNP) to spatially interpolate ERA5 data.\n",
"\n",
"We can go from imports to predictions with a trained model in less than 30 lines of code!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import deepsensor.torch\n",
"from deepsensor.data.loader import TaskLoader\n",
"from deepsensor.data.processor import DataProcessor\n",
"from deepsensor.model.convnp import ConvNP\n",
"from deepsensor.train.train import train_epoch\n",
"\n",
"import xarray as xr\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# Load raw data\n",
"ds_raw = xr.tutorial.open_dataset(\"air_temperature\")\n",
"\n",
"# Normalise data\n",
"data_processor = DataProcessor(x1_name=\"lat\", x1_map=(15, 75), x2_name=\"lon\", x2_map=(200, 330))\n",
"ds = data_processor(ds_raw)\n",
"\n",
"# Set up task loader\n",
"task_loader = TaskLoader(context=ds, target=ds)\n",
"\n",
"# Set up model\n",
"model = ConvNP(data_processor, task_loader)\n",
"\n",
"# Generate training tasks with up to 10% of grid cells passed as context and all grid cells\n",
"# passed as targets\n",
"train_tasks = []\n",
"for date in pd.date_range(\"2013-01-01\", \"2014-11-30\")[::7]:\n",
" task = task_loader(date, context_sampling=np.random.uniform(0.0, 0.1), target_sampling=\"all\")\n",
" train_tasks.append(task)\n",
"\n",
"# Train model\n",
"for epoch in range(10):\n",
" train_epoch(model, train_tasks, progress_bar=True)\n",
"\n",
"# Predict on new task with 10% of context data and a dense grid of target points\n",
"test_task = task_loader(\"2014-12-31\", 0.1)\n",
"mean_ds, std_ds = model.predict(test_task, X_t=ds_raw)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After training, the model can predict directly to `xarray` in your data's original units and coordinate system:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mean_ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also predict directly to `pandas` containing a timeseries of predictions at off-grid locations\n",
"by passing a `numpy` array of target locations to the `X_t` argument of `.predict`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Predict at two off-grid locations for three days in December 2014\n",
"test_tasks = task_loader(pd.date_range(\"2014-12-01\", \"2014-12-31\"), 0.1)\n",
"mean_df, std_df = model.predict(test_tasks, X_t=np.array([[50, 280], [40, 250]]).T)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mean_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mr_py38",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
84 changes: 0 additions & 84 deletions docs/getting-started/tutorials/quickstart.rst

This file was deleted.

3 changes: 2 additions & 1 deletion requirements.docs.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Sphinx==7.2.6
sphinx-rtd-theme==1.3.0
sphinx-rtd-theme==1.3.0
nbsphinx==0.9.3

0 comments on commit f2bee1f

Please sign in to comment.